Skip to content

Commit 8eec0c1

Browse files
committed
ENH: add np.issubdtype checker to mimic numpy
1 parent 9a92ec3 commit 8eec0c1

File tree

1 file changed

+21
-3
lines changed

1 file changed

+21
-3
lines changed

Diff for: torch_np/_dtypes.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010

1111
from . import _scalar_types
1212

13+
14+
__all__ = ['dtype_from_torch', 'dtype', 'typecodes', 'issubdtype']
15+
16+
1317
# Define analogs of numpy dtypes supported by pytorch.
1418

1519
class dtype:
@@ -177,6 +181,23 @@ def is_integer(dtyp):
177181
return dtyp.typecode in typecodes['AllInteger']
178182

179183

184+
185+
def issubclass_(arg, klass):
186+
try:
187+
return issubclass(arg, klass)
188+
except TypeError:
189+
return False
190+
191+
192+
def issubdtype(arg1, arg2):
193+
# cf https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numerictypes.py#L356-L420
194+
if not issubclass_(arg1, _scalar_types.generic):
195+
arg1 = dtype(arg1).type
196+
if not issubclass_(arg2, _scalar_types.generic):
197+
arg2 = dtype(arg2).type
198+
return issubclass(arg1, arg2)
199+
200+
180201
# The casting below is defined *with dtypes only*, so no value-based casting!
181202

182203
# These two dicts are autogenerated with autogen/gen_dtypes.py,
@@ -210,6 +231,3 @@ def is_integer(dtyp):
210231

211232
########################## end autogenerated part
212233

213-
214-
__all__ = ['dtype_from_torch', 'dtype', 'typecodes']
215-

0 commit comments

Comments
 (0)