Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/verify_extension_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
pull_request:
branches: [ "main" ]
types: [ labeled ]
workflow_dispatch:

permissions:
contents: read
Expand Down Expand Up @@ -33,4 +34,4 @@ jobs:

- name: Test extension build via import
run: |
pytest tests/import_test.py -k test_import
pytest tests/import_test.py -k test_extension_built
17 changes: 9 additions & 8 deletions openequivariance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
import sys
import torch
import numpy as np

try:
import openequivariance.extlib
except Exception as e:
raise ImportError(f"Unable to load OpenEquivariance extension library:\n{e}")
from pathlib import Path
from importlib.metadata import version

import openequivariance.extlib

from openequivariance.extlib import (
LINKED_LIBPYTHON,
LINKED_LIBPYTHON_ERROR,
BUILT_EXTENSION,
BUILT_EXTENSION_ERROR,
)

from openequivariance.implementations.e3nn_lite import (
TPProblem,
Irrep,
Expand Down Expand Up @@ -63,9 +67,6 @@ def torch_ext_so_path():
]
)

LINKED_LIBPYTHON = openequivariance.extlib.LINKED_LIBPYTHON
LINKED_LIBPYTHON_ERROR = openequivariance.extlib.LINKED_LIBPYTHON_ERROR

__all__ = [
"TPProblem",
"Irreps",
Expand Down
188 changes: 97 additions & 91 deletions openequivariance/extlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,33 @@
import sysconfig
from pathlib import Path

global torch
import torch

from openequivariance.benchmark.logging_utils import getLogger

oeq_root = str(Path(__file__).parent.parent)

build_ext = True
TORCH_COMPILE = True
TORCH_VERSION_CUDA_OR_HIP = torch.version.cuda or torch.version.hip
torch_module, generic_module = None, None
postprocess_kernel = lambda kernel: kernel # noqa : E731
BUILT_EXTENSION = False
BUILT_EXTENSION_ERROR = None

TORCH_COMPILE = False
TORCH_COMPILE_ERROR = None

LINKED_LIBPYTHON = False
LINKED_LIBPYTHON_ERROR = None

torch_module, generic_module = None, None
postprocess_kernel = lambda kernel: kernel # noqa : E731


def _compile_torch_cuda_extension():
return torch.version.cuda


def _compile_torch_hip_extension():
return torch.version.hip


try:
python_lib_dir = sysconfig.get_config_var("LIBDIR")
major, minor = sys.version_info.major, sys.version_info.minor
Expand All @@ -33,114 +45,108 @@
)

LINKED_LIBPYTHON = True

except Exception as e:
LINKED_LIBPYTHON_ERROR = f"Error linking libpython:\n{e}\nSysconfig variables:\n{sysconfig.get_config_vars()}"

generic_module = None
if not build_ext:

if BUILT_EXTENSION:
import openequivariance.extlib.generic_module

generic_module = openequivariance.extlib.generic_module
elif TORCH_VERSION_CUDA_OR_HIP:
from torch.utils.cpp_extension import library_paths, include_paths

extra_cflags = ["-O3"]
generic_sources = ["generic_module.cpp"]
torch_sources = ["libtorch_tp_jit.cpp"]

include_dirs, extra_link_args = (["util"], ["-Wl,--no-as-needed"])

if LINKED_LIBPYTHON:
extra_link_args.pop()
extra_link_args.extend(
[
f"-Wl,--no-as-needed,-rpath,{python_lib_dir}",
f"-L{python_lib_dir}",
f"-l{python_lib_name}",
],
)

if torch.version.cuda:
extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc"])

try:
elif _compile_torch_cuda_extension() or _compile_torch_hip_extension():
try:
from torch.utils.cpp_extension import library_paths, include_paths

extra_cflags = ["-O3"]
generic_sources = ["generic_module.cpp"]
torch_sources = ["libtorch_tp_jit.cpp"]

include_dirs, extra_link_args = (["util"], ["-Wl,--no-as-needed"])

if LINKED_LIBPYTHON:
extra_link_args.pop()
extra_link_args.extend(
[
f"-Wl,--no-as-needed,-rpath,{python_lib_dir}",
f"-L{python_lib_dir}",
f"-l{python_lib_name}",
],
)
if torch.version.cuda:
extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc"])
torch_libs, cuda_libs = library_paths("cuda")
extra_link_args.append("-Wl,-rpath," + torch_libs)
extra_link_args.append("-L" + cuda_libs)
if os.path.exists(cuda_libs + "/stubs"):
extra_link_args.append("-L" + cuda_libs + "/stubs")
except Exception as e:
getLogger().info(str(e))

extra_cflags.append("-DCUDA_BACKEND")
elif torch.version.hip:
extra_link_args.extend(["-lhiprtc"])
torch_libs = library_paths("cuda")[0]
extra_link_args.append("-Wl,-rpath," + torch_libs)

def postprocess(kernel):
kernel = kernel.replace("__syncwarp();", "__threadfence_block();")
kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(")
kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd")
return kernel

postprocess_kernel = postprocess

extra_cflags.append("-DHIP_BACKEND")

generic_sources = [oeq_root + "/extension/" + src for src in generic_sources]
torch_sources = [oeq_root + "/extension/" + src for src in torch_sources]
include_dirs = [oeq_root + "/extension/" + d for d in include_dirs] + include_paths(
"cuda"
)

torch_compile_exception = None
with warnings.catch_warnings():
warnings.simplefilter("ignore")
extra_cflags.append("-DCUDA_BACKEND")
elif torch.version.hip:
extra_link_args.extend(["-lhiprtc"])
torch_libs = library_paths("cuda")[0]
extra_link_args.append("-Wl,-rpath," + torch_libs)

try:
torch_module = torch.utils.cpp_extension.load(
"libtorch_tp_jit",
torch_sources,
def postprocess(kernel):
kernel = kernel.replace("__syncwarp();", "__threadfence_block();")
kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(")
kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd")
return kernel

postprocess_kernel = postprocess

extra_cflags.append("-DHIP_BACKEND")

generic_sources = [oeq_root + "/extension/" + src for src in generic_sources]
torch_sources = [oeq_root + "/extension/" + src for src in torch_sources]
include_dirs = [
oeq_root + "/extension/" + d for d in include_dirs
] + include_paths("cuda")

with warnings.catch_warnings():
warnings.simplefilter("ignore")

try:
torch_module = torch.utils.cpp_extension.load(
"libtorch_tp_jit",
torch_sources,
extra_cflags=extra_cflags,
extra_include_paths=include_dirs,
extra_ldflags=extra_link_args,
)
torch.ops.load_library(torch_module.__file__)
TORCH_COMPILE = True
except Exception as e:
# If compiling torch fails (e.g. low gcc version), we should fall back to the
# version that takes integer pointers as args (but is untraceable to PyTorch JIT / export).
TORCH_COMPILE_ERROR = e

generic_module = torch.utils.cpp_extension.load(
"generic_module",
generic_sources,
extra_cflags=extra_cflags,
extra_include_paths=include_dirs,
extra_ldflags=extra_link_args,
)
torch.ops.load_library(torch_module.__file__)
except Exception as e:
# If compiling torch fails (e.g. low gcc version), we should fall back to the
# version that takes integer pointers as args (but is untraceable to PyTorch JIT / export).
TORCH_COMPILE = False
torch_compile_exception = e

generic_module = torch.utils.cpp_extension.load(
"generic_module",
generic_sources,
extra_cflags=extra_cflags,
extra_include_paths=include_dirs,
extra_ldflags=extra_link_args,
)
if "generic_module" not in sys.modules:
sys.modules["generic_module"] = generic_module
if "generic_module" not in sys.modules:
sys.modules["generic_module"] = generic_module

if not TORCH_COMPILE:
warnings.warn(
"Could not compile integrated PyTorch wrapper. Falling back to Pybind11"
+ f", but JITScript, compile fullgraph, and export will fail.\n {torch_compile_exception}"
)
else:
TORCH_COMPILE = False
if not TORCH_COMPILE:
warnings.warn(
"Could not compile integrated PyTorch wrapper. Falling back to Pybind11"
+ f", but JITScript, compile fullgraph, and export will fail.\n {TORCH_COMPILE_ERROR}"
)
BUILT_EXTENSION = True
print("REACHED 3!")
except Exception as e:
BUILT_EXTENSION_ERROR = f"Error building OpenEquivariance Extension: {e}"


def _raise_import_error_helper(import_target: str):
if not TORCH_VERSION_CUDA_OR_HIP:
raise ImportError(
f"Could not import {import_target}: OpenEquivariance's torch extension was not built because torch.version.cuda || torch.version.hip is false"
)
if not BUILT_EXTENSION:
raise ImportError(f"Could not import {import_target}: {BUILT_EXTENSION_ERROR}")


if TORCH_VERSION_CUDA_OR_HIP:
if BUILT_EXTENSION:
from generic_module import (
JITTPImpl,
JITConvImpl,
Expand Down
7 changes: 7 additions & 0 deletions tests/import_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,10 @@ def test_import():
assert openequivariance.__version__ is not None
assert openequivariance.__version__ != "0.0.0"
assert openequivariance.__version__ == version("openequivariance")


def test_extension_built():
from openequivariance import BUILT_EXTENSION, BUILT_EXTENSION_ERROR

assert BUILT_EXTENSION_ERROR is None
assert BUILT_EXTENSION
Loading