Skip to content

Smoke valid args for binary ufunc tests #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
08cc8ed
ENH: introduce scalar type hierarchy
ev-br Jan 3, 2023
fd4a4d9
TST: undo (some) test skips of test with scalars
ev-br Jan 3, 2023
5b4d716
ENH: add np.issubdtype checker to mimic numpy
ev-br Jan 3, 2023
c5f4949
ENH: introduce scalar type hierarchy
ev-br Jan 3, 2023
e1fa959
TST: undo (some) test skips of test with scalars
ev-br Jan 3, 2023
5c2e6f9
ENH: add np.issubdtype checker to mimic numpy
ev-br Jan 3, 2023
7622143
MAINT: adapt assert_equal, assert_array_equal
ev-br Jan 4, 2023
0a391da
TST: fix test_scalar_ctors from numpy
ev-br Jan 4, 2023
1c8900e
MAINT: test_scalar_methods from numpy
ev-br Jan 5, 2023
43a894d
MAINT: numpy-vendored tests get through the collection stage
ev-br Jan 5, 2023
5c9adde
MAINT: multiple assorted fixes to make numpy tests pass
ev-br Jan 5, 2023
c0d5113
BUG: np.asarray(arr) returns arr not a copy
ev-br Jan 5, 2023
09ce7e0
BUG: fix import in test_ufunc_basic
ev-br Jan 5, 2023
0b5e9a7
API: add tests to stipulate equivalence of arrays scalars and 0D arrays
ev-br Jan 5, 2023
5be93f1
TST: test_numerictypes: remove definitely unsupported things
ev-br Jan 5, 2023
a07fab6
BUG: fix the scalar type hierarchy, so that issubdtype works.
ev-br Jan 5, 2023
eec3bba
ENH: add dtype.itemsize, rm a buch of tests of timedelta, dtype(str) …
ev-br Jan 5, 2023
adf9c73
ENH: dtypes pickle/unpickle
ev-br Jan 6, 2023
2d7d932
TST: test_dtype from NumPy passes (with skips/fails, of course)
ev-br Jan 6, 2023
6993215
ENH: add iinfo, finfo
ev-br Jan 6, 2023
2830ada
MAINT: update .gitignore
ev-br Jan 6, 2023
91b3cc6
Rudiementary autogen binary ufuncs input type fix
honno Jan 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
__pycache__/*
autogen/__pycache__
torch_np/__pycache__/*
torch_np/tests/__pycache__/*
torch_np/tests/numpy_tests/core/__pycache__/*
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
.coverage

11 changes: 7 additions & 4 deletions autogen/gen_ufuncs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from dump_namespace import grab_namespace, get_signature
from collections import defaultdict
from warnings import warn
from .dump_namespace import grab_namespace, get_signature

import numpy as np

Expand Down Expand Up @@ -138,7 +140,7 @@ def test_{np_name}():



test_header = header + """\
test_header = """\
import numpy as np
import torch

Expand Down Expand Up @@ -168,14 +170,15 @@ def {np_name}(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K'
test_template = """

def test_{np_name}():
assert_allclose(np.{np_name}(0.5, 0.6),
{np_name}(0.5, 0.6), atol=1e-7, check_dtype=False)
assert_allclose(np.{np_name}({args}),
np.{np_name}({args}), atol=1e-7, check_dtype=False)

"""



skip = {np.divmod, # two outputs
np.matmul, # array inputs
}


Expand Down
7 changes: 5 additions & 2 deletions torch_np/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from ._dtypes import *
from ._scalar_types import *
from ._wrapper import *
from . import testing
#from . import testing

from ._unary_ufuncs import *
from ._binary_ufuncs import *
from ._ndarray import can_cast, result_type, newaxis
from ._util import AxisError

from ._getlimits import iinfo, finfo
from ._getlimits import errstate

inf = float('inf')
nan = float('nan')

83 changes: 62 additions & 21 deletions torch_np/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,20 @@
import builtins
import torch

from . import _scalar_types


__all__ = ['dtype_from_torch', 'dtype', 'typecodes', 'issubdtype']


# Define analogs of numpy dtypes supported by pytorch.

class dtype:
def __init__(self, name):
def __init__(self, name, /):
if isinstance(name, dtype):
_name = name.name
elif hasattr(name, 'dtype'):
_name = name.dtype.name
elif name in python_types_dict:
_name = python_types_dict[name]
elif name in dt_names:
Expand All @@ -22,6 +30,9 @@ def __init__(self, name):
_name = typecode_chars_dict[name]
elif name in dt_aliases_dict:
_name = dt_aliases_dict[name]
# the check must come last, so that 'name' is not a string
elif issubclass(name, _scalar_types.generic):
_name = name.name
else:
raise TypeError(f"data type '{name}' not understood")
self._name = _name
Expand All @@ -30,6 +41,10 @@ def __init__(self, name):
def name(self):
return self._name

@property
def type(self):
return _scalar_types._typemap[self._name]

@property
def typecode(self):
return _typecodes_from_dtype_dict[self._name]
Expand All @@ -38,14 +53,31 @@ def __eq__(self, other):
if isinstance(other, dtype):
return self._name == other.name
else:
other_instance = dtype(other)
try:
other_instance = dtype(other)
except TypeError:
return False
Copy link
Collaborator

@ev-br ev-br Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why it is better than raising? (Uhm, that was me. Meaning this code is not right, long tern)

return self._name == other_instance.name

def __hash__(self):
return hash(self._name)

def __repr__(self):
return f'dtype("{self.name}")'

__str__ = __repr__

def itemsize(self):
elem = self.type(1)
return elem.get().element_size()

def __getstate__(self):
return self._name

def __setstate__(self, value):
self._name = value



dt_names = ['float16', 'float32', 'float64',
'complex64', 'complex128',
Expand All @@ -58,6 +90,7 @@ def __repr__(self):


dt_aliases_dict = {
'u1' : 'uint8',
'i1' : 'int8',
'i2' : 'int16',
'i4' : 'int32',
Expand All @@ -75,7 +108,12 @@ def __repr__(self):
python_types_dict = {
int: 'int64',
float: 'float64',
builtins.bool: 'bool'
complex: 'complex128',
builtins.bool: 'bool',
# also allow stringified names of python types
int.__name__ : 'int64',
float.__name__ : 'float64',
complex.__name__: 'complex128',
}


Expand All @@ -101,24 +139,13 @@ def __repr__(self):
typecodes = {'All': 'efdFDBbhil?',
'AllFloat': 'efdFD',
'AllInteger': 'Bbhil',
'Integer': 'bhil',
'UnsignedInteger': 'B',
'Float': 'efd',
'Complex': 'FD',
}


float16 = dtype("float16")
float32 = dtype("float32")
float64 = dtype("float64")
complex64 = dtype("complex64")
complex128 = dtype("complex128")
uint8 = dtype("uint8")
int8 = dtype("int8")
int16 = dtype("int16")
int32 = dtype("int32")
int64 = dtype("int64")
bool = dtype("bool")

intp = int64 # XXX
int_ = int64

# Map the torch-suppored subset dtypes to local analogs
# "quantized" types not available in numpy, skip
_dtype_from_torch_dict = {
Expand Down Expand Up @@ -183,6 +210,23 @@ def is_integer(dtyp):
return dtyp.typecode in typecodes['AllInteger']



def issubclass_(arg, klass):
try:
return issubclass(arg, klass)
except TypeError:
return False


def issubdtype(arg1, arg2):
# cf https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numerictypes.py#L356-L420
if not issubclass_(arg1, _scalar_types.generic):
arg1 = dtype(arg1).type
if not issubclass_(arg2, _scalar_types.generic):
arg2 = dtype(arg2).type
return issubclass(arg1, arg2)


# The casting below is defined *with dtypes only*, so no value-based casting!

# These two dicts are autogenerated with autogen/gen_dtypes.py,
Expand Down Expand Up @@ -216,6 +260,3 @@ def is_integer(dtyp):

########################## end autogenerated part


__all__ = ['dtype_from_torch', 'dtype', 'typecodes'] + dt_names + ['intp', 'int_']

19 changes: 19 additions & 0 deletions torch_np/_getlimits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch
from . import _dtypes

def finfo(dtyp):
torch_dtype = _dtypes.torch_dtype_from(dtyp)
return torch.finfo(torch_dtype)


def iinfo(dtyp):
torch_dtype = _dtypes.torch_dtype_from(dtyp)
return torch.iinfo(torch_dtype)


import contextlib

# FIXME: this is only a stub
@contextlib.contextmanager
def errstate(*args, **kwds):
yield
6 changes: 6 additions & 0 deletions torch_np/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,9 @@ def allow_only_single_axis(axis):
if len(axis) != 1:
raise NotImplementedError("does not handle tuple axis")
return axis[0]


def to_tensors(*inputs):
"""Convert all ndarrays from `inputs` to tensors."""
return tuple([value.get() if isinstance(value, ndarray) else value
for value in inputs])
37 changes: 36 additions & 1 deletion torch_np/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,16 @@ def __neq__(self, other):
def __gt__(self, other):
return asarray(self._tensor > asarray(other).get())

def __lt__(self, other):
return asarray(self._tensor < asarray(other).get())

def __ge__(self, other):
return asarray(self._tensor >= asarray(other).get())

def __le__(self, other):
return asarray(self._tensor <= asarray(other).get())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NB: this will need to be redone similar to gh-17



def __bool__(self):
try:
return bool(self._tensor)
Expand All @@ -131,6 +141,15 @@ def __hash__(self):
def __float__(self):
return float(self._tensor)

# XXX : are single-element ndarrays scalars?
def is_integer(self):
if self.shape == ():
if _dtypes.is_integer(self.dtype):
return True
return self._tensor.item().is_integer()
else:
return False


### sequence ###
def __len__(self):
Expand Down Expand Up @@ -162,6 +181,15 @@ def __truediv__(self, other):
other_tensor = asarray(other).get()
return asarray(self._tensor.__truediv__(other_tensor))

def __or__(self, other):
other_tensor = asarray(other).get()
return asarray(self._tensor.__or__(other_tensor))

def __ior__(self, other):
other_tensor = asarray(other).get()
return asarray(self._tensor.__ior__(other_tensor))


def __invert__(self):
return asarray(self._tensor.__invert__())

Expand Down Expand Up @@ -307,7 +335,8 @@ def sum(self, axis=None, dtype=None, out=None, keepdims=NoValue,

### indexing ###
def __getitem__(self, *args, **kwds):
return ndarray._from_tensor_and_base(self._tensor.__getitem__(*args, **kwds), self)
t_args = _helpers.to_tensors(*args)
return ndarray._from_tensor_and_base(self._tensor.__getitem__(*t_args, **kwds), self)

def __setitem__(self, index, value):
value = asarray(value).get()
Expand All @@ -320,6 +349,8 @@ def asarray(a, dtype=None, order=None, *, like=None):
raise NotImplementedError

if isinstance(a, ndarray):
if dtype is not None and dtype != a.dtype:
a = a.astype(dtype)
return a

if isinstance(a, (list, tuple)):
Expand Down Expand Up @@ -356,6 +387,10 @@ def array(object, dtype=None, *, copy=True, order='K', subok=False, ndmin=0,

if isinstance(object, ndarray):
result = object._tensor

if dtype != object.dtype:
torch_dtype = _dtypes.torch_dtype_from(dtype)
result = result.to(torch_dtype)
else:
torch_dtype = _dtypes.torch_dtype_from(dtype)
result = torch.as_tensor(object, dtype=torch_dtype)
Expand Down
Loading