Skip to content

Error when broadcasting array API compliant class #8665

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
5 tasks done
TomNicholas opened this issue Jan 25, 2024 · 1 comment · Fixed by #8669
Closed
5 tasks done

Error when broadcasting array API compliant class #8665

TomNicholas opened this issue Jan 25, 2024 · 1 comment · Fixed by #8669
Labels
array API standard Support for the Python array API standard bug topic-arrays related to flexible array support

Comments

@TomNicholas
Copy link
Member

What happened?

Broadcasting fails for array types that strictly follow the array API standard.

What did you expect to happen?

With a normal numpy array this obviously works fine.

Minimal Complete Verifiable Example

import numpy.array_api as nxp

arr = nxp.asarray([[1, 2, 3], [4, 5, 6]], dtype=np.dtype('float32'))

var = xr.Variable(data=arr, dims=['x', 'y'])

var.isel(x=0)  # this is fine

var * var.isel(x=0)  # this is not

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[31], line 1
----> 1 var * var.isel(x=0)

File ~/Documents/Work/Code/xarray/xarray/core/_typed_ops.py:487, in VariableOpsMixin.__mul__(self, other)
    486 def __mul__(self, other: VarCompatible) -> Self | T_DataArray:
--> 487     return self._binary_op(other, operator.mul)

File ~/Documents/Work/Code/xarray/xarray/core/variable.py:2406, in Variable._binary_op(self, other, f, reflexive)
   2404     other_data, self_data, dims = _broadcast_compat_data(other, self)
   2405 else:
-> 2406     self_data, other_data, dims = _broadcast_compat_data(self, other)
   2407 keep_attrs = _get_keep_attrs(default=False)
   2408 attrs = self._attrs if keep_attrs else None

File ~/Documents/Work/Code/xarray/xarray/core/variable.py:2922, in _broadcast_compat_data(self, other)
   2919 def _broadcast_compat_data(self, other):
   2920     if all(hasattr(other, attr) for attr in ["dims", "data", "shape", "encoding"]):
   2921         # `other` satisfies the necessary Variable API for broadcast_variables
-> 2922         new_self, new_other = _broadcast_compat_variables(self, other)
   2923         self_data = new_self.data
   2924         other_data = new_other.data

File ~/Documents/Work/Code/xarray/xarray/core/variable.py:2899, in _broadcast_compat_variables(*variables)
   2893 """Create broadcast compatible variables, with the same dimensions.
   2894 
   2895 Unlike the result of broadcast_variables(), some variables may have
   2896 dimensions of size 1 instead of the size of the broadcast dimension.
   2897 """
   2898 dims = tuple(_unified_dims(variables))
-> 2899 return tuple(var.set_dims(dims) if var.dims != dims else var for var in variables)

File ~/Documents/Work/Code/xarray/xarray/core/variable.py:2899, in <genexpr>(.0)
   2893 """Create broadcast compatible variables, with the same dimensions.
   2894 
   2895 Unlike the result of broadcast_variables(), some variables may have
   2896 dimensions of size 1 instead of the size of the broadcast dimension.
   2897 """
   2898 dims = tuple(_unified_dims(variables))
-> 2899 return tuple(var.set_dims(dims) if var.dims != dims else var for var in variables)

File ~/Documents/Work/Code/xarray/xarray/core/variable.py:1479, in Variable.set_dims(self, dims, shape)
   1477     expanded_data = duck_array_ops.broadcast_to(self.data, tmp_shape)
   1478 else:
-> 1479     expanded_data = self.data[(None,) * (len(expanded_dims) - self.ndim)]
   1481 expanded_var = Variable(
   1482     expanded_dims, expanded_data, self._attrs, self._encoding, fastpath=True
   1483 )
   1484 return expanded_var.transpose(*dims)

File ~/miniconda3/envs/dev3.11/lib/python3.12/site-packages/numpy/array_api/_array_object.py:555, in Array.__getitem__(self, key)
    550 """
    551 Performs the operation __getitem__.
    552 """
    553 # Note: Only indices required by the spec are allowed. See the
    554 # docstring of _validate_index
--> 555 self._validate_index(key)
    556 if isinstance(key, Array):
    557     # Indexing self._array with array_api arrays can be erroneous
    558     key = key._array

File ~/miniconda3/envs/dev3.11/lib/python3.12/site-packages/numpy/array_api/_array_object.py:348, in Array._validate_index(self, key)
    344 elif n_ellipsis == 0:
    345     # Note boolean masks must be the sole index, which we check for
    346     # later on.
    347     if not key_has_mask and n_single_axes < self.ndim:
--> 348         raise IndexError(
    349             f"{self.ndim=}, but the multi-axes index only specifies "
    350             f"{n_single_axes} dimensions. If this was intentional, "
    351             "add a trailing ellipsis (...) which expands into as many "
    352             "slices (:) as necessary - this is what np.ndarray arrays "
    353             "implicitly do, but such flat indexing behaviour is not "
    354             "specified in the Array API."
    355         )
    357 if n_ellipsis == 0:
    358     indexed_shape = self.shape

IndexError: self.ndim=1, but the multi-axes index only specifies 0 dimensions. If this was intentional, add a trailing ellipsis (...) which expands into as many slices (:) as necessary - this is what np.ndarray arrays implicitly do, but such flat indexing behaviour is not specified in the Array API.

MVCE confirmation

  • Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
  • Complete example — the example is self-contained, including all data and the text of any traceback.
  • Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
  • New issue — a search of GitHub Issues suggests this is not a duplicate.
  • Recent environment — the issue occurs with the latest version of xarray and its dependencies.

Relevant log output

No response

Anything else we need to know?

No response

Environment

main branch of xarray, numpy 1.26.0

@TomNicholas TomNicholas added bug topic-arrays related to flexible array support array API standard Support for the Python array API standard labels Jan 25, 2024
@TomNicholas
Copy link
Member Author

TomNicholas commented Jan 25, 2024

It's kind of weird that given these array API Variables

var1 = xr.DataArray(
    nxp.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, np.nan]]),
    dims=("x", "y"),
    coords={"x": [10, 20]},
)
var2 = xr.DataArray(nxp.asarray([1.0, 2.0]), dims="x")

this doesn't work

var1 * var2
IndexError: self.ndim=1, but the multi-axes index only specifies 0 dimensions. If this was intentional, add a trailing ellipsis (...) which expands into as many slices (:) as necessary - this is what np.ndarray arrays implicitly do, but such flat indexing behaviour is not specified in the Array API.

but this does

var1_expanded, var2_expanded = xr.broadcast(var1, var2)
var1_expanded * var2_expanded

Naively I would have thought that the broadcasting should go through the same code path in either case?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array API standard Support for the Python array API standard bug topic-arrays related to flexible array support
Projects
Development

Successfully merging a pull request may close this issue.

1 participant