diff --git a/.github/workflows/verify_extension_build.yml b/.github/workflows/verify_extension_build.yml index db48af7..cdf9e5f 100644 --- a/.github/workflows/verify_extension_build.yml +++ b/.github/workflows/verify_extension_build.yml @@ -6,6 +6,7 @@ on: pull_request: branches: [ "main" ] types: [ labeled ] + workflow_dispatch: permissions: contents: read @@ -33,4 +34,4 @@ jobs: - name: Test extension build via import run: | - pytest tests/import_test.py -k test_import \ No newline at end of file + pytest tests/import_test.py -k test_extension_built \ No newline at end of file diff --git a/openequivariance/__init__.py b/openequivariance/__init__.py index 9fb67d0..116129d 100644 --- a/openequivariance/__init__.py +++ b/openequivariance/__init__.py @@ -2,14 +2,18 @@ import sys import torch import numpy as np - -try: - import openequivariance.extlib -except Exception as e: - raise ImportError(f"Unable to load OpenEquivariance extension library:\n{e}") from pathlib import Path from importlib.metadata import version +import openequivariance.extlib + +from openequivariance.extlib import ( + LINKED_LIBPYTHON, + LINKED_LIBPYTHON_ERROR, + BUILT_EXTENSION, + BUILT_EXTENSION_ERROR, +) + from openequivariance.implementations.e3nn_lite import ( TPProblem, Irrep, @@ -63,9 +67,6 @@ def torch_ext_so_path(): ] ) -LINKED_LIBPYTHON = openequivariance.extlib.LINKED_LIBPYTHON -LINKED_LIBPYTHON_ERROR = openequivariance.extlib.LINKED_LIBPYTHON_ERROR - __all__ = [ "TPProblem", "Irreps", diff --git a/openequivariance/extlib/__init__.py b/openequivariance/extlib/__init__.py index 527c4ab..d1d23c0 100644 --- a/openequivariance/extlib/__init__.py +++ b/openequivariance/extlib/__init__.py @@ -5,21 +5,33 @@ import sysconfig from pathlib import Path -global torch import torch from openequivariance.benchmark.logging_utils import getLogger oeq_root = str(Path(__file__).parent.parent) -build_ext = True -TORCH_COMPILE = True -TORCH_VERSION_CUDA_OR_HIP = torch.version.cuda or torch.version.hip -torch_module, generic_module = None, None -postprocess_kernel = lambda kernel: kernel # noqa : E731 +BUILT_EXTENSION = False +BUILT_EXTENSION_ERROR = None + +TORCH_COMPILE = False +TORCH_COMPILE_ERROR = None LINKED_LIBPYTHON = False LINKED_LIBPYTHON_ERROR = None + +torch_module, generic_module = None, None +postprocess_kernel = lambda kernel: kernel # noqa : E731 + + +def _compile_torch_cuda_extension(): + return torch.version.cuda + + +def _compile_torch_hip_extension(): + return torch.version.hip + + try: python_lib_dir = sysconfig.get_config_var("LIBDIR") major, minor = sys.version_info.major, sys.version_info.minor @@ -33,114 +45,108 @@ ) LINKED_LIBPYTHON = True - except Exception as e: LINKED_LIBPYTHON_ERROR = f"Error linking libpython:\n{e}\nSysconfig variables:\n{sysconfig.get_config_vars()}" -generic_module = None -if not build_ext: + +if BUILT_EXTENSION: import openequivariance.extlib.generic_module generic_module = openequivariance.extlib.generic_module -elif TORCH_VERSION_CUDA_OR_HIP: - from torch.utils.cpp_extension import library_paths, include_paths - - extra_cflags = ["-O3"] - generic_sources = ["generic_module.cpp"] - torch_sources = ["libtorch_tp_jit.cpp"] - - include_dirs, extra_link_args = (["util"], ["-Wl,--no-as-needed"]) - - if LINKED_LIBPYTHON: - extra_link_args.pop() - extra_link_args.extend( - [ - f"-Wl,--no-as-needed,-rpath,{python_lib_dir}", - f"-L{python_lib_dir}", - f"-l{python_lib_name}", - ], - ) - - if torch.version.cuda: - extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc"]) - - try: +elif _compile_torch_cuda_extension() or _compile_torch_hip_extension(): + try: + from torch.utils.cpp_extension import library_paths, include_paths + + extra_cflags = ["-O3"] + generic_sources = ["generic_module.cpp"] + torch_sources = ["libtorch_tp_jit.cpp"] + + include_dirs, extra_link_args = (["util"], ["-Wl,--no-as-needed"]) + + if LINKED_LIBPYTHON: + extra_link_args.pop() + extra_link_args.extend( + [ + f"-Wl,--no-as-needed,-rpath,{python_lib_dir}", + f"-L{python_lib_dir}", + f"-l{python_lib_name}", + ], + ) + if torch.version.cuda: + extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc"]) torch_libs, cuda_libs = library_paths("cuda") extra_link_args.append("-Wl,-rpath," + torch_libs) extra_link_args.append("-L" + cuda_libs) if os.path.exists(cuda_libs + "/stubs"): extra_link_args.append("-L" + cuda_libs + "/stubs") - except Exception as e: - getLogger().info(str(e)) - - extra_cflags.append("-DCUDA_BACKEND") - elif torch.version.hip: - extra_link_args.extend(["-lhiprtc"]) - torch_libs = library_paths("cuda")[0] - extra_link_args.append("-Wl,-rpath," + torch_libs) - def postprocess(kernel): - kernel = kernel.replace("__syncwarp();", "__threadfence_block();") - kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(") - kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd") - return kernel - - postprocess_kernel = postprocess - - extra_cflags.append("-DHIP_BACKEND") - - generic_sources = [oeq_root + "/extension/" + src for src in generic_sources] - torch_sources = [oeq_root + "/extension/" + src for src in torch_sources] - include_dirs = [oeq_root + "/extension/" + d for d in include_dirs] + include_paths( - "cuda" - ) - - torch_compile_exception = None - with warnings.catch_warnings(): - warnings.simplefilter("ignore") + extra_cflags.append("-DCUDA_BACKEND") + elif torch.version.hip: + extra_link_args.extend(["-lhiprtc"]) + torch_libs = library_paths("cuda")[0] + extra_link_args.append("-Wl,-rpath," + torch_libs) - try: - torch_module = torch.utils.cpp_extension.load( - "libtorch_tp_jit", - torch_sources, + def postprocess(kernel): + kernel = kernel.replace("__syncwarp();", "__threadfence_block();") + kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(") + kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd") + return kernel + + postprocess_kernel = postprocess + + extra_cflags.append("-DHIP_BACKEND") + + generic_sources = [oeq_root + "/extension/" + src for src in generic_sources] + torch_sources = [oeq_root + "/extension/" + src for src in torch_sources] + include_dirs = [ + oeq_root + "/extension/" + d for d in include_dirs + ] + include_paths("cuda") + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + try: + torch_module = torch.utils.cpp_extension.load( + "libtorch_tp_jit", + torch_sources, + extra_cflags=extra_cflags, + extra_include_paths=include_dirs, + extra_ldflags=extra_link_args, + ) + torch.ops.load_library(torch_module.__file__) + TORCH_COMPILE = True + except Exception as e: + # If compiling torch fails (e.g. low gcc version), we should fall back to the + # version that takes integer pointers as args (but is untraceable to PyTorch JIT / export). + TORCH_COMPILE_ERROR = e + + generic_module = torch.utils.cpp_extension.load( + "generic_module", + generic_sources, extra_cflags=extra_cflags, extra_include_paths=include_dirs, extra_ldflags=extra_link_args, ) - torch.ops.load_library(torch_module.__file__) - except Exception as e: - # If compiling torch fails (e.g. low gcc version), we should fall back to the - # version that takes integer pointers as args (but is untraceable to PyTorch JIT / export). - TORCH_COMPILE = False - torch_compile_exception = e - - generic_module = torch.utils.cpp_extension.load( - "generic_module", - generic_sources, - extra_cflags=extra_cflags, - extra_include_paths=include_dirs, - extra_ldflags=extra_link_args, - ) - if "generic_module" not in sys.modules: - sys.modules["generic_module"] = generic_module + if "generic_module" not in sys.modules: + sys.modules["generic_module"] = generic_module - if not TORCH_COMPILE: - warnings.warn( - "Could not compile integrated PyTorch wrapper. Falling back to Pybind11" - + f", but JITScript, compile fullgraph, and export will fail.\n {torch_compile_exception}" - ) -else: - TORCH_COMPILE = False + if not TORCH_COMPILE: + warnings.warn( + "Could not compile integrated PyTorch wrapper. Falling back to Pybind11" + + f", but JITScript, compile fullgraph, and export will fail.\n {TORCH_COMPILE_ERROR}" + ) + BUILT_EXTENSION = True + print("REACHED 3!") + except Exception as e: + BUILT_EXTENSION_ERROR = f"Error building OpenEquivariance Extension: {e}" def _raise_import_error_helper(import_target: str): - if not TORCH_VERSION_CUDA_OR_HIP: - raise ImportError( - f"Could not import {import_target}: OpenEquivariance's torch extension was not built because torch.version.cuda || torch.version.hip is false" - ) + if not BUILT_EXTENSION: + raise ImportError(f"Could not import {import_target}: {BUILT_EXTENSION_ERROR}") -if TORCH_VERSION_CUDA_OR_HIP: +if BUILT_EXTENSION: from generic_module import ( JITTPImpl, JITConvImpl, diff --git a/tests/import_test.py b/tests/import_test.py index 3cfbb14..bf26af3 100644 --- a/tests/import_test.py +++ b/tests/import_test.py @@ -7,3 +7,10 @@ def test_import(): assert openequivariance.__version__ is not None assert openequivariance.__version__ != "0.0.0" assert openequivariance.__version__ == version("openequivariance") + + +def test_extension_built(): + from openequivariance import BUILT_EXTENSION, BUILT_EXTENSION_ERROR + + assert BUILT_EXTENSION_ERROR is None + assert BUILT_EXTENSION