We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consider
In [1]: import torch In [2]: from array_api_compat import array_namespace In [3]: xp = array_namespace(torch.ones(3)) In [4]: m, n = 7, 4.0 In [5]: import array_api_extra as xpx In [6]: xpx.sinc(2. * xp.arange(n, m, dtype=xp.float64) / (m - 1) - 1.0, xp=xp) --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Cell In[6], line 1 ----> 1 xpx.sinc(2. * xp.arange(n, m, dtype=xp.float64) / (m - 1) - 1.0, xp=xp) File ~/.conda/envs/scipy-dev/lib/python3.11/site-packages/array_api_extra/_funcs.py:518, in sinc(x, xp) 516 raise ValueError(err_msg) 517 # no scalars in `where` - array-api#807 --> 518 y = xp.pi * xp.where( 519 x, x, xp.asarray(xp.finfo(x.dtype).smallest_normal, dtype=x.dtype) 520 ) 521 return xp.sin(y) / y File ~/repos/array-api-compat/array_api_compat/torch/_aliases.py:503, in where(condition, x1, x2) 501 def where(condition: array, x1: array, x2: array, /) -> array: 502 x1, x2 = _fix_promotion(x1, x2) --> 503 return torch.where(condition, x1, x2) RuntimeError: where expected condition to be a boolean tensor, but got a tensor with dtype Double
The same MRE on numpy or jax.numpy is OK.
The text was updated successfully, but these errors were encountered:
The spec doesn't say whether only bool dtypes for condition are supported (https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.where.html). That should be clarified in the spec, and then:
condition
Sorry, something went wrong.
it sounds like the intention in the standard was to only accept boolean arrays here, but that detail was lost with data-apis/array-api#116
where
Successfully merging a pull request may close this issue.
Consider
The same MRE on numpy or jax.numpy is OK.
The text was updated successfully, but these errors were encountered: