14
14
import types
15
15
from typing import Generic
16
16
17
+ from array_api_compat import size
17
18
from pint import Quantity
18
19
from pint .facets .plain import MagnitudeT , PlainQuantity
19
20
@@ -107,6 +108,15 @@ def __repr__(self):
107
108
f" '{ self .units } '\n )>"
108
109
)
109
110
111
+ def __mul__ (self , other ):
112
+ if hasattr (other , "units" ):
113
+ magnitude = self ._call_super_method ("__mul__" , other .magnitude )
114
+ units = self .units * other .units
115
+ else :
116
+ magnitude = self ._call_super_method ("__mul__" , other )
117
+ units = self .units
118
+ return ArrayUnitQuantity (magnitude , units )
119
+
110
120
## Linear Algebra Methods ##
111
121
def __matmul__ (self , other ):
112
122
return mod .matmul (self , other )
@@ -133,11 +143,11 @@ def mT(self):
133
143
def __dlpack_device__ (self ):
134
144
return self .magnitude .__dlpack_device__ ()
135
145
136
- def __dlpack__ (self , ** kwargs ):
146
+ def __dlpack__ (self , stream = None , max_version = None , dl_device = None , copy = None ):
137
147
# really not sure how to define this
138
- return self .magnitude .__dlpack__ (** kwargs )
139
-
140
- __dlpack__ . __signature__ = inspect . signature ( xp . empty ( 0 ). __dlpack__ )
148
+ return self .magnitude .__dlpack__ (
149
+ stream = stream , max_version = max_version , dl_device = dl_device , copy = copy
150
+ )
141
151
142
152
def to_device (self , device , / , * , stream = None ):
143
153
_magnitude = self ._magnitude .to_device (device , stream = stream )
@@ -185,7 +195,7 @@ def fun(self, name=name):
185
195
"__lshift__" ,
186
196
"__lt__" ,
187
197
"__mod__" ,
188
- "__mul__" ,
198
+ # "__mul__",
189
199
"__ne__" ,
190
200
"__or__" ,
191
201
"__pow__" ,
@@ -301,7 +311,8 @@ def manip_fun(x, *args, **kwargs):
301
311
magnitude = xp .asarray (x .magnitude , copy = True )
302
312
units = x .units
303
313
elif hasattr (x , "__array_namespace__" ):
304
- magnitude = x
314
+ x = asarray (x )
315
+ magnitude = xp .asarray (x .magnitude , copy = True )
305
316
units = None
306
317
one_array = True
307
318
else :
@@ -390,7 +401,9 @@ def astype(x, dtype, /, *, copy=True, device=None):
390
401
if device is None and not copy and dtype == x .dtype :
391
402
return x
392
403
x = asarray (x )
393
- magnitude = xp .astype (x .magnitude , dtype , copy = copy , device = device )
404
+ # https://github.com/data-apis/array-api-compat/issues/226
405
+ # magnitude = xp.astype(x.magnitude, dtype, copy=copy, device=device)
406
+ magnitude = xp .astype (x .magnitude , dtype , copy = copy )
394
407
return ArrayUnitQuantity (magnitude , x .units )
395
408
396
409
mod .astype = astype
@@ -600,7 +613,7 @@ def where(condition, x1, x2, /):
600
613
def fun (x , / , * args , func_str = func_str , ** kwargs ):
601
614
x = asarray (x )
602
615
magnitude = xp .asarray (x .magnitude , copy = True )
603
- magnitude = getattr (xp , func_str )(x , * args , ** kwargs )
616
+ magnitude = getattr (xp , func_str )(magnitude , * args , ** kwargs )
604
617
return ArrayUnitQuantity (magnitude , x .units )
605
618
606
619
setattr (mod , func_str , fun )
@@ -651,6 +664,20 @@ def fun(x1, x2, /, *args, func_str=func_str, **kwargs):
651
664
652
665
setattr (mod , func_str , fun )
653
666
667
+ def multiply (x1 , x2 , / , * args , ** kwargs ):
668
+ x1 = asarray (x1 )
669
+ x2 = asarray (x2 )
670
+
671
+ units = x1 .units * x2 .units
672
+
673
+ x1_magnitude = xp .asarray (x1 .magnitude , copy = True )
674
+ x2_magnitude = x2 .m_as (x1 .units )
675
+
676
+ magnitude = xp .multiply (x1_magnitude , x2_magnitude , * args , ** kwargs )
677
+ return ArrayUnitQuantity (magnitude , units )
678
+
679
+ mod .multiply = multiply
680
+
654
681
## Indexing Functions
655
682
def take (x , indices , / , ** kwargs ):
656
683
magnitude = xp .take (x .magnitude , indices .magnitude , ** kwargs )
@@ -791,7 +818,7 @@ def var(x, /, *args, **kwargs):
791
818
def prod (x , / , * args , axis = None , ** kwargs ):
792
819
x = asarray (x )
793
820
magnitude = xp .asarray (x .magnitude , copy = True )
794
- exponent = magnitude .shape [axis ] if axis is not None else magnitude . size
821
+ exponent = magnitude .shape [axis ] if axis is not None else size ( magnitude )
795
822
units = x .units ** exponent
796
823
magnitude = xp .prod (magnitude , * args , axis = axis , ** kwargs )
797
824
return ArrayUnitQuantity (magnitude , units )
0 commit comments