diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 37f31ec2..8194a083 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -235,7 +235,17 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]: is_pydata_sparse_array """ cls = cast(Hashable, type(x)) - return _issubclass_fast(cls, "jax", "Array") or _is_jax_zero_gradient_array(x) + # We test for jax.core.Tracer here to identify jax arrays during jit tracing. From jax 0.8.2 on, + # tracers are not a subclass of jax.Array anymore. Note that tracers can also represent + # non-array values and a fully correct implementation would need to use isinstance checks. Since + # we use hash-based caching with type names as keys, we cannot use instance checks without + # losing performance here. For more information, see + # https://github.com/data-apis/array-api-compat/pull/369 and the corresponding issue. + return ( + _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "jax.core", "Tracer") + or _is_jax_zero_gradient_array(x) + ) def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: @@ -296,6 +306,7 @@ def _is_array_api_cls(cls: type) -> bool: or _issubclass_fast(cls, "sparse", "SparseArray") # TODO: drop support for jax<0.4.32 which didn't have __array_namespace__ or _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "jax.core", "Tracer") # see is_jax_array for limitations ) @@ -934,6 +945,7 @@ def _is_writeable_cls(cls: type) -> bool | None: if ( _issubclass_fast(cls, "numpy", "generic") or _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "jax.core", "Tracer") # see is_jax_array for limitations or _issubclass_fast(cls, "sparse", "SparseArray") ): return False @@ -973,6 +985,7 @@ def _is_lazy_cls(cls: type) -> bool | None: return False if ( _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "jax.core", "Tracer") # see is_jax_array for limitations or _issubclass_fast(cls, "dask.array", "Array") or _issubclass_fast(cls, "ndonnx", "Array") ): diff --git a/tests/test_jax.py b/tests/test_jax.py index 285958d4..322d0223 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -1,7 +1,14 @@ from numpy.testing import assert_equal import pytest -from array_api_compat import device, to_device +from array_api_compat import ( + device, + to_device, + is_jax_array, + is_lazy_array, + is_array_api_obj, + is_writeable_array, +) try: import jax @@ -13,7 +20,7 @@ @pytest.mark.parametrize( - "func", + "func", [ lambda x: jnp.zeros(1, device=device(x)), lambda x: jnp.zeros_like(jnp.ones(1, device=device(x))), @@ -26,7 +33,7 @@ ), ), lambda x: to_device(jnp.zeros(1), device(x)), - ] + ], ) def test_device_jit(func): # Test work around to https://github.com/jax-ml/jax/issues/26000 @@ -36,3 +43,15 @@ def test_device_jit(func): x = jnp.ones(1) assert_equal(func(x), jnp.asarray([0])) assert_equal(jax.jit(func)(x), jnp.asarray([0])) + + +def test_inside_jit(): + # Test if jax arrays are handled correctly inside jax.jit. + # Jax tracers are not a subclass of jax.Array from 0.8.2 on. We explicitly test that + # tracers are handled appropriately. For limitations, see is_jax_array() docstring. + # Reference issue: https://github.com/data-apis/array-api-compat/issues/368 + x = jnp.asarray([1, 2, 3]) + assert jax.jit(is_jax_array)(x) + assert jax.jit(is_array_api_obj)(x) + assert not jax.jit(is_writeable_array)(x) + assert jax.jit(is_lazy_array)(x)