|
10 | 10 |
|
11 | 11 | from . import _scalar_types
|
12 | 12 |
|
| 13 | + |
| 14 | +__all__ = ['dtype_from_torch', 'dtype', 'typecodes', 'issubdtype'] |
| 15 | + |
| 16 | + |
13 | 17 | # Define analogs of numpy dtypes supported by pytorch.
|
14 | 18 |
|
15 | 19 | class dtype:
|
@@ -177,6 +181,23 @@ def is_integer(dtyp):
|
177 | 181 | return dtyp.typecode in typecodes['AllInteger']
|
178 | 182 |
|
179 | 183 |
|
| 184 | + |
| 185 | +def issubclass_(arg, klass): |
| 186 | + try: |
| 187 | + return issubclass(arg, klass) |
| 188 | + except TypeError: |
| 189 | + return False |
| 190 | + |
| 191 | + |
| 192 | +def issubdtype(arg1, arg2): |
| 193 | + # cf https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numerictypes.py#L356-L420 |
| 194 | + if not issubclass_(arg1, _scalar_types.generic): |
| 195 | + arg1 = dtype(arg1).type |
| 196 | + if not issubclass_(arg2, _scalar_types.generic): |
| 197 | + arg2 = dtype(arg2).type |
| 198 | + return issubclass(arg1, arg2) |
| 199 | + |
| 200 | + |
180 | 201 | # The casting below is defined *with dtypes only*, so no value-based casting!
|
181 | 202 |
|
182 | 203 | # These two dicts are autogenerated with autogen/gen_dtypes.py,
|
@@ -210,6 +231,3 @@ def is_integer(dtyp):
|
210 | 231 |
|
211 | 232 | ########################## end autogenerated part
|
212 | 233 |
|
213 |
| - |
214 |
| -__all__ = ['dtype_from_torch', 'dtype', 'typecodes'] |
215 |
| - |
|
0 commit comments