-
Notifications
You must be signed in to change notification settings - Fork 17
ENH: apply_where: add kwargs support #624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -211,6 +211,7 @@ def test_device(self, xp: ModuleType, device: Device): | |||||
| ) | ||||||
| @given( | ||||||
| n_arrays=st.integers(min_value=1, max_value=3), | ||||||
| n_kwarrays=st.integers(min_value=1, max_value=3), | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Ideally I'd also like to test the use case where n_arrays=0 and n_kwarrays>0, but it would be a bit more laborious to explain to hypothesis |
||||||
| rng_seed=st.integers(min_value=1000000000, max_value=9999999999), | ||||||
| dtype=npst.floating_dtypes(sizes=(32, 64)), | ||||||
| p=st.floats(min_value=0, max_value=1), | ||||||
|
|
@@ -219,6 +220,7 @@ def test_device(self, xp: ModuleType, device: Device): | |||||
| def test_hypothesis( | ||||||
| self, | ||||||
| n_arrays: int, | ||||||
| n_kwarrays: int, | ||||||
| rng_seed: int, | ||||||
| dtype: np.dtype[Any], | ||||||
| p: float, | ||||||
|
|
@@ -233,9 +235,13 @@ def test_hypothesis( | |||||
| ): | ||||||
| pytest.xfail(reason="NumPy 1.x dtype promotion for scalars") | ||||||
|
|
||||||
| mbs = npst.mutually_broadcastable_shapes(num_shapes=n_arrays + 1, min_side=0) | ||||||
| mbs = npst.mutually_broadcastable_shapes( | ||||||
| num_shapes=1 + n_arrays + n_kwarrays, min_side=0 | ||||||
| ) | ||||||
| input_shapes, _ = data.draw(mbs) | ||||||
| cond_shape, *shapes = input_shapes | ||||||
| cond_shape = input_shapes[0] | ||||||
| shapes = input_shapes[1 : 1 + n_arrays] | ||||||
| kwshapes = input_shapes[1 + n_arrays :] | ||||||
|
|
||||||
| # cupy/cupy#8382 | ||||||
| # https://github.com/jax-ml/jax/issues/26658 | ||||||
|
|
@@ -257,22 +263,34 @@ def test_hypothesis( | |||||
| for shape in shapes | ||||||
| ) | ||||||
|
|
||||||
| def f1(*args: Array) -> Array: | ||||||
| return cast(Array, sum(args)) | ||||||
| kwargs = { | ||||||
| str(n): xp.asarray( | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Variadic function arguments that don't respect the naming convention of python variables work only by accident and could break in future / alternative Python interpreters |
||||||
| data.draw(npst.arrays(dtype=dtype.type, shape=shape, elements=elements)) | ||||||
| ) | ||||||
| for n, shape in enumerate(kwshapes) | ||||||
| } | ||||||
| kwkeys = kwargs.keys() | ||||||
|
|
||||||
| def f1(*args: Array, **kwargs: dict[str, Array]) -> Array: | ||||||
| assert set(kwargs.keys()) == set(kwkeys) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| args_kwargs = cast(tuple[Array, ...], (*args, *kwargs.values())) | ||||||
| return cast(Array, sum(args_kwargs)) | ||||||
|
|
||||||
| def f2(*args: Array) -> Array: | ||||||
| return cast(Array, sum(args) / 2) | ||||||
| def f2(*args: Array, **kwargs: dict[str, Array]) -> Array: | ||||||
| assert set(kwargs.keys()) == set(kwkeys) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| args_kwargs = cast(tuple[Array, ...], (*args, *kwargs.values())) | ||||||
| return cast(Array, sum(args_kwargs) / 2) | ||||||
|
|
||||||
| rng = np.random.default_rng(rng_seed) | ||||||
| cond = xp.asarray(rng.random(size=cond_shape) > p) | ||||||
|
|
||||||
| res1 = apply_where(cond, arrays, f1, fill_value=fill_value) | ||||||
| res2 = apply_where(cond, arrays, f1, f2) | ||||||
| res3 = apply_where(cond, arrays, f1, fill_value=float_fill_value) | ||||||
| res1 = apply_where(cond, arrays, f1, fill_value=fill_value, kwargs=kwargs) | ||||||
| res2 = apply_where(cond, arrays, f1, f2, kwargs=kwargs) | ||||||
| res3 = apply_where(cond, arrays, f1, fill_value=float_fill_value, kwargs=kwargs) | ||||||
|
|
||||||
| ref1 = xp.where(cond, f1(*arrays), fill_value) | ||||||
| ref2 = xp.where(cond, f1(*arrays), f2(*arrays)) | ||||||
| ref3 = xp.where(cond, f1(*arrays), float_fill_value) | ||||||
| ref1 = xp.where(cond, f1(*arrays, **kwargs), fill_value) | ||||||
| ref2 = xp.where(cond, f1(*arrays, **kwargs), f2(*arrays, **kwargs)) | ||||||
| ref3 = xp.where(cond, f1(*arrays, **kwargs), float_fill_value) | ||||||
|
|
||||||
| xp_assert_close(res1, ref1, rtol=2e-16) | ||||||
| xp_assert_equal(res2, ref2) | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be enhanced to allow for non-array kwargs.
The workaround without it, which requires a bit of expertise to figure out, is to hide the non-array kwargs with
functools.partial.This is a strict enhancement suggestion and not blocking for this PR.