Skip to content

MAINT: Simplify at implementation #118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 51 additions & 60 deletions src/array_api_extra/_lib/_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,22 +185,42 @@ def __getitem__(self, idx: Index, /) -> at: # numpydoc ignore=PR01,RT01
raise ValueError(msg)
return at(self._x, idx)

def _update_common(
def _op(
self,
at_op: _AtOp,
y: Array,
in_place_op: Callable[[Array, Array | object], Array] | None,
y: Array | object,
/,
copy: bool | None,
xp: ModuleType | None,
) -> tuple[Array, None] | tuple[None, Array]: # numpydoc ignore=PR01
) -> Array:
"""
Perform common prepocessing to all update operations.
Implement all update operations.

Parameters
----------
at_op : _AtOp
Method of JAX's Array.at[].
in_place_op : Callable[[Array, Array | object], Array] | None
In-place operation to apply on mutable backends::

x[idx] = in_place_op(x[idx], y)

If None::

x[idx] = y

y : array or object
Right-hand side of the operation.
copy : bool or None
Whether to copy the input array. See the class docstring for details.
xp : array_namespace or None
The array namespace for the input array.

Returns
-------
tuple
If the operation can be resolved by ``at[]``, ``(return value, None)``
Otherwise, ``(None, preprocessed x)``.
Array
Updated `x`.
"""
x, idx = self._x, self._idx

Expand Down Expand Up @@ -231,7 +251,7 @@ def _update_common(
if is_jax_array(x):
# Use JAX's at[]
func = cast(Callable[[Array], Array], getattr(x.at[idx], at_op.value))
return func(y), None
return func(y)
# Emulate at[] behaviour for non-JAX arrays
# with a copy followed by an update
if xp is None:
Expand All @@ -249,52 +269,25 @@ def _update_common(
msg = f"Can't update read-only array {x}"
raise ValueError(msg)

return None, x
if in_place_op:
x[self._idx] = in_place_op(x[self._idx], y)
else: # set()
x[self._idx] = y
return x

def set(
self,
y: Array,
y: Array | object,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] = y`` and return the update array."""
res, x = self._update_common(_AtOp.SET, y, copy=copy, xp=xp)
if res is not None:
return res
assert x is not None
x[self._idx] = y
return x

def _iop(
self,
at_op: _AtOp,
elwise_op: Callable[[Array, Array], Array],
y: Array,
/,
copy: bool | None,
xp: ModuleType | None,
) -> Array: # numpydoc ignore=PR01,RT01
"""
``x[idx] += y`` or equivalent in-place operation on a subset of x.

which is the same as saying
x[idx] = x[idx] + y
Note that this is not the same as
operator.iadd(x[idx], y)
Consider for example when x is a numpy array and idx is a fancy index, which
triggers a deep copy on __getitem__.
"""
res, x = self._update_common(at_op, y, copy=copy, xp=xp)
if res is not None:
return res
assert x is not None
x[self._idx] = elwise_op(x[self._idx], y)
return x
return self._op(_AtOp.SET, None, y, copy=copy, xp=xp)

def add(
self,
y: Array,
y: Array | object,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
Expand All @@ -304,70 +297,68 @@ def add(
# Note for this and all other methods based on _iop:
# operator.iadd and operator.add subtly differ in behaviour, as
# only iadd will trigger exceptions when y has an incompatible dtype.
return self._iop(_AtOp.ADD, operator.iadd, y, copy=copy, xp=xp)
return self._op(_AtOp.ADD, operator.iadd, y, copy=copy, xp=xp)

def subtract(
self,
y: Array,
y: Array | object,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] -= y`` and return the updated array."""
return self._iop(_AtOp.SUBTRACT, operator.isub, y, copy=copy, xp=xp)
return self._op(_AtOp.SUBTRACT, operator.isub, y, copy=copy, xp=xp)

def multiply(
self,
y: Array,
y: Array | object,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] *= y`` and return the updated array."""
return self._iop(_AtOp.MULTIPLY, operator.imul, y, copy=copy, xp=xp)
return self._op(_AtOp.MULTIPLY, operator.imul, y, copy=copy, xp=xp)

def divide(
self,
y: Array,
y: Array | object,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] /= y`` and return the updated array."""
return self._iop(_AtOp.DIVIDE, operator.itruediv, y, copy=copy, xp=xp)
return self._op(_AtOp.DIVIDE, operator.itruediv, y, copy=copy, xp=xp)

def power(
self,
y: Array,
y: Array | object,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] **= y`` and return the updated array."""
return self._iop(_AtOp.POWER, operator.ipow, y, copy=copy, xp=xp)
return self._op(_AtOp.POWER, operator.ipow, y, copy=copy, xp=xp)

def min(
self,
y: Array,
y: Array | object,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array."""
if xp is None:
xp = array_namespace(self._x)
xp = array_namespace(self._x) if xp is None else xp
y = xp.asarray(y)
return self._iop(_AtOp.MIN, xp.minimum, y, copy=copy, xp=xp)
return self._op(_AtOp.MIN, xp.minimum, y, copy=copy, xp=xp)

def max(
self,
y: Array,
y: Array | object,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array."""
if xp is None:
xp = array_namespace(self._x)
xp = array_namespace(self._x) if xp is None else xp
y = xp.asarray(y)
return self._iop(_AtOp.MAX, xp.maximum, y, copy=copy, xp=xp)
return self._op(_AtOp.MAX, xp.maximum, y, copy=copy, xp=xp)
Loading