diff --git a/mkl_random/__init__.py b/mkl_random/__init__.py index 512027b..01ee956 100644 --- a/mkl_random/__init__.py +++ b/mkl_random/__init__.py @@ -42,4 +42,6 @@ test = PytestTester(__name__) del PytestTester +from ._patch import monkey_patch, use_in_numpy, restore, is_patched, patched_names, mkl_random + del _init_helper diff --git a/mkl_random/src/_patch.pyx b/mkl_random/src/_patch.pyx new file mode 100644 index 0000000..6c39ff3 --- /dev/null +++ b/mkl_random/src/_patch.pyx @@ -0,0 +1,275 @@ +# Copyright (c) 2019, Intel Corporation +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Intel Corporation nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# distutils: language = c +# cython: language_level=3 + +""" +Patch NumPy's `numpy.random` symbols to use mkl_random implementations. + +This is attribute-level monkey patching. It can replace legacy APIs like +`numpy.random.RandomState` and global distribution functions, but it does not +replace NumPy's `Generator`/`default_rng()` unless mkl_random provides fully +compatible replacements. +""" + +from threading import local as threading_local +from contextlib import ContextDecorator + +import numpy as _np +from . import mklrand as _mr + + +cdef tuple _DEFAULT_NAMES = ( + # Legacy seeding / state + "seed", + "get_state", + "set_state", + "RandomState", + + # Common global sampling helpers + "random", + "random_sample", + "sample", + "rand", + "randn", + "bytes", + + # Integers + "randint", + + # Common distributions (only patched if present on both sides) + "standard_normal", + "normal", + "uniform", + "exponential", + "gamma", + "beta", + "chisquare", + "f", + "lognormal", + "laplace", + "logistic", + "multivariate_normal", + "poisson", + "power", + "rayleigh", + "triangular", + "vonmises", + "wald", + "weibull", + "zipf", + + # Permutations / choices + "choice", + "permutation", + "shuffle", +) + + +cdef class patch: + cdef bint _is_patched + cdef object _numpy_module + cdef object _originals # dict: name -> original object + cdef object _patched # list of names actually patched + + def __cinit__(self): + self._is_patched = False + self._numpy_module = None + self._originals = {} + self._patched = [] + + def do_patch(self, numpy_module=None, names=None, bint strict=False): + """ + Patch the given numpy module (default: imported numpy) in-place. + + Parameters + ---------- + numpy_module : module, optional + The numpy module to patch (e.g. `import numpy as np; use_in_numpy(np)`). + names : iterable[str], optional + Attributes under `numpy_module.random` to patch. Defaults to _DEFAULT_NAMES. + strict : bool + If True, raise if any requested symbol cannot be patched. + """ + if numpy_module is None: + numpy_module = _np + if names is None: + names = _DEFAULT_NAMES + + if not hasattr(numpy_module, "random"): + raise TypeError("Expected a numpy-like module with a `.random` attribute.") + + # If already patched, only allow idempotent re-entry for the same numpy module. + if self._is_patched: + if self._numpy_module is numpy_module: + return + raise RuntimeError("Already patched a different numpy module; call restore() first.") + + np_random = numpy_module.random + + originals = {} + patched = [] + missing = [] + + for name in names: + if not hasattr(np_random, name) or not hasattr(_mr, name): + missing.append(name) + continue + originals[name] = getattr(np_random, name) + setattr(np_random, name, getattr(_mr, name)) + patched.append(name) + + if strict and missing: + # revert partial patch before raising + for n, v in originals.items(): + setattr(np_random, n, v) + raise AttributeError( + "Could not patch these names (missing on numpy.random or mkl_random.mklrand): " + + ", ".join([str(x) for x in missing]) + ) + + self._numpy_module = numpy_module + self._originals = originals + self._patched = patched + self._is_patched = True + + def do_unpatch(self): + """ + Restore the previously patched numpy module. + """ + if not self._is_patched: + return + numpy_module = self._numpy_module + np_random = numpy_module.random + for n, v in self._originals.items(): + setattr(np_random, n, v) + + self._numpy_module = None + self._originals = {} + self._patched = [] + self._is_patched = False + + def is_patched(self): + return self._is_patched + + def patched_names(self): + """ + Returns list of names that were actually patched. + """ + return list(self._patched) + + +_tls = threading_local() + + +def _is_tls_initialized(): + return (getattr(_tls, "initialized", None) is not None) and (_tls.initialized is True) + + +def _initialize_tls(): + _tls.patch = patch() + _tls.initialized = True + + +def monkey_patch(numpy_module=None, names=None, strict=False): + """ + Enables using mkl_random in the given NumPy module by patching `numpy.random`. + + Examples + -------- + >>> import numpy as np + >>> import mkl_random + >>> mkl_random.is_patched() + False + >>> mkl_random.monkey_patch(np) + >>> mkl_random.is_patched() + True + >>> mkl_random.restore() + >>> mkl_random.is_patched() + False + """ + if not _is_tls_initialized(): + _initialize_tls() + _tls.patch.do_patch(numpy_module=numpy_module, names=names, strict=bool(strict)) + + +def use_in_numpy(numpy_module=None, names=None, strict=False): + """ + Backward-compatible alias for monkey_patch(). + """ + monkey_patch(numpy_module=numpy_module, names=names, strict=strict) + + +def restore(): + """ + Disables using mkl_random in NumPy by restoring the original `numpy.random` symbols. + """ + if not _is_tls_initialized(): + _initialize_tls() + _tls.patch.do_unpatch() + + +def is_patched(): + """ + Returns whether NumPy has been patched with mkl_random. + """ + if not _is_tls_initialized(): + _initialize_tls() + return bool(_tls.patch.is_patched()) + + +def patched_names(): + """ + Returns the names actually patched in `numpy.random`. + """ + if not _is_tls_initialized(): + _initialize_tls() + return _tls.patch.patched_names() + + +class mkl_random(ContextDecorator): + """ + Context manager and decorator to temporarily patch NumPy's `numpy.random`. + + Examples + -------- + >>> import numpy as np + >>> import mkl_random + >>> with mkl_random.mkl_random(): + ... x = np.random.normal(size=10) + """ + def __init__(self, numpy_module=None, names=None, strict=False): + self._numpy_module = numpy_module + self._names = names + self._strict = strict + + def __enter__(self): + monkey_patch(numpy_module=self._numpy_module, names=self._names, strict=self._strict) + return self + + def __exit__(self, *exc): + restore() + return False diff --git a/mkl_random/tests/test_patch.py b/mkl_random/tests/test_patch.py new file mode 100644 index 0000000..3dabea1 --- /dev/null +++ b/mkl_random/tests/test_patch.py @@ -0,0 +1,95 @@ +import numpy as np +import mkl_random +import pytest + +def test_is_patched(): + """ + Test that is_patched() returns correct status. + """ + assert not mkl_random.is_patched() + mkl_random.monkey_patch(np) + assert mkl_random.is_patched() + mkl_random.restore() + assert not mkl_random.is_patched() + +def test_monkey_patch_and_restore(): + """ + Test that monkey_patch replaces and restore brings back original functions. + """ + # Store original functions + orig_normal = np.random.normal + orig_randint = np.random.randint + orig_RandomState = np.random.RandomState + + try: + mkl_random.monkey_patch(np) + + # Check that functions are now different objects + assert np.random.normal is not orig_normal + assert np.random.randint is not orig_randint + assert np.random.RandomState is not orig_RandomState + + # Check that they are from mkl_random + assert np.random.normal is mkl_random.mklrand.normal + assert np.random.RandomState is mkl_random.mklrand.RandomState + + finally: + mkl_random.restore() + + # Check that original functions are restored + assert mkl_random.is_patched() is False + assert np.random.normal is orig_normal + assert np.random.randint is orig_randint + assert np.random.RandomState is orig_RandomState + +def test_context_manager(): + """ + Test that the context manager patches and automatically restores. + """ + orig_uniform = np.random.uniform + assert not mkl_random.is_patched() + + with mkl_random.mkl_random(np): + assert mkl_random.is_patched() is True + assert np.random.uniform is not orig_uniform + # Smoke test inside context + arr = np.random.uniform(size=10) + assert arr.shape == (10,) + + assert not mkl_random.is_patched() + assert np.random.uniform is orig_uniform + +def test_patched_functions_callable(): + """ + Smoke test to ensure some patched functions can be called without error. + """ + mkl_random.monkey_patch(np) + try: + # These calls should now be routed to mkl_random's implementations + x = np.random.standard_normal(size=100) + assert x.shape == (100,) + + y = np.random.randint(0, 100, size=50) + assert y.shape == (50,) + assert np.all(y >= 0) and np.all(y < 100) + + st = np.random.RandomState(12345) + z = st.rand(10) + assert z.shape == (10,) + + finally: + mkl_random.restore() + +def test_patched_names(): + """ + Test that patched_names() returns a list of patched symbols. + """ + try: + mkl_random.monkey_patch(np) + names = mkl_random.patched_names() + assert isinstance(names, list) + assert len(names) > 0 + assert "normal" in names + assert "RandomState" in names + finally: + mkl_random.restore() diff --git a/setup.py b/setup.py index c47ebfb..70f83ea 100644 --- a/setup.py +++ b/setup.py @@ -81,6 +81,14 @@ def extensions(): extra_compile_args = eca, define_macros=defs + [("NDEBUG", None)], language="c++" + ), + + Extension( + "mkl_random._patch", + sources=[join("mkl_random", "src", "_patch.pyx")], + include_dirs=[np.get_include()], + define_macros=defs + [("NDEBUG", None)], + language="c", ) ]