Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def apply_where( # numpydoc ignore=GL08
f2: Callable[..., Array],
/,
*,
kwargs: dict[str, Array] | None = None,
xp: ModuleType | None = None,
) -> Array: ...

Expand All @@ -53,6 +54,7 @@ def apply_where( # numpydoc ignore=GL08
/,
*,
fill_value: Array | complex,
kwargs: dict[str, Array] | None = None,
xp: ModuleType | None = None,
) -> Array: ...

Expand All @@ -65,6 +67,7 @@ def apply_where( # numpydoc ignore=PR01,PR02
/,
*,
fill_value: Array | complex | None = None,
kwargs: dict[str, Array] | None = None,
xp: ModuleType | None = None,
) -> Array:
"""
Expand All @@ -91,6 +94,9 @@ def apply_where( # numpydoc ignore=PR01,PR02
It does not need to be scalar; it needs however to be broadcastable with
`cond` and `args`.
Mutually exclusive with `f2`. You must provide one or the other.
kwargs : dict of str : Array pairs
Keyword argument(s) to `f1` (and `f2`). Values must be broadcastable with
`cond`.
xp : array_namespace, optional
The standard-compatible namespace for `cond` and `args`. Default: infer.

Expand Down Expand Up @@ -129,6 +135,11 @@ def apply_where( # numpydoc ignore=PR01,PR02
args_ = list(args) if isinstance(args, tuple) else [args]
del args

kwargs_ = {} if kwargs is None else kwargs
kwkeys = list(kwargs_.keys())
args_ = [*args_, *kwargs_.values()]
del kwargs

xp = array_namespace(cond, fill_value, *args_) if xp is None else xp

if isinstance(fill_value, int | float | complex | NoneType):
Expand All @@ -139,8 +150,11 @@ def apply_where( # numpydoc ignore=PR01,PR02
if is_dask_namespace(xp):
meta_xp = meta_namespace(cond, fill_value, *args_, xp=xp)
# map_blocks doesn't descend into tuples of Arrays
return xp.map_blocks(_apply_where, cond, f1, f2, fill_value, *args_, xp=meta_xp)
return _apply_where(cond, f1, f2, fill_value, *args_, xp=xp)
return xp.map_blocks(
_apply_where, cond, f1, f2, fill_value, *args_, kwkeys=kwkeys, xp=meta_xp
)

return _apply_where(cond, f1, f2, fill_value, *args_, kwkeys=kwkeys, xp=xp)


def _apply_where( # numpydoc ignore=PR01,RT01
Expand All @@ -149,15 +163,26 @@ def _apply_where( # numpydoc ignore=PR01,RT01
f2: Callable[..., Array] | None,
fill_value: Array | int | float | complex | bool | None,
*args: Array,
kwkeys: list[str],
xp: ModuleType,
) -> Array:
"""Helper of `apply_where`. On Dask, this runs on a single chunk."""

nargs = len(args) - len(kwkeys)
kwargs = dict(zip(kwkeys, args[nargs:], strict=True))
args = args[:nargs]

if not capabilities(xp, device=_compat.device(cond))["boolean indexing"]:
# jax.jit does not support assignment by boolean mask
return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value)
return xp.where(
cond,
f1(*args, **kwargs),
f2(*args, **kwargs) if f2 is not None else fill_value,
)

temp1 = f1(*(arr[cond] for arr in args))
temp1 = f1(
*(arr[cond] for arr in args), **{key: val[cond] for key, val in kwargs.items()}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be enhanced to allow for non-array kwargs.
The workaround without it, which requires a bit of expertise to figure out, is to hide the non-array kwargs with functools.partial.

This is a strict enhancement suggestion and not blocking for this PR.

)

if f2 is None:
dtype = xp.result_type(temp1, fill_value)
Expand All @@ -167,7 +192,10 @@ def _apply_where( # numpydoc ignore=PR01,RT01
out = xp.astype(fill_value, dtype, copy=True)
else:
ncond = ~cond
temp2 = f2(*(arr[ncond] for arr in args))
temp2 = f2(
*(arr[ncond] for arr in args),
**{key: val[ncond] for key, val in kwargs.items()},
)
dtype = xp.result_type(temp1, temp2)
out = xp.empty_like(cond, dtype=dtype)
out = at(out, ncond).set(temp2)
Expand Down
42 changes: 30 additions & 12 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def test_device(self, xp: ModuleType, device: Device):
)
@given(
n_arrays=st.integers(min_value=1, max_value=3),
n_kwarrays=st.integers(min_value=1, max_value=3),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
n_kwarrays=st.integers(min_value=1, max_value=3),
n_kwarrays=st.integers(min_value=0, max_value=3),

Ideally I'd also like to test the use case where n_arrays=0 and n_kwarrays>0, but it would be a bit more laborious to explain to hypothesis

rng_seed=st.integers(min_value=1000000000, max_value=9999999999),
dtype=npst.floating_dtypes(sizes=(32, 64)),
p=st.floats(min_value=0, max_value=1),
Expand All @@ -219,6 +220,7 @@ def test_device(self, xp: ModuleType, device: Device):
def test_hypothesis(
self,
n_arrays: int,
n_kwarrays: int,
rng_seed: int,
dtype: np.dtype[Any],
p: float,
Expand All @@ -233,9 +235,13 @@ def test_hypothesis(
):
pytest.xfail(reason="NumPy 1.x dtype promotion for scalars")

mbs = npst.mutually_broadcastable_shapes(num_shapes=n_arrays + 1, min_side=0)
mbs = npst.mutually_broadcastable_shapes(
num_shapes=1 + n_arrays + n_kwarrays, min_side=0
)
input_shapes, _ = data.draw(mbs)
cond_shape, *shapes = input_shapes
cond_shape = input_shapes[0]
shapes = input_shapes[1 : 1 + n_arrays]
kwshapes = input_shapes[1 + n_arrays :]

# cupy/cupy#8382
# https://github.com/jax-ml/jax/issues/26658
Expand All @@ -257,22 +263,34 @@ def test_hypothesis(
for shape in shapes
)

def f1(*args: Array) -> Array:
return cast(Array, sum(args))
kwargs = {
str(n): xp.asarray(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
str(n): xp.asarray(
f"kw{n}": xp.asarray(

Variadic function arguments that don't respect the naming convention of python variables work only by accident and could break in future / alternative Python interpreters

data.draw(npst.arrays(dtype=dtype.type, shape=shape, elements=elements))
)
for n, shape in enumerate(kwshapes)
}
kwkeys = kwargs.keys()

def f1(*args: Array, **kwargs: dict[str, Array]) -> Array:
assert set(kwargs.keys()) == set(kwkeys)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert set(kwargs.keys()) == set(kwkeys)
assert kwargs.keys() == kwkeys

args_kwargs = cast(tuple[Array, ...], (*args, *kwargs.values()))
return cast(Array, sum(args_kwargs))

def f2(*args: Array) -> Array:
return cast(Array, sum(args) / 2)
def f2(*args: Array, **kwargs: dict[str, Array]) -> Array:
assert set(kwargs.keys()) == set(kwkeys)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert set(kwargs.keys()) == set(kwkeys)
assert kwargs.keys() == kwkeys

args_kwargs = cast(tuple[Array, ...], (*args, *kwargs.values()))
return cast(Array, sum(args_kwargs) / 2)

rng = np.random.default_rng(rng_seed)
cond = xp.asarray(rng.random(size=cond_shape) > p)

res1 = apply_where(cond, arrays, f1, fill_value=fill_value)
res2 = apply_where(cond, arrays, f1, f2)
res3 = apply_where(cond, arrays, f1, fill_value=float_fill_value)
res1 = apply_where(cond, arrays, f1, fill_value=fill_value, kwargs=kwargs)
res2 = apply_where(cond, arrays, f1, f2, kwargs=kwargs)
res3 = apply_where(cond, arrays, f1, fill_value=float_fill_value, kwargs=kwargs)

ref1 = xp.where(cond, f1(*arrays), fill_value)
ref2 = xp.where(cond, f1(*arrays), f2(*arrays))
ref3 = xp.where(cond, f1(*arrays), float_fill_value)
ref1 = xp.where(cond, f1(*arrays, **kwargs), fill_value)
ref2 = xp.where(cond, f1(*arrays, **kwargs), f2(*arrays, **kwargs))
ref3 = xp.where(cond, f1(*arrays, **kwargs), float_fill_value)

xp_assert_close(res1, ref1, rtol=2e-16)
xp_assert_equal(res2, ref2)
Expand Down