-
Notifications
You must be signed in to change notification settings - Fork 39
Fix is_jax_array for jax>=0.8.2
#369
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -235,7 +235,17 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]: | |
| is_pydata_sparse_array | ||
| """ | ||
| cls = cast(Hashable, type(x)) | ||
| return _issubclass_fast(cls, "jax", "Array") or _is_jax_zero_gradient_array(x) | ||
| # We test for jax.core.Tracer here to identify jax arrays during jit tracing. From jax 0.8.2 on, | ||
| # tracers are not a subclass of jax.Array anymore. Note that tracers can also represent | ||
| # non-array values and a fully correct implementation would need to use isinstance checks. Since | ||
| # we use hash-based caching with type names as keys, we cannot use instance checks without | ||
| # losing performance here. For more information, see | ||
| # https://github.com/data-apis/array-api-compat/pull/369 and the corresponding issue. | ||
| return ( | ||
| _issubclass_fast(cls, "jax", "Array") | ||
| or _issubclass_fast(cls, "jax.core", "Tracer") | ||
| or _is_jax_zero_gradient_array(x) | ||
| ) | ||
|
Comment on lines
+244
to
+248
|
||
|
|
||
|
|
||
| def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: | ||
|
|
@@ -296,6 +306,7 @@ def _is_array_api_cls(cls: type) -> bool: | |
| or _issubclass_fast(cls, "sparse", "SparseArray") | ||
| # TODO: drop support for jax<0.4.32 which didn't have __array_namespace__ | ||
| or _issubclass_fast(cls, "jax", "Array") | ||
| or _issubclass_fast(cls, "jax.core", "Tracer") # see is_jax_array for limitations | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -934,6 +945,7 @@ def _is_writeable_cls(cls: type) -> bool | None: | |
| if ( | ||
| _issubclass_fast(cls, "numpy", "generic") | ||
| or _issubclass_fast(cls, "jax", "Array") | ||
| or _issubclass_fast(cls, "jax.core", "Tracer") # see is_jax_array for limitations | ||
| or _issubclass_fast(cls, "sparse", "SparseArray") | ||
| ): | ||
| return False | ||
|
|
@@ -973,6 +985,7 @@ def _is_lazy_cls(cls: type) -> bool | None: | |
| return False | ||
| if ( | ||
| _issubclass_fast(cls, "jax", "Array") | ||
| or _issubclass_fast(cls, "jax.core", "Tracer") # see is_jax_array for limitations | ||
| or _issubclass_fast(cls, "dask.array", "Array") | ||
| or _issubclass_fast(cls, "ndonnx", "Array") | ||
| ): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main reason for the change in v0.8.2 is that tracers now can represent more than just arrays, and so returning True for any tracer may lead to false positives.
The logic in
Array.__instancecheck__is what is required to accurately check in all contexts whetherxis an array: https://github.com/jax-ml/jax/blob/82ae1b1cde42a5b93e00d8c3376cde627c2d83bb/jaxlib/py_array.cc#L2187-L2218The easiest way to accomplish this would be to check
isinstance(x, jax.Array)rather than recreating that logic here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That will force us to use a non-cachable operation, which is going to slow things down. But I don't think we have a choice given that the
Tracertype itself no longer holds information on whether or not it's anArray.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jakevdp Can you elaborate a bit more on which kinds of non-array objects now create tracers? I.e. we use an
_is_writable_clsand_is_lazy_cls. Even if tracers are not arrays, these functions could still be decidable based on the type only. Are tracers still always lazy and always immutable? I realize that these questions might be ill-defined since tracers do not represent real objects and can disappear from the final computation graph, but for our purposes that's not an issue.Also, could you show an example of a tracer that does not wrap an array? E.g. are bools in the input now traced as bools and not as arrays? This would be very helpful for testing.
@crusaderky Current helper methods such as
_is_writable_clsare designed to returnNonefor non-array API objects. It seems we cannot make that decision based off of type information only onjax>=0.8.2. Are you fine with relaxing theNonestrategy and returningTrueforTracers in general, or do you want to be strict here? The former still fits into our current setup, the latter must use non-cachableisinstancechecks.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An example is the new hijax Box type. There are no public APIs for this (yet), but here's how you can construct it using currently-private APIs at head:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current design is that
Tracersubclass reflects the type of transformation being traced (e.g. jit, vmap, grad, jaxpr, etc.) while theavalattribute can be inspected to see what kind of object is being traced.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, that's very helpful. At this point I think we need a decision by the array-api-compat team. Both versions shouldn't be hard to implement.
@crusaderky @lucascolley what are your thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the false positives seem like a concern from SciPy's side. Maybe we go with this, but add a note into the code comments about the false positives in case anyone complains in the future?