forked from data-apis/array-api-extra
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_at.py
454 lines (367 loc) · 14.6 KB
/
_at.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
"""Update operations for read-only arrays."""
from __future__ import annotations
import operator
from collections.abc import Callable
from enum import Enum
from types import ModuleType
from typing import TYPE_CHECKING, ClassVar, cast
from ._utils._compat import (
array_namespace,
is_dask_array,
is_jax_array,
is_writeable_array,
)
from ._utils._helpers import meta_namespace
from ._utils._typing import Array, SetIndex
if TYPE_CHECKING: # pragma: no cover
# TODO import from typing (requires Python >=3.11)
from typing_extensions import Self
class _AtOp(Enum):
"""Operations for use in `xpx.at`."""
SET = "set"
ADD = "add"
SUBTRACT = "subtract"
MULTIPLY = "multiply"
DIVIDE = "divide"
POWER = "power"
MIN = "min"
MAX = "max"
# @override from Python 3.12
def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride]
"""
Return string representation (useful for pytest logs).
Returns
-------
str
The operation's name.
"""
return self.value
class Undef(Enum):
"""Sentinel for undefined values."""
UNDEF = 0
_undef = Undef.UNDEF
class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
"""
Update operations for read-only arrays.
This implements ``jax.numpy.ndarray.at`` for all writeable
backends (those that support ``__setitem__``) and routes
to the ``.at[]`` method for JAX arrays.
Parameters
----------
x : array
Input array.
idx : index, optional
Only `array API standard compliant indices
<https://data-apis.org/array-api/latest/API_specification/indexing.html>`_
are supported.
You may use two alternate syntaxes::
>>> import array_api_extra as xpx
>>> xpx.at(x, idx).set(value) # or add(value), etc.
>>> xpx.at(x)[idx].set(value)
copy : bool, optional
None (default)
The array parameter *may* be modified in place if it is
possible and beneficial for performance.
You should not reuse it after calling this function.
True
Ensure that the inputs are not modified.
False
Ensure that the update operation writes back to the input.
Raise ``ValueError`` if a copy cannot be avoided.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.
Returns
-------
Updated input array.
Warnings
--------
(a) When you omit the ``copy`` parameter, you should never reuse the parameter
array later on; ideally, you should reassign it immediately::
>>> import array_api_extra as xpx
>>> x = xpx.at(x, 0).set(2)
The above best practice pattern ensures that the behaviour won't change depending
on whether ``x`` is writeable or not, as the original ``x`` object is dereferenced
as soon as ``xpx.at`` returns; this way there is no risk to accidentally update it
twice.
On the reverse, the anti-pattern below must be avoided, as it will result in
different behaviour on read-only versus writeable arrays::
>>> x = xp.asarray([0, 0, 0])
>>> y = xpx.at(x, 0).set(2)
>>> z = xpx.at(x, 1).set(3)
In the above example, both calls to ``xpx.at`` update ``x`` in place *if possible*.
This causes the behaviour to diverge depending on whether ``x`` is writeable or not:
- If ``x`` is writeable, then after the snippet above you'll have
``x == y == z == [2, 3, 0]``
- If ``x`` is read-only, then you'll end up with
``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and ``z == [0, 3, 0]``.
The correct pattern to use if you want diverging outputs from the same input is
to enforce copies::
>>> x = xp.asarray([0, 0, 0])
>>> y = xpx.at(x, 0).set(2, copy=True) # Never updates x
>>> z = xpx.at(x, 1).set(3) # May or may not update x in place
>>> del x # avoid accidental reuse of x as we don't know its state anymore
(b) The array API standard does not support integer array indices.
The behaviour of update methods when the index is an array of integers is
undefined and will vary between backends; this is particularly true when the
index contains multiple occurrences of the same index, e.g.::
>>> import numpy as np
>>> import jax.numpy as jnp
>>> import array_api_extra as xpx
>>> xpx.at(np.asarray([123]), np.asarray([0, 0])).add(1)
array([124])
>>> xpx.at(jnp.asarray([123]), jnp.asarray([0, 0])).add(1)
Array([125], dtype=int32)
See Also
--------
jax.numpy.ndarray.at : Equivalent array method in JAX.
Notes
-----
`sparse <https://sparse.pydata.org/>`_, as well as read-only arrays from libraries
not explicitly covered by ``array-api-compat``, are not supported by update
methods.
Boolean masks are supported on Dask and jitted JAX arrays exclusively
when `idx` has the same shape as `x` and `y` is 0-dimensional.
Note that this support is not available in JAX's native
``x.at[mask].set(y)``.
This pattern::
>>> mask = m(x)
>>> x[mask] = f(x[mask])
Can't be replaced by `at`, as it won't work on Dask and JAX inside jax.jit::
>>> mask = m(x)
>>> x = xpx.at(x, mask).set(f(x[mask]) # Crash on Dask and jax.jit
You should instead use::
>>> x = xp.where(m(x), f(x), x)
Examples
--------
Given either of these equivalent expressions::
>>> import array_api_extra as xpx
>>> x = xpx.at(x)[1].add(2)
>>> x = xpx.at(x, 1).add(2)
If x is a JAX array, they are the same as::
>>> x = x.at[1].add(2)
If x is a read-only NumPy array, they are the same as::
>>> x = x.copy()
>>> x[1] += 2
For other known backends, they are the same as::
>>> x[1] += 2
"""
_x: Array
_idx: SetIndex | Undef
__slots__: ClassVar[tuple[str, ...]] = ("_idx", "_x")
def __init__(
self, x: Array, idx: SetIndex | Undef = _undef, /
) -> None: # numpydoc ignore=GL08
self._x = x
self._idx = idx
def __getitem__(self, idx: SetIndex, /) -> Self: # numpydoc ignore=PR01,RT01
"""
Allow for the alternate syntax ``at(x)[start:stop:step]``.
It looks prettier than ``at(x, slice(start, stop, step))``
and feels more intuitive coming from the JAX documentation.
"""
if self._idx is not _undef:
msg = "Index has already been set"
raise ValueError(msg)
return type(self)(self._x, idx)
def _op(
self,
at_op: _AtOp,
in_place_op: Callable[[Array, Array | complex], Array] | None,
out_of_place_op: Callable[[Array, Array], Array] | None,
y: Array | complex,
/,
copy: bool | None,
xp: ModuleType | None,
) -> Array:
"""
Implement all update operations.
Parameters
----------
at_op : _AtOp
Method of JAX's Array.at[].
in_place_op : Callable[[Array, Array | complex], Array] | None
In-place operation to apply on mutable backends::
x[idx] = in_place_op(x[idx], y)
If None::
x[idx] = y
out_of_place_op : Callable[[Array, Array], Array] | None
Out-of-place operation to apply when idx is a boolean mask and the backend
doesn't support in-place updates::
x = xp.where(idx, out_of_place_op(x, y), x)
If None::
x = xp.where(idx, y, x)
y : array or complex
Right-hand side of the operation.
copy : bool or None
Whether to copy the input array. See the class docstring for details.
xp : array_namespace, optional
The array namespace for the input array. Default: infer.
Returns
-------
Array
Updated `x`.
"""
from ._funcs import apply_where # pylint: disable=cyclic-import
x, idx = self._x, self._idx
xp = array_namespace(x, y) if xp is None else xp
if isinstance(idx, Undef):
msg = (
"Index has not been set.\n"
"Usage: either\n"
" at(x, idx).set(value)\n"
"or\n"
" at(x)[idx].set(value)\n"
"(same for all other methods)."
)
raise ValueError(msg)
if copy not in (True, False, None):
msg = f"copy must be True, False, or None; got {copy!r}"
raise ValueError(msg)
writeable = None if copy else is_writeable_array(x)
# JAX inside jax.jit doesn't support in-place updates with boolean
# masks; Dask exclusively supports __setitem__ but not iops.
# We can handle the common special case of 0-dimensional y
# with where(idx, y, x) instead.
if (
(is_dask_array(idx) or is_jax_array(idx))
and idx.dtype == xp.bool
and idx.shape == x.shape
):
y_xp = xp.asarray(y, dtype=x.dtype)
if y_xp.ndim == 0:
if out_of_place_op: # add(), subtract(), ...
# suppress inf warnings on Dask
out = apply_where(
idx, (x, y_xp), out_of_place_op, fill_value=x, xp=xp
)
# Undo int->float promotion on JAX after _AtOp.DIVIDE
out = xp.astype(out, x.dtype, copy=False)
else: # set()
out = xp.where(idx, y_xp, x)
if copy is False:
x[()] = out
return x
return out
# else: this will work on eager JAX and crash on jax.jit and Dask
if copy or (copy is None and not writeable):
if is_jax_array(x):
# Use JAX's at[]
func = cast(
Callable[[Array | complex], Array],
getattr(x.at[idx], at_op.value), # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue,reportUnknownArgumentType]
)
out = func(y)
# Undo int->float promotion on JAX after _AtOp.DIVIDE
return xp.astype(out, x.dtype, copy=False)
# Emulate at[] behaviour for non-JAX arrays
# with a copy followed by an update
x = xp.asarray(x, copy=True)
# A copy of a read-only numpy array is writeable
# Note: this assumes that a copy of a writeable array is writeable
assert not writeable
writeable = None
if writeable is None:
writeable = is_writeable_array(x)
if not writeable:
# sparse crashes here
msg = f"Can't update read-only array {x}"
raise ValueError(msg)
if in_place_op: # add(), subtract(), ...
x[idx] = in_place_op(x[idx], y)
else: # set()
x[idx] = y
return x
def set(
self,
y: Array | complex,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] = y`` and return the update array."""
return self._op(_AtOp.SET, None, None, y, copy=copy, xp=xp)
def add(
self,
y: Array | complex,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] += y`` and return the updated array."""
# Note for this and all other methods based on _iop:
# operator.iadd and operator.add subtly differ in behaviour, as
# only iadd will trigger exceptions when y has an incompatible dtype.
return self._op(_AtOp.ADD, operator.iadd, operator.add, y, copy=copy, xp=xp)
def subtract(
self,
y: Array | complex,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] -= y`` and return the updated array."""
return self._op(
_AtOp.SUBTRACT, operator.isub, operator.sub, y, copy=copy, xp=xp
)
def multiply(
self,
y: Array | complex,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] *= y`` and return the updated array."""
return self._op(
_AtOp.MULTIPLY, operator.imul, operator.mul, y, copy=copy, xp=xp
)
def divide(
self,
y: Array | complex,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] /= y`` and return the updated array."""
return self._op(
_AtOp.DIVIDE, operator.itruediv, operator.truediv, y, copy=copy, xp=xp
)
def power(
self,
y: Array | complex,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] **= y`` and return the updated array."""
return self._op(_AtOp.POWER, operator.ipow, operator.pow, y, copy=copy, xp=xp)
def min(
self,
y: Array | complex,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array."""
# On Dask, this function runs on the chunks, so we need to determine the
# namespace that Dask is wrapping.
# Note that da.minimum _incidentally_ works on NumPy, CuPy, and sparse
# thanks to all these meta-namespaces implementing the __array_ufunc__
# interface, but there's no guarantee that it will work for other
# wrapped libraries in the future.
xp = array_namespace(self._x) if xp is None else xp
mxp = meta_namespace(self._x, xp=xp)
y = xp.asarray(y)
return self._op(_AtOp.MIN, mxp.minimum, mxp.minimum, y, copy=copy, xp=xp)
def max(
self,
y: Array | complex,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array."""
# See note on min()
xp = array_namespace(self._x) if xp is None else xp
mxp = meta_namespace(self._x, xp=xp)
y = xp.asarray(y)
return self._op(_AtOp.MAX, mxp.maximum, mxp.maximum, y, copy=copy, xp=xp)