Skip to content

Commit a33ad3e

Browse files
authored
Merge pull request #15 from lucascolley/scipy
some fixes
2 parents 4fee5fd + 9af6561 commit a33ad3e

File tree

1 file changed

+36
-9
lines changed

1 file changed

+36
-9
lines changed

Diff for: src/pint_array/__init__.py

+36-9
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import types
1515
from typing import Generic
1616

17+
from array_api_compat import size
1718
from pint import Quantity
1819
from pint.facets.plain import MagnitudeT, PlainQuantity
1920

@@ -107,6 +108,15 @@ def __repr__(self):
107108
f" '{self.units}'\n)>"
108109
)
109110

111+
def __mul__(self, other):
112+
if hasattr(other, "units"):
113+
magnitude = self._call_super_method("__mul__", other.magnitude)
114+
units = self.units * other.units
115+
else:
116+
magnitude = self._call_super_method("__mul__", other)
117+
units = self.units
118+
return ArrayUnitQuantity(magnitude, units)
119+
110120
## Linear Algebra Methods ##
111121
def __matmul__(self, other):
112122
return mod.matmul(self, other)
@@ -133,11 +143,11 @@ def mT(self):
133143
def __dlpack_device__(self):
134144
return self.magnitude.__dlpack_device__()
135145

136-
def __dlpack__(self, **kwargs):
146+
def __dlpack__(self, stream=None, max_version=None, dl_device=None, copy=None):
137147
# really not sure how to define this
138-
return self.magnitude.__dlpack__(**kwargs)
139-
140-
__dlpack__.__signature__ = inspect.signature(xp.empty(0).__dlpack__)
148+
return self.magnitude.__dlpack__(
149+
stream=stream, max_version=max_version, dl_device=dl_device, copy=copy
150+
)
141151

142152
def to_device(self, device, /, *, stream=None):
143153
_magnitude = self._magnitude.to_device(device, stream=stream)
@@ -185,7 +195,7 @@ def fun(self, name=name):
185195
"__lshift__",
186196
"__lt__",
187197
"__mod__",
188-
"__mul__",
198+
# "__mul__",
189199
"__ne__",
190200
"__or__",
191201
"__pow__",
@@ -301,7 +311,8 @@ def manip_fun(x, *args, **kwargs):
301311
magnitude = xp.asarray(x.magnitude, copy=True)
302312
units = x.units
303313
elif hasattr(x, "__array_namespace__"):
304-
magnitude = x
314+
x = asarray(x)
315+
magnitude = xp.asarray(x.magnitude, copy=True)
305316
units = None
306317
one_array = True
307318
else:
@@ -390,7 +401,9 @@ def astype(x, dtype, /, *, copy=True, device=None):
390401
if device is None and not copy and dtype == x.dtype:
391402
return x
392403
x = asarray(x)
393-
magnitude = xp.astype(x.magnitude, dtype, copy=copy, device=device)
404+
# https://github.com/data-apis/array-api-compat/issues/226
405+
# magnitude = xp.astype(x.magnitude, dtype, copy=copy, device=device)
406+
magnitude = xp.astype(x.magnitude, dtype, copy=copy)
394407
return ArrayUnitQuantity(magnitude, x.units)
395408

396409
mod.astype = astype
@@ -600,7 +613,7 @@ def where(condition, x1, x2, /):
600613
def fun(x, /, *args, func_str=func_str, **kwargs):
601614
x = asarray(x)
602615
magnitude = xp.asarray(x.magnitude, copy=True)
603-
magnitude = getattr(xp, func_str)(x, *args, **kwargs)
616+
magnitude = getattr(xp, func_str)(magnitude, *args, **kwargs)
604617
return ArrayUnitQuantity(magnitude, x.units)
605618

606619
setattr(mod, func_str, fun)
@@ -651,6 +664,20 @@ def fun(x1, x2, /, *args, func_str=func_str, **kwargs):
651664

652665
setattr(mod, func_str, fun)
653666

667+
def multiply(x1, x2, /, *args, **kwargs):
668+
x1 = asarray(x1)
669+
x2 = asarray(x2)
670+
671+
units = x1.units * x2.units
672+
673+
x1_magnitude = xp.asarray(x1.magnitude, copy=True)
674+
x2_magnitude = x2.m_as(x1.units)
675+
676+
magnitude = xp.multiply(x1_magnitude, x2_magnitude, *args, **kwargs)
677+
return ArrayUnitQuantity(magnitude, units)
678+
679+
mod.multiply = multiply
680+
654681
## Indexing Functions
655682
def take(x, indices, /, **kwargs):
656683
magnitude = xp.take(x.magnitude, indices.magnitude, **kwargs)
@@ -791,7 +818,7 @@ def var(x, /, *args, **kwargs):
791818
def prod(x, /, *args, axis=None, **kwargs):
792819
x = asarray(x)
793820
magnitude = xp.asarray(x.magnitude, copy=True)
794-
exponent = magnitude.shape[axis] if axis is not None else magnitude.size
821+
exponent = magnitude.shape[axis] if axis is not None else size(magnitude)
795822
units = x.units**exponent
796823
magnitude = xp.prod(magnitude, *args, axis=axis, **kwargs)
797824
return ArrayUnitQuantity(magnitude, units)

0 commit comments

Comments
 (0)