From dbc26dfa388218307fc3cae6f74d326e9b101083 Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Sat, 1 Nov 2025 17:07:13 -0700 Subject: [PATCH 01/11] stronger criteria to attempt compilation --- openequivariance/extlib/__init__.py | 43 +++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/openequivariance/extlib/__init__.py b/openequivariance/extlib/__init__.py index 527c4ab..9e1b937 100644 --- a/openequivariance/extlib/__init__.py +++ b/openequivariance/extlib/__init__.py @@ -3,6 +3,8 @@ import os import warnings import sysconfig +import shutil +from functools import lru_cache from pathlib import Path global torch @@ -14,7 +16,30 @@ build_ext = True TORCH_COMPILE = True -TORCH_VERSION_CUDA_OR_HIP = torch.version.cuda or torch.version.hip + + +@lru_cache(maxsize=1) +def _compile_torch_cuda_extension(): + return torch.version.cuda and _nvcc_present() + + +def _nvcc_present(): + return shutil.which("nvcc") is not None + + +@lru_cache(maxsize=1) +def _compile_torch_hip_extension(): + return torch.version.hip and _hipcc_present() + + +def _hipcc_present(): + return shutil.which("hipcc") is not None + + +COMPILE_TORCH_CUDA_EXTENSION = _compile_torch_cuda_extension() +COMPILE_TORCH_HIP_EXTENSION = _compile_torch_hip_extension() + + torch_module, generic_module = None, None postprocess_kernel = lambda kernel: kernel # noqa : E731 @@ -42,7 +67,7 @@ import openequivariance.extlib.generic_module generic_module = openequivariance.extlib.generic_module -elif TORCH_VERSION_CUDA_OR_HIP: +elif COMPILE_TORCH_CUDA_EXTENSION or COMPILE_TORCH_HIP_EXTENSION: from torch.utils.cpp_extension import library_paths, include_paths extra_cflags = ["-O3"] @@ -134,13 +159,21 @@ def postprocess(kernel): def _raise_import_error_helper(import_target: str): - if not TORCH_VERSION_CUDA_OR_HIP: + if torch.version.cuda: + raise ImportError( + f"Could not import {import_target}: OpenEquivariance's torch extension was not built because nvcc was not found" + ) + elif torch.version.hip: + raise ImportError( + f"Could not import {import_target}: OpenEquivariance's torch extension was not built because hipcc was not found" + ) + else: raise ImportError( - f"Could not import {import_target}: OpenEquivariance's torch extension was not built because torch.version.cuda || torch.version.hip is false" + f"Could not import {import_target}: OpenEquivariance's torch extension was not built because torch.version.cuda || torch.version.hip was False" ) -if TORCH_VERSION_CUDA_OR_HIP: +if COMPILE_TORCH_CUDA_EXTENSION or COMPILE_TORCH_HIP_EXTENSION: from generic_module import ( JITTPImpl, JITConvImpl, From 2fb12bbeef6577a8690b5071f9637cdce84e1ca7 Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Fri, 7 Nov 2025 12:50:58 -0800 Subject: [PATCH 02/11] integrating PR feedback --- openequivariance/__init__.py | 17 +-- openequivariance/extlib/__init__.py | 198 +++++++++++++--------------- tests/import_test.py | 7 + 3 files changed, 106 insertions(+), 116 deletions(-) diff --git a/openequivariance/__init__.py b/openequivariance/__init__.py index 9fb67d0..116129d 100644 --- a/openequivariance/__init__.py +++ b/openequivariance/__init__.py @@ -2,14 +2,18 @@ import sys import torch import numpy as np - -try: - import openequivariance.extlib -except Exception as e: - raise ImportError(f"Unable to load OpenEquivariance extension library:\n{e}") from pathlib import Path from importlib.metadata import version +import openequivariance.extlib + +from openequivariance.extlib import ( + LINKED_LIBPYTHON, + LINKED_LIBPYTHON_ERROR, + BUILT_EXTENSION, + BUILT_EXTENSION_ERROR, +) + from openequivariance.implementations.e3nn_lite import ( TPProblem, Irrep, @@ -63,9 +67,6 @@ def torch_ext_so_path(): ] ) -LINKED_LIBPYTHON = openequivariance.extlib.LINKED_LIBPYTHON -LINKED_LIBPYTHON_ERROR = openequivariance.extlib.LINKED_LIBPYTHON_ERROR - __all__ = [ "TPProblem", "Irreps", diff --git a/openequivariance/extlib/__init__.py b/openequivariance/extlib/__init__.py index 9e1b937..188b143 100644 --- a/openequivariance/extlib/__init__.py +++ b/openequivariance/extlib/__init__.py @@ -3,37 +3,28 @@ import os import warnings import sysconfig -import shutil from functools import lru_cache from pathlib import Path -global torch import torch from openequivariance.benchmark.logging_utils import getLogger oeq_root = str(Path(__file__).parent.parent) -build_ext = True TORCH_COMPILE = True +BUILT_EXTENSION = False +BUILT_EXTENSION_ERROR = None @lru_cache(maxsize=1) def _compile_torch_cuda_extension(): - return torch.version.cuda and _nvcc_present() - - -def _nvcc_present(): - return shutil.which("nvcc") is not None + return torch.version.cuda and ("CUDA_HOME" in os.environ) @lru_cache(maxsize=1) def _compile_torch_hip_extension(): - return torch.version.hip and _hipcc_present() - - -def _hipcc_present(): - return shutil.which("hipcc") is not None + return torch.version.hip and ("HIP_HOME" in os.environ) COMPILE_TORCH_CUDA_EXTENSION = _compile_torch_cuda_extension() @@ -63,117 +54,108 @@ def _hipcc_present(): LINKED_LIBPYTHON_ERROR = f"Error linking libpython:\n{e}\nSysconfig variables:\n{sysconfig.get_config_vars()}" generic_module = None -if not build_ext: +if BUILT_EXTENSION: import openequivariance.extlib.generic_module generic_module = openequivariance.extlib.generic_module elif COMPILE_TORCH_CUDA_EXTENSION or COMPILE_TORCH_HIP_EXTENSION: - from torch.utils.cpp_extension import library_paths, include_paths - - extra_cflags = ["-O3"] - generic_sources = ["generic_module.cpp"] - torch_sources = ["libtorch_tp_jit.cpp"] - - include_dirs, extra_link_args = (["util"], ["-Wl,--no-as-needed"]) - - if LINKED_LIBPYTHON: - extra_link_args.pop() - extra_link_args.extend( - [ - f"-Wl,--no-as-needed,-rpath,{python_lib_dir}", - f"-L{python_lib_dir}", - f"-l{python_lib_name}", - ], - ) - - if torch.version.cuda: - extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc"]) - - try: - torch_libs, cuda_libs = library_paths("cuda") + try: + from torch.utils.cpp_extension import library_paths, include_paths + + extra_cflags = ["-O3"] + generic_sources = ["generic_module.cpp"] + torch_sources = ["libtorch_tp_jit.cpp"] + + include_dirs, extra_link_args = (["util"], ["-Wl,--no-as-needed"]) + + if LINKED_LIBPYTHON: + extra_link_args.pop() + extra_link_args.extend( + [ + f"-Wl,--no-as-needed,-rpath,{python_lib_dir}", + f"-L{python_lib_dir}", + f"-l{python_lib_name}", + ], + ) + if torch.version.cuda: + extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc"]) + + try: + torch_libs, cuda_libs = library_paths("cuda") + extra_link_args.append("-Wl,-rpath," + torch_libs) + extra_link_args.append("-L" + cuda_libs) + if os.path.exists(cuda_libs + "/stubs"): + extra_link_args.append("-L" + cuda_libs + "/stubs") + except Exception as e: + getLogger().info(str(e)) + + extra_cflags.append("-DCUDA_BACKEND") + elif torch.version.hip: + extra_link_args.extend(["-lhiprtc"]) + torch_libs = library_paths("cuda")[0] extra_link_args.append("-Wl,-rpath," + torch_libs) - extra_link_args.append("-L" + cuda_libs) - if os.path.exists(cuda_libs + "/stubs"): - extra_link_args.append("-L" + cuda_libs + "/stubs") - except Exception as e: - getLogger().info(str(e)) - - extra_cflags.append("-DCUDA_BACKEND") - elif torch.version.hip: - extra_link_args.extend(["-lhiprtc"]) - torch_libs = library_paths("cuda")[0] - extra_link_args.append("-Wl,-rpath," + torch_libs) - - def postprocess(kernel): - kernel = kernel.replace("__syncwarp();", "__threadfence_block();") - kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(") - kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd") - return kernel - - postprocess_kernel = postprocess - - extra_cflags.append("-DHIP_BACKEND") - - generic_sources = [oeq_root + "/extension/" + src for src in generic_sources] - torch_sources = [oeq_root + "/extension/" + src for src in torch_sources] - include_dirs = [oeq_root + "/extension/" + d for d in include_dirs] + include_paths( - "cuda" - ) - torch_compile_exception = None - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - try: - torch_module = torch.utils.cpp_extension.load( - "libtorch_tp_jit", - torch_sources, + def postprocess(kernel): + kernel = kernel.replace("__syncwarp();", "__threadfence_block();") + kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(") + kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd") + return kernel + + postprocess_kernel = postprocess + + extra_cflags.append("-DHIP_BACKEND") + + generic_sources = [oeq_root + "/extension/" + src for src in generic_sources] + torch_sources = [oeq_root + "/extension/" + src for src in torch_sources] + include_dirs = [ + oeq_root + "/extension/" + d for d in include_dirs + ] + include_paths("cuda") + + torch_compile_exception = None + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + try: + torch_module = torch.utils.cpp_extension.load( + "libtorch_tp_jit", + torch_sources, + extra_cflags=extra_cflags, + extra_include_paths=include_dirs, + extra_ldflags=extra_link_args, + ) + torch.ops.load_library(torch_module.__file__) + except Exception as e: + # If compiling torch fails (e.g. low gcc version), we should fall back to the + # version that takes integer pointers as args (but is untraceable to PyTorch JIT / export). + TORCH_COMPILE = False + torch_compile_exception = e + + generic_module = torch.utils.cpp_extension.load( + "generic_module", + generic_sources, extra_cflags=extra_cflags, extra_include_paths=include_dirs, extra_ldflags=extra_link_args, ) - torch.ops.load_library(torch_module.__file__) - except Exception as e: - # If compiling torch fails (e.g. low gcc version), we should fall back to the - # version that takes integer pointers as args (but is untraceable to PyTorch JIT / export). - TORCH_COMPILE = False - torch_compile_exception = e - - generic_module = torch.utils.cpp_extension.load( - "generic_module", - generic_sources, - extra_cflags=extra_cflags, - extra_include_paths=include_dirs, - extra_ldflags=extra_link_args, - ) - if "generic_module" not in sys.modules: - sys.modules["generic_module"] = generic_module + if "generic_module" not in sys.modules: + sys.modules["generic_module"] = generic_module - if not TORCH_COMPILE: - warnings.warn( - "Could not compile integrated PyTorch wrapper. Falling back to Pybind11" - + f", but JITScript, compile fullgraph, and export will fail.\n {torch_compile_exception}" - ) -else: - TORCH_COMPILE = False + if not TORCH_COMPILE: + warnings.warn( + "Could not compile integrated PyTorch wrapper. Falling back to Pybind11" + + f", but JITScript, compile fullgraph, and export will fail.\n {torch_compile_exception}" + ) + BUILT_EXTENSION = True + except Exception as e: + BUILT_EXTENSION_ERROR = f"Error building OpenEquivariance Extension: {e}" def _raise_import_error_helper(import_target: str): - if torch.version.cuda: - raise ImportError( - f"Could not import {import_target}: OpenEquivariance's torch extension was not built because nvcc was not found" - ) - elif torch.version.hip: - raise ImportError( - f"Could not import {import_target}: OpenEquivariance's torch extension was not built because hipcc was not found" - ) - else: - raise ImportError( - f"Could not import {import_target}: OpenEquivariance's torch extension was not built because torch.version.cuda || torch.version.hip was False" - ) + if not BUILT_EXTENSION: + raise ImportError(f"Could not import {import_target}: {BUILT_EXTENSION_ERROR}") -if COMPILE_TORCH_CUDA_EXTENSION or COMPILE_TORCH_HIP_EXTENSION: +if BUILT_EXTENSION: from generic_module import ( JITTPImpl, JITConvImpl, diff --git a/tests/import_test.py b/tests/import_test.py index 3cfbb14..bf26af3 100644 --- a/tests/import_test.py +++ b/tests/import_test.py @@ -7,3 +7,10 @@ def test_import(): assert openequivariance.__version__ is not None assert openequivariance.__version__ != "0.0.0" assert openequivariance.__version__ == version("openequivariance") + + +def test_extension_built(): + from openequivariance import BUILT_EXTENSION, BUILT_EXTENSION_ERROR + + assert BUILT_EXTENSION_ERROR is None + assert BUILT_EXTENSION From 351e3b33a7878ea1d636d87eed75c4bb83aed566 Mon Sep 17 00:00:00 2001 From: asglover <140220574+asglover@users.noreply.github.com> Date: Fri, 7 Nov 2025 14:03:03 -0800 Subject: [PATCH 03/11] make TORCH_COMPILE False by default, remove duplicate setting of generic_module. remove constants for _compile_torch_extension() as it's referenced only once. torch_compile_exception -> TORCH_COMPILE_ERROR for consistency. --- openequivariance/extlib/__init__.py | 33 +++++++++++++---------------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/openequivariance/extlib/__init__.py b/openequivariance/extlib/__init__.py index 188b143..6b28756 100644 --- a/openequivariance/extlib/__init__.py +++ b/openequivariance/extlib/__init__.py @@ -12,10 +12,18 @@ oeq_root = str(Path(__file__).parent.parent) -TORCH_COMPILE = True BUILT_EXTENSION = False BUILT_EXTENSION_ERROR = None +TORCH_COMPILE = False +TORCH_COMPILE_ERROR = None + +LINKED_LIBPYTHON = False +LINKED_LIBPYTHON_ERROR = None + +torch_module, generic_module = None, None +postprocess_kernel = lambda kernel: kernel # noqa : E731 + @lru_cache(maxsize=1) def _compile_torch_cuda_extension(): @@ -27,15 +35,6 @@ def _compile_torch_hip_extension(): return torch.version.hip and ("HIP_HOME" in os.environ) -COMPILE_TORCH_CUDA_EXTENSION = _compile_torch_cuda_extension() -COMPILE_TORCH_HIP_EXTENSION = _compile_torch_hip_extension() - - -torch_module, generic_module = None, None -postprocess_kernel = lambda kernel: kernel # noqa : E731 - -LINKED_LIBPYTHON = False -LINKED_LIBPYTHON_ERROR = None try: python_lib_dir = sysconfig.get_config_var("LIBDIR") major, minor = sys.version_info.major, sys.version_info.minor @@ -49,16 +48,15 @@ def _compile_torch_hip_extension(): ) LINKED_LIBPYTHON = True - except Exception as e: LINKED_LIBPYTHON_ERROR = f"Error linking libpython:\n{e}\nSysconfig variables:\n{sysconfig.get_config_vars()}" -generic_module = None + if BUILT_EXTENSION: import openequivariance.extlib.generic_module - + generic_module = openequivariance.extlib.generic_module -elif COMPILE_TORCH_CUDA_EXTENSION or COMPILE_TORCH_HIP_EXTENSION: +elif _compile_torch_cuda_extension() or _compile_torch_hip_extension(): try: from torch.utils.cpp_extension import library_paths, include_paths @@ -111,7 +109,6 @@ def postprocess(kernel): oeq_root + "/extension/" + d for d in include_dirs ] + include_paths("cuda") - torch_compile_exception = None with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -124,11 +121,11 @@ def postprocess(kernel): extra_ldflags=extra_link_args, ) torch.ops.load_library(torch_module.__file__) + TORCH_COMPILE = True except Exception as e: # If compiling torch fails (e.g. low gcc version), we should fall back to the # version that takes integer pointers as args (but is untraceable to PyTorch JIT / export). - TORCH_COMPILE = False - torch_compile_exception = e + TORCH_COMPILE_ERROR = e generic_module = torch.utils.cpp_extension.load( "generic_module", @@ -143,7 +140,7 @@ def postprocess(kernel): if not TORCH_COMPILE: warnings.warn( "Could not compile integrated PyTorch wrapper. Falling back to Pybind11" - + f", but JITScript, compile fullgraph, and export will fail.\n {torch_compile_exception}" + + f", but JITScript, compile fullgraph, and export will fail.\n {TORCH_COMPILE_ERROR}" ) BUILT_EXTENSION = True except Exception as e: From 4ff819f4bba3f12d9b945c78aa51371736049582 Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Fri, 7 Nov 2025 14:16:39 -0800 Subject: [PATCH 04/11] format --- openequivariance/extlib/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openequivariance/extlib/__init__.py b/openequivariance/extlib/__init__.py index 6b28756..e3fe252 100644 --- a/openequivariance/extlib/__init__.py +++ b/openequivariance/extlib/__init__.py @@ -54,7 +54,7 @@ def _compile_torch_hip_extension(): if BUILT_EXTENSION: import openequivariance.extlib.generic_module - + generic_module = openequivariance.extlib.generic_module elif _compile_torch_cuda_extension() or _compile_torch_hip_extension(): try: From 04c50789a059ef29103ccde155cbf15b9ce93c10 Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Sat, 8 Nov 2025 12:21:34 -0800 Subject: [PATCH 05/11] remove lru cache --- openequivariance/extlib/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/openequivariance/extlib/__init__.py b/openequivariance/extlib/__init__.py index e3fe252..9b2588e 100644 --- a/openequivariance/extlib/__init__.py +++ b/openequivariance/extlib/__init__.py @@ -3,7 +3,6 @@ import os import warnings import sysconfig -from functools import lru_cache from pathlib import Path import torch @@ -25,12 +24,10 @@ postprocess_kernel = lambda kernel: kernel # noqa : E731 -@lru_cache(maxsize=1) def _compile_torch_cuda_extension(): return torch.version.cuda and ("CUDA_HOME" in os.environ) -@lru_cache(maxsize=1) def _compile_torch_hip_extension(): return torch.version.hip and ("HIP_HOME" in os.environ) From 4e1a595f7eb71565351338d913fa7260d203c5f3 Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Sat, 8 Nov 2025 12:21:53 -0800 Subject: [PATCH 06/11] bad commit --- openequivariance/extension/util/backend_cuda.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/openequivariance/extension/util/backend_cuda.hpp b/openequivariance/extension/util/backend_cuda.hpp index 364186f..6db46d2 100644 --- a/openequivariance/extension/util/backend_cuda.hpp +++ b/openequivariance/extension/util/backend_cuda.hpp @@ -91,7 +91,8 @@ class GPUTimer { int* ptr = (int*) (CUDA_Allocator::gpu_alloc(element_count * sizeof(int))); CUDA_ERRCHK(cudaMemset(ptr, 42, element_count * sizeof(int))) CUDA_Allocator::gpu_free(ptr); - cudaDeviceSynchronize(); + cudaDeviceSy + nchronize(); } ~GPUTimer() { From 0f0bb078a9d478783640438c3cea5e79caaac0e0 Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Sat, 8 Nov 2025 13:44:13 -0800 Subject: [PATCH 07/11] test extension built --- .github/workflows/verify_extension_build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/verify_extension_build.yml b/.github/workflows/verify_extension_build.yml index db48af7..69ddbae 100644 --- a/.github/workflows/verify_extension_build.yml +++ b/.github/workflows/verify_extension_build.yml @@ -33,4 +33,4 @@ jobs: - name: Test extension build via import run: | - pytest tests/import_test.py -k test_import \ No newline at end of file + pytest tests/import_test.py -k test_extension_built \ No newline at end of file From 91c34d7b319704b14b49f6a6368d87a18f988542 Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Sat, 8 Nov 2025 13:44:27 -0800 Subject: [PATCH 08/11] revert intentional mistake --- openequivariance/extension/util/backend_cuda.hpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/openequivariance/extension/util/backend_cuda.hpp b/openequivariance/extension/util/backend_cuda.hpp index 6db46d2..364186f 100644 --- a/openequivariance/extension/util/backend_cuda.hpp +++ b/openequivariance/extension/util/backend_cuda.hpp @@ -91,8 +91,7 @@ class GPUTimer { int* ptr = (int*) (CUDA_Allocator::gpu_alloc(element_count * sizeof(int))); CUDA_ERRCHK(cudaMemset(ptr, 42, element_count * sizeof(int))) CUDA_Allocator::gpu_free(ptr); - cudaDeviceSy - nchronize(); + cudaDeviceSynchronize(); } ~GPUTimer() { From e9ab609c963b530f1ed683903a46d00f6381fed4 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 9 Nov 2025 18:34:12 -0800 Subject: [PATCH 09/11] Added workflow dispatch trigger. --- .github/workflows/verify_extension_build.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/verify_extension_build.yml b/.github/workflows/verify_extension_build.yml index 69ddbae..cdf9e5f 100644 --- a/.github/workflows/verify_extension_build.yml +++ b/.github/workflows/verify_extension_build.yml @@ -6,6 +6,7 @@ on: pull_request: branches: [ "main" ] types: [ labeled ] + workflow_dispatch: permissions: contents: read From d2d7acfc27f83183a31a39f10cae60bf3106474e Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 9 Nov 2025 18:44:07 -0800 Subject: [PATCH 10/11] Removed a try-catch. --- openequivariance/extlib/__init__.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/openequivariance/extlib/__init__.py b/openequivariance/extlib/__init__.py index 9b2588e..2fecfb8 100644 --- a/openequivariance/extlib/__init__.py +++ b/openequivariance/extlib/__init__.py @@ -74,15 +74,11 @@ def _compile_torch_hip_extension(): ) if torch.version.cuda: extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc"]) - - try: - torch_libs, cuda_libs = library_paths("cuda") - extra_link_args.append("-Wl,-rpath," + torch_libs) - extra_link_args.append("-L" + cuda_libs) - if os.path.exists(cuda_libs + "/stubs"): - extra_link_args.append("-L" + cuda_libs + "/stubs") - except Exception as e: - getLogger().info(str(e)) + torch_libs, cuda_libs = library_paths("cuda") + extra_link_args.append("-Wl,-rpath," + torch_libs) + extra_link_args.append("-L" + cuda_libs) + if os.path.exists(cuda_libs + "/stubs"): + extra_link_args.append("-L" + cuda_libs + "/stubs") extra_cflags.append("-DCUDA_BACKEND") elif torch.version.hip: From 89f3e3cb0e83ddab3bf955fac2cdce0ebc9971b9 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 9 Nov 2025 18:54:03 -0800 Subject: [PATCH 11/11] Fixed another error. --- openequivariance/extlib/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/openequivariance/extlib/__init__.py b/openequivariance/extlib/__init__.py index 2fecfb8..d1d23c0 100644 --- a/openequivariance/extlib/__init__.py +++ b/openequivariance/extlib/__init__.py @@ -25,11 +25,11 @@ def _compile_torch_cuda_extension(): - return torch.version.cuda and ("CUDA_HOME" in os.environ) + return torch.version.cuda def _compile_torch_hip_extension(): - return torch.version.hip and ("HIP_HOME" in os.environ) + return torch.version.hip try: @@ -136,6 +136,7 @@ def postprocess(kernel): + f", but JITScript, compile fullgraph, and export will fail.\n {TORCH_COMPILE_ERROR}" ) BUILT_EXTENSION = True + print("REACHED 3!") except Exception as e: BUILT_EXTENSION_ERROR = f"Error building OpenEquivariance Extension: {e}"