Skip to content

port scalarmath tests #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions torch_np/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __repr__(self):

__str__ = __repr__

@property
def itemsize(self):
elem = self.type(1)
return elem.get().element_size()
Expand Down Expand Up @@ -193,6 +194,8 @@ def torch_dtype_from(dtyp):
raise TypeError


# ### Defaults and dtype discovery

def default_int_type():
return dtype('int64')

Expand All @@ -201,14 +204,33 @@ def default_float_type():
return dtype('float64')


def default_complex_type():
return dtype('complex128')


def is_floating(dtyp):
dtyp = dtype(dtyp)
return dtyp.typecode in typecodes['AllFloat']
return issubclass(dtyp.type, _scalar_types.floating)


def is_integer(dtyp):
dtyp = dtype(dtyp)
return dtyp.typecode in typecodes['AllInteger']

return issubclass(dtyp.type, _scalar_types.integer)


def get_default_dtype_for(dtyp):
typ = dtype(dtyp).type
if issubclass(typ, _scalar_types.integer):
result = default_int_type()
elif issubclass(typ, _scalar_types.floating):
result = default_float_type()
elif issubclass(typ, _scalar_types.complexfloating):
result = default_complex_type()
elif issubclass(typ, _scalar_types.bool_):
result = dtype('bool')
else:
raise TypeError("dtype %s not understood." % dtyp)
return result


def issubclass_(arg, klass):
Expand Down
165 changes: 117 additions & 48 deletions torch_np/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ def base(self):
def T(self):
return self.transpose()

@property
def real(self):
return asarray(self._tensor.real)

@property
def imag(self):
try:
return asarray(self._tensor.imag)
except RuntimeError:
zeros = torch.zeros_like(self._tensor)
return ndarray._from_tensor_and_base(zeros, None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not simply return asarray(zeros)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point we'll need to rationalize these two forms, agree.
Basically, asarray is anything array-like in, array out; here we explicitly construct the tensor, so my fingers naturally typed this line. The line above, with asarray(self._tensor.imag), should be changed to follow line 94.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, asarray(Tensor) has the same semantics as _from_tensor_and_base(Tensor, None), so we can decide to always prefer the first one over the latter one for conciseness and consistency.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not always

In [2]: a = np.zeros(3)

In [3]: np.asarray(a) is a
Out[3]: True

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that I suggested doing so when using torch.Tensors.


# ctors
def astype(self, dtype):
newt = ndarray()
Expand All @@ -102,6 +114,13 @@ def __str__(self):

### comparisons ###
def __eq__(self, other):
try:
t_other = asarray(other).get
except RuntimeError:
# Failed to convert other to array: definitely not equal.
# TODO: generalize, delegate to ufuncs
falsy = torch.full(self.shape, fill_value=False, dtype=bool)
return asarray(falsy)
return asarray(self._tensor == asarray(other).get())

def __neq__(self, other):
Expand All @@ -119,7 +138,6 @@ def __ge__(self, other):
def __le__(self, other):
return asarray(self._tensor <= asarray(other).get())


def __bool__(self):
try:
return bool(self._tensor)
Expand All @@ -141,6 +159,9 @@ def __hash__(self):
def __float__(self):
return float(self._tensor)

def __int__(self):
return int(self._tensor)

# XXX : are single-element ndarrays scalars?
def is_integer(self):
if self.shape == ():
Expand All @@ -167,7 +188,10 @@ def __iadd__(self, other):

def __sub__(self, other):
other_tensor = asarray(other).get()
return asarray(self._tensor.__sub__(other_tensor))
try:
return asarray(self._tensor.__sub__(other_tensor))
except RuntimeError as e:
raise TypeError(e.args)

def __mul__(self, other):
other_tensor = asarray(other).get()
Expand All @@ -177,10 +201,30 @@ def __rmul__(self, other):
other_tensor = asarray(other).get()
return asarray(self._tensor.__rmul__(other_tensor))

def __floordiv__(self, other):
other_tensor = asarray(other).get()
return asarray(self._tensor.__floordiv__(other_tensor))

def __ifloordiv__(self, other):
other_tensor = asarray(other).get()
return asarray(self._tensor.__ifloordiv__(other_tensor))

def __truediv__(self, other):
other_tensor = asarray(other).get()
return asarray(self._tensor.__truediv__(other_tensor))

def __itruediv__(self, other):
other_tensor = asarray(other).get()
return asarray(self._tensor.__itruediv__(other_tensor))

def __mod__(self, other):
other_tensor = asarray(other).get()
return asarray(self._tensor.__mod__(other_tensor))

def __imod__(self, other):
other_tensor = asarray(other).get()
return asarray(self._tensor.__imod__(other_tensor))

def __or__(self, other):
other_tensor = asarray(other).get()
return asarray(self._tensor.__or__(other_tensor))
Expand All @@ -189,10 +233,22 @@ def __ior__(self, other):
other_tensor = asarray(other).get()
return asarray(self._tensor.__ior__(other_tensor))


def __invert__(self):
return asarray(self._tensor.__invert__())

def __abs__(self):
return asarray(self._tensor.__abs__())

def __neg__(self):
try:
return asarray(self._tensor.__neg__())
except RuntimeError as e:
raise TypeError(e.args)

def __pow__(self, exponent):
exponent_tensor = asarray(exponent).get()
return asarray(self._tensor.__pow__(exponent_tensor))

### methods to match namespace functions

def squeeze(self, axis=None):
Expand Down Expand Up @@ -301,7 +357,7 @@ def mean(self, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoVal

if dtype is None:
dtype = self.dtype
if not _dtypes.is_floating(dtype):
if _dtypes.is_integer(dtype):
dtype = _dtypes.default_float_type()
torch_dtype = _dtypes.torch_dtype_from(dtype)

Expand All @@ -321,7 +377,7 @@ def sum(self, axis=None, dtype=None, out=None, keepdims=NoValue,

if dtype is None:
dtype = self.dtype
if not _dtypes.is_floating(dtype):
if _dtypes.is_integer(dtype):
dtype = _dtypes.default_float_type()
torch_dtype = _dtypes.torch_dtype_from(dtype)

Expand All @@ -343,67 +399,80 @@ def __setitem__(self, index, value):
return self._tensor.__setitem__(index, value)


def asarray(a, dtype=None, order=None, *, like=None):
_util.subok_not_ok(like)
if order is not None:
# This is the ideally the only place which talks to ndarray directly.
# The rest goes through asarray (preferred) or array.

def array(object, dtype=None, *, copy=True, order='K', subok=False, ndmin=0,
like=None):
_util.subok_not_ok(like, subok)
if order != 'K':
raise NotImplementedError

if isinstance(a, ndarray):
if dtype is not None and dtype != a.dtype:
a = a.astype(dtype)
return a
# a happy path
if isinstance(object, ndarray):
if copy is False and dtype is None and ndmin <= object.ndim:
return object

if isinstance(a, (list, tuple)):
# handle lists of ndarrays, [1, [2, 3], ndarray(4)] etc
# lists of ndarrays: [1, [2, 3], ndarray(4)] convert to lists of lists
if isinstance(object, (list, tuple)):
a1 = []
for elem in a:
for elem in object:
if isinstance(elem, ndarray):
a1.append(elem.get().tolist())
else:
a1.append(elem)
object = a1

# get the tensor from "object"
if isinstance(object, ndarray):
tensor = object._tensor
base = object
elif isinstance(object, torch.Tensor):
tensor = object
base = None
else:
a1 = a
tensor = torch.as_tensor(object)
base = None

torch_dtype = _dtypes.torch_dtype_from(dtype)
# At this point, `tensor.dtype` is the pytorch default. Our default may
# differ, so need to typecast. However, we cannot just do `tensor.to`,
# because if our desired dtype is wider then pytorch's, `tensor`
# may have lost precision:

# This and array(...) are the only places which talk to ndarray directly.
# The rest goes through asarray (preferred) or array.
out = ndarray()
tt = torch.as_tensor(a1, dtype=torch_dtype)
out._tensor = tt
return out
# int(torch.as_tensor(1e12)) - 1e12 equals -4096 (try it!)

# Therefore, we treat `tensor.dtype` as a hint, and convert the
# original object *again*, this time with an explicit dtype.
dtyp = _dtypes.dtype_from_torch(tensor.dtype)
default = _dtypes.get_default_dtype_for(dtyp)
torch_dtype = _dtypes.torch_dtype_from(default)
Comment on lines +446 to +448
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we may want to overload get_default_dtype_for to also accept PyTorch dtypes to simplify this code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I get the feeling that if PyTorch defaults are the same as those in NumPy (via set_default...) simply doing as_tensor should do the trick. Is this possible?
Otherwise, as you mention above, we should make sure we just call as_tensor from here, as otherwise we'll inadvertently get incorrect results.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's maybe not rely on this, not just yet at least. When we have a more complete coverage, let's experiment with how much we can peel out.


def array(object, dtype=None, *, copy=True, order='K', subok=False, ndmin=0,
like=None):
_util.subok_not_ok(like, subok)
if order != 'K':
raise NotImplementedError

if isinstance(object, (list, tuple)):
obj = asarray(object)
return array(obj, dtype, copy=copy, order=order, subok=subok,
ndmin=ndmin, like=like)
tensor = torch.as_tensor(object, dtype=torch_dtype)

if isinstance(object, ndarray):
result = object._tensor

if dtype != object.dtype:
torch_dtype = _dtypes.torch_dtype_from(dtype)
result = result.to(torch_dtype)
else:
# type cast if requested
if dtype is not None:
torch_dtype = _dtypes.torch_dtype_from(dtype)
result = torch.as_tensor(object, dtype=torch_dtype)
tensor = tensor.to(torch_dtype)
base = None

# adjust ndim if needed
ndim_extra = ndmin - tensor.ndim
if ndim_extra > 0:
tensor = tensor.view((1,)*ndim_extra + tensor.shape)
base = None

# copy if requested
if copy:
result = result.clone()
tensor = tensor.clone()
base = None

ndim_extra = ndmin - result.ndim
if ndim_extra > 0:
result = result.reshape((1,)*ndim_extra + result.shape)
out = ndarray()
out._tensor = result
return out
return ndarray._from_tensor_and_base(tensor, base)


def asarray(a, dtype=None, order=None, *, like=None):
if order is None:
order = 'K'
return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0)



Expand Down
12 changes: 9 additions & 3 deletions torch_np/_scalar_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ def __new__(self, value):
if isinstance(value, _ndarray.ndarray):
tensor = value.get()
else:
tensor = torch.as_tensor(value, dtype=torch_dtype)
try:
tensor = torch.as_tensor(value, dtype=torch_dtype)
except RuntimeError as e:
if "Overflow" in str(e):
raise OverflowError(e.args)
raise e
#
# With numpy:
# >>> a = np.ones(3)
Expand Down Expand Up @@ -135,6 +140,7 @@ class bool_(generic):
half = float16
single = float32
double = float64
float_ = float64

csingle = complex64
cdouble = complex128
Expand Down Expand Up @@ -169,8 +175,8 @@ class bool_(generic):
__all__ = list(_typemap.keys())
__all__.remove('bool')

__all__ += ['bool_', 'intp', 'int_', 'intc', 'byte', 'short', 'longlong', 'ubyte', 'half', 'single', 'double',
'csingle', 'cdouble']
__all__ += ['bool_', 'intp', 'int_', 'intc', 'byte', 'short', 'longlong',
'ubyte', 'half', 'single', 'double', 'csingle', 'cdouble', 'float_']
__all__ += ['sctypes']
__all__ += ['generic', 'number',
'integer', 'signedinteger', 'unsignedinteger',
Expand Down
17 changes: 9 additions & 8 deletions torch_np/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,11 @@ def argwhere(a):
return asarray(torch.argwhere(tensor))


def abs(a):
# FIXME: should go the other way, together with other ufuncs
arr = asarray(a)
return a.__abs__()

from ._ndarray import axis_out_keepdims_wrapper

@axis_out_keepdims_wrapper
Expand Down Expand Up @@ -702,18 +707,14 @@ def angle(z, deg=False):
return result


@asarray_replacer()
def real(a):
return torch.real(a)
arr = asarray(a)
return arr.real


@asarray_replacer()
def imag(a):
# torch.imag raises on real-valued inputs
if torch.is_complex(a):
return torch.imag(a)
else:
return torch.zeros_like(a)
arr = asarray(a)
return arr.imag


@asarray_replacer()
Expand Down
1 change: 1 addition & 0 deletions torch_np/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .utils import (assert_equal, assert_array_equal, assert_almost_equal,
assert_warns, assert_)
from .utils import _gen_alignment_data

from .testing import assert_allclose # FIXME

Expand Down
Loading