From 00d1f89e9b1f12f7df765164ca89d6ecbbe7fe53 Mon Sep 17 00:00:00 2001 From: Martin Schuck Date: Mon, 22 Dec 2025 21:54:13 +0100 Subject: [PATCH 1/4] Fix is_jax_array for jax>=0.8.2 --- array_api_compat/common/_helpers.py | 6 +++++- tests/test_common.py | 9 +++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 37f31ec2..2f8b0e12 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -235,7 +235,11 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]: is_pydata_sparse_array """ cls = cast(Hashable, type(x)) - return _issubclass_fast(cls, "jax", "Array") or _is_jax_zero_gradient_array(x) + return ( + _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "jax.core", "Tracer") + or _is_jax_zero_gradient_array(x) + ) def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: diff --git a/tests/test_common.py b/tests/test_common.py index 85ed032e..e105fef0 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -56,6 +56,15 @@ def test_is_xp_array(library, func): assert is_array_api_obj(x) +def test_is_jax_array_jitted(): + import jax + import jax.numpy as jnp + + x = jnp.asarray([1, 2, 3]) + assert is_jax_array(x) + assert jax.jit(lambda y: is_jax_array(y))(x) + + @pytest.mark.parametrize('library', is_namespace_functions.keys()) @pytest.mark.parametrize('func', is_namespace_functions.values()) def test_is_xp_namespace(library, func): From 4f1e7f0ff1b5c55f6b40efa01eb183943688147f Mon Sep 17 00:00:00 2001 From: Martin Schuck Date: Mon, 22 Dec 2025 22:02:05 +0100 Subject: [PATCH 2/4] Skip jax test if not installed --- tests/test_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_common.py b/tests/test_common.py index e105fef0..bcff26dc 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -57,7 +57,7 @@ def test_is_xp_array(library, func): def test_is_jax_array_jitted(): - import jax + jax = pytest.importorskip("jax") import jax.numpy as jnp x = jnp.asarray([1, 2, 3]) From bcf350fe4c6d16d0a104d928a69a93eeb02d4e9d Mon Sep 17 00:00:00 2001 From: Martin Schuck Date: Thu, 25 Dec 2025 13:02:19 +0100 Subject: [PATCH 3/4] Fix and test array_api_obj, is_writable_array, is_lazy_array --- array_api_compat/common/_helpers.py | 3 +++ tests/test_common.py | 9 --------- tests/test_jax.py | 24 +++++++++++++++++++++--- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 2f8b0e12..669796d5 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -300,6 +300,7 @@ def _is_array_api_cls(cls: type) -> bool: or _issubclass_fast(cls, "sparse", "SparseArray") # TODO: drop support for jax<0.4.32 which didn't have __array_namespace__ or _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "jax.core", "Tracer") ) @@ -938,6 +939,7 @@ def _is_writeable_cls(cls: type) -> bool | None: if ( _issubclass_fast(cls, "numpy", "generic") or _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "jax.core", "Tracer") or _issubclass_fast(cls, "sparse", "SparseArray") ): return False @@ -977,6 +979,7 @@ def _is_lazy_cls(cls: type) -> bool | None: return False if ( _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "jax.core", "Tracer") or _issubclass_fast(cls, "dask.array", "Array") or _issubclass_fast(cls, "ndonnx", "Array") ): diff --git a/tests/test_common.py b/tests/test_common.py index bcff26dc..85ed032e 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -56,15 +56,6 @@ def test_is_xp_array(library, func): assert is_array_api_obj(x) -def test_is_jax_array_jitted(): - jax = pytest.importorskip("jax") - import jax.numpy as jnp - - x = jnp.asarray([1, 2, 3]) - assert is_jax_array(x) - assert jax.jit(lambda y: is_jax_array(y))(x) - - @pytest.mark.parametrize('library', is_namespace_functions.keys()) @pytest.mark.parametrize('func', is_namespace_functions.values()) def test_is_xp_namespace(library, func): diff --git a/tests/test_jax.py b/tests/test_jax.py index 285958d4..8b079409 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -1,7 +1,14 @@ from numpy.testing import assert_equal import pytest -from array_api_compat import device, to_device +from array_api_compat import ( + device, + to_device, + is_jax_array, + is_lazy_array, + is_array_api_obj, + is_writeable_array, +) try: import jax @@ -13,7 +20,7 @@ @pytest.mark.parametrize( - "func", + "func", [ lambda x: jnp.zeros(1, device=device(x)), lambda x: jnp.zeros_like(jnp.ones(1, device=device(x))), @@ -26,7 +33,7 @@ ), ), lambda x: to_device(jnp.zeros(1), device(x)), - ] + ], ) def test_device_jit(func): # Test work around to https://github.com/jax-ml/jax/issues/26000 @@ -36,3 +43,14 @@ def test_device_jit(func): x = jnp.ones(1) assert_equal(func(x), jnp.asarray([0])) assert_equal(jax.jit(func)(x), jnp.asarray([0])) + + +def test_inside_jit(): + jax = pytest.importorskip("jax") + import jax.numpy as jnp + + x = jnp.asarray([1, 2, 3]) + assert jax.jit(is_jax_array)(x) + assert jax.jit(is_array_api_obj)(x) + assert not jax.jit(is_writeable_array)(x) + assert jax.jit(is_lazy_array)(x) From d8befd5341989a5cbc930be4c16f7a5bfae1cd2b Mon Sep 17 00:00:00 2001 From: Martin Schuck Date: Thu, 25 Dec 2025 16:33:35 +0100 Subject: [PATCH 4/4] Add comments on jax.core.Tracer detection limitations --- array_api_compat/common/_helpers.py | 12 +++++++++--- tests/test_jax.py | 7 ++++--- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 669796d5..8194a083 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -235,6 +235,12 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]: is_pydata_sparse_array """ cls = cast(Hashable, type(x)) + # We test for jax.core.Tracer here to identify jax arrays during jit tracing. From jax 0.8.2 on, + # tracers are not a subclass of jax.Array anymore. Note that tracers can also represent + # non-array values and a fully correct implementation would need to use isinstance checks. Since + # we use hash-based caching with type names as keys, we cannot use instance checks without + # losing performance here. For more information, see + # https://github.com/data-apis/array-api-compat/pull/369 and the corresponding issue. return ( _issubclass_fast(cls, "jax", "Array") or _issubclass_fast(cls, "jax.core", "Tracer") @@ -300,7 +306,7 @@ def _is_array_api_cls(cls: type) -> bool: or _issubclass_fast(cls, "sparse", "SparseArray") # TODO: drop support for jax<0.4.32 which didn't have __array_namespace__ or _issubclass_fast(cls, "jax", "Array") - or _issubclass_fast(cls, "jax.core", "Tracer") + or _issubclass_fast(cls, "jax.core", "Tracer") # see is_jax_array for limitations ) @@ -939,7 +945,7 @@ def _is_writeable_cls(cls: type) -> bool | None: if ( _issubclass_fast(cls, "numpy", "generic") or _issubclass_fast(cls, "jax", "Array") - or _issubclass_fast(cls, "jax.core", "Tracer") + or _issubclass_fast(cls, "jax.core", "Tracer") # see is_jax_array for limitations or _issubclass_fast(cls, "sparse", "SparseArray") ): return False @@ -979,7 +985,7 @@ def _is_lazy_cls(cls: type) -> bool | None: return False if ( _issubclass_fast(cls, "jax", "Array") - or _issubclass_fast(cls, "jax.core", "Tracer") + or _issubclass_fast(cls, "jax.core", "Tracer") # see is_jax_array for limitations or _issubclass_fast(cls, "dask.array", "Array") or _issubclass_fast(cls, "ndonnx", "Array") ): diff --git a/tests/test_jax.py b/tests/test_jax.py index 8b079409..322d0223 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -46,9 +46,10 @@ def test_device_jit(func): def test_inside_jit(): - jax = pytest.importorskip("jax") - import jax.numpy as jnp - + # Test if jax arrays are handled correctly inside jax.jit. + # Jax tracers are not a subclass of jax.Array from 0.8.2 on. We explicitly test that + # tracers are handled appropriately. For limitations, see is_jax_array() docstring. + # Reference issue: https://github.com/data-apis/array-api-compat/issues/368 x = jnp.asarray([1, 2, 3]) assert jax.jit(is_jax_array)(x) assert jax.jit(is_array_api_obj)(x)