Skip to content

Remove rarely used shape utilities #1016

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 38 additions & 59 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
)
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.shape import specify_broadcastable
from pytensor.tensor.shape import Shape_i, specify_broadcastable
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
from pytensor.tensor.variable import TensorVariable
Expand Down Expand Up @@ -1194,23 +1194,22 @@
self.return_index = return_index
self.return_inverse = return_inverse
self.return_counts = return_counts
if axis is not None and axis < 0:
raise ValueError("Axis cannot be negative.")

Check warning on line 1198 in pytensor/tensor/extra_ops.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/extra_ops.py#L1198

Added line #L1198 was not covered by tests
Comment on lines +1197 to +1198
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove this possibility?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it simplifies the logic in the Op. The helper users use pt.unique handles the negative axis and passes a positive one to the Op. Users don't really create Ops themselves

self.axis = axis

def make_node(self, x):
x = ptb.as_tensor_variable(x)
self_axis = self.axis
if self_axis is None:
axis = self.axis
if axis is None:
out_shape = (None,)
else:
if self_axis < 0:
self_axis += x.type.ndim
if self_axis < 0 or self_axis >= x.type.ndim:
if axis >= x.type.ndim:
raise ValueError(
f"Unique axis {self.axis} is outside of input ndim = {x.type.ndim}"
f"Axis {axis} out of range for input {x} with ndim={x.type.ndim}."
)
out_shape = tuple(
s if s == 1 and axis != self_axis else None
for axis, s in enumerate(x.type.shape)
None if dim == axis else s for dim, s in enumerate(x.type.shape)
)

outputs = [TensorType(dtype=x.dtype, shape=out_shape)()]
Expand All @@ -1224,60 +1223,37 @@
return Apply(self, [x], outputs)

def perform(self, node, inputs, output_storage):
x = inputs[0]
z = output_storage
param = {}
if self.return_index:
param["return_index"] = True
if self.return_inverse:
param["return_inverse"] = True
if self.return_counts:
param["return_counts"] = True
if self.axis is not None:
param["axis"] = self.axis
outs = np.unique(x, **param)
if (
(not self.return_inverse)
and (not self.return_index)
and (not self.return_counts)
):
z[0][0] = outs
else:
[x] = inputs
outs = np.unique(
x,
return_index=self.return_index,
return_inverse=self.return_inverse,
return_counts=self.return_counts,
axis=self.axis,
)
if isinstance(outs, tuple):
for i in range(len(outs)):
z[i][0] = outs[i]
output_storage[i][0] = outs[i]
else:
output_storage[0][0] = outs

def infer_shape(self, fgraph, node, i0_shapes):
ret = fgraph.shape_feature.default_infer_shape(fgraph, node, i0_shapes)
if self.axis is not None:
self_axis = self.axis
ndim = len(i0_shapes[0])
if self_axis < 0:
self_axis += ndim
if self_axis < 0 or self_axis >= ndim:
raise RuntimeError(
f"Unique axis `{self.axis}` is outside of input ndim = {ndim}."
)
ret[0] = tuple(
fgraph.shape_feature.shape_ir(i, node.outputs[0]) for i in range(ndim)
)
[x_shape] = i0_shapes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand what is happening in this function, but just to check, shouldn't there be a case for return_index and return_counts as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is. i0_shapes are the input dimensions, so that doesn't change with the number of outputs. return_index/counts are outputs, and they are always vector.

We set out_shapes = [out.shape[0] for out in node.outputs] by default which will always work for return_index and return_counts. Then we have special logic for the main output when axis is not None and for return_inverse which is not just out.shape[0].

Copy link
Member Author

@ricardoV94 ricardoV94 Oct 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The big picture is the function tries to return a graph for the shape of the outputs given the input shapes (and possibly values, which you could retrieve from node.inputs). The default graph of the shape is just output.shape, which we try to avoid when possible, as we would like to avoid computing the Op in order to find out its shape.

For unique we can do that for some of the outputs dimensions, but not all (we only know how many repeated values there are if we evaluate Unique).

This method is combining dims we can know from the input shapes and those that we can only get after we compute the outputs with out.shape[0] or out.shape[x].

shape0_op = Shape_i(0)
out_shapes = [(shape0_op(out),) for out in node.outputs]

axis = self.axis
if axis is not None:
shape = list(x_shape)
shape[axis] = Shape_i(axis)(node.outputs[0])
out_shapes[0] = tuple(shape)

if self.return_inverse:
if self.axis is None:
shape = (prod(i0_shapes[0]),)
else:
shape = (i0_shapes[0][self_axis],)
if self.return_index:
ret[2] = shape
return ret
ret[1] = shape
return ret
return ret

def __setstate__(self, state):
self.__dict__.update(state)
# For backwards compatibility with pickled instances of Unique that
# did not have the axis parameter specified
if "axis" not in state:
self.axis = None
shape = prod(x_shape) if self.axis is None else x_shape[axis]
return_index_out_idx = 2 if self.return_index else 1
out_shapes[return_index_out_idx] = (shape,)

return out_shapes


def unique(
Expand All @@ -1293,6 +1269,9 @@
* the number of times each unique value comes up in the input array

"""
ar = as_tensor_variable(ar)
if axis is not None:
axis = normalize_axis_index(axis, ar.ndim)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Armavica here is where we allow negative axis for the user

return Unique(return_index, return_inverse, return_counts, axis)(ar)


Expand Down
3 changes: 1 addition & 2 deletions pytensor/tensor/rewriting/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
Shape_i,
SpecifyShape,
Unbroadcast,
shape_i,
specify_shape,
unbroadcast,
)
Expand Down Expand Up @@ -1060,7 +1059,7 @@ def local_Shape_of_SpecifyShape(fgraph, node):
# Replace `NoneConst` by `shape_i`
for i, sh in enumerate(shape):
if NoneConst.equals(sh):
shape[i] = shape_i(x, i, fgraph)
shape[i] = x.shape[i]

return [stack(shape).astype(np.int64)]

Expand Down
10 changes: 0 additions & 10 deletions pytensor/tensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,16 +363,6 @@ def recur(node):
return shape(var)[i]


def shape_i_op(i):
key = i
if key not in shape_i_op.cache:
shape_i_op.cache[key] = Shape_i(i)
return shape_i_op.cache[key]


shape_i_op.cache = {} # type: ignore


def register_shape_i_c_code(typ, code, check_input, version=()):
"""
Tell Shape_i how to generate C code for an PyTensor Type.
Expand Down
7 changes: 3 additions & 4 deletions pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
from pytensor.tensor.math import clip
from pytensor.tensor.shape import Reshape, shape_i, specify_broadcastable
from pytensor.tensor.shape import Reshape, Shape_i, specify_broadcastable
from pytensor.tensor.type import (
TensorType,
bscalar,
Expand Down Expand Up @@ -2705,10 +2705,9 @@ def is_bool_index(idx):
index_shapes = []
for idx, ishape in zip(indices, ishapes[1:]):
# Mixed bool indexes are converted to nonzero entries
shape0_op = Shape_i(0)
if is_bool_index(idx):
index_shapes.extend(
(shape_i(nz_dim, 0, fgraph=fgraph),) for nz_dim in nonzero(idx)
)
index_shapes.extend((shape0_op(nz_dim),) for nz_dim in nonzero(idx))
# The `ishapes` entries for `SliceType`s will be None, and
# we need to give `indexed_result_shape` the actual slices.
elif isinstance(getattr(idx, "type", None), SliceType):
Expand Down
10 changes: 3 additions & 7 deletions tests/tensor/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@
from pytensor.compile.ops import DeepCopyOp
from pytensor.configdefaults import config
from pytensor.graph.basic import Variable, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace, vectorize_node
from pytensor.graph.type import Type
from pytensor.misc.safe_asarray import _asarray
from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row
from pytensor.tensor.basic import MakeVector, constant, stack
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.shape import (
Reshape,
Shape,
Expand All @@ -26,7 +24,6 @@
_specify_shape,
reshape,
shape,
shape_i,
shape_tuple,
specify_broadcastable,
specify_shape,
Expand Down Expand Up @@ -633,13 +630,12 @@ def test_nonstandard_shapes():
tl_shape = shape(tl)
assert np.array_equal(tl_shape.get_test_value(), (2, 2, 3, 4))

# There's no `FunctionGraph`, so it should return a `Subtensor`
tl_shape_i = shape_i(tl, 0)
# Test specific dim
tl_shape_i = shape(tl)[0]
assert isinstance(tl_shape_i.owner.op, Subtensor)
assert tl_shape_i.get_test_value() == 2

tl_fg = FunctionGraph([a, b], [tl], features=[ShapeFeature()])
tl_shape_i = shape_i(tl, 0, fgraph=tl_fg)
tl_shape_i = Shape_i(0)(tl)
assert not isinstance(tl_shape_i.owner.op, Subtensor)
assert tl_shape_i.get_test_value() == 2

Expand Down
Loading