-
Notifications
You must be signed in to change notification settings - Fork 135
Remove rarely used shape utilities #1016
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,7 +41,7 @@ | |
) | ||
from pytensor.tensor.math import max as pt_max | ||
from pytensor.tensor.math import sum as pt_sum | ||
from pytensor.tensor.shape import specify_broadcastable | ||
from pytensor.tensor.shape import Shape_i, specify_broadcastable | ||
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor | ||
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector | ||
from pytensor.tensor.variable import TensorVariable | ||
|
@@ -1194,23 +1194,22 @@ | |
self.return_index = return_index | ||
self.return_inverse = return_inverse | ||
self.return_counts = return_counts | ||
if axis is not None and axis < 0: | ||
raise ValueError("Axis cannot be negative.") | ||
self.axis = axis | ||
|
||
def make_node(self, x): | ||
x = ptb.as_tensor_variable(x) | ||
self_axis = self.axis | ||
if self_axis is None: | ||
axis = self.axis | ||
if axis is None: | ||
out_shape = (None,) | ||
else: | ||
if self_axis < 0: | ||
self_axis += x.type.ndim | ||
if self_axis < 0 or self_axis >= x.type.ndim: | ||
if axis >= x.type.ndim: | ||
raise ValueError( | ||
f"Unique axis {self.axis} is outside of input ndim = {x.type.ndim}" | ||
f"Axis {axis} out of range for input {x} with ndim={x.type.ndim}." | ||
) | ||
out_shape = tuple( | ||
s if s == 1 and axis != self_axis else None | ||
for axis, s in enumerate(x.type.shape) | ||
None if dim == axis else s for dim, s in enumerate(x.type.shape) | ||
) | ||
|
||
outputs = [TensorType(dtype=x.dtype, shape=out_shape)()] | ||
|
@@ -1224,60 +1223,37 @@ | |
return Apply(self, [x], outputs) | ||
|
||
def perform(self, node, inputs, output_storage): | ||
x = inputs[0] | ||
z = output_storage | ||
param = {} | ||
if self.return_index: | ||
param["return_index"] = True | ||
if self.return_inverse: | ||
param["return_inverse"] = True | ||
if self.return_counts: | ||
param["return_counts"] = True | ||
if self.axis is not None: | ||
param["axis"] = self.axis | ||
outs = np.unique(x, **param) | ||
if ( | ||
(not self.return_inverse) | ||
and (not self.return_index) | ||
and (not self.return_counts) | ||
): | ||
z[0][0] = outs | ||
else: | ||
[x] = inputs | ||
outs = np.unique( | ||
x, | ||
return_index=self.return_index, | ||
return_inverse=self.return_inverse, | ||
return_counts=self.return_counts, | ||
axis=self.axis, | ||
) | ||
if isinstance(outs, tuple): | ||
for i in range(len(outs)): | ||
z[i][0] = outs[i] | ||
output_storage[i][0] = outs[i] | ||
else: | ||
output_storage[0][0] = outs | ||
|
||
def infer_shape(self, fgraph, node, i0_shapes): | ||
ret = fgraph.shape_feature.default_infer_shape(fgraph, node, i0_shapes) | ||
if self.axis is not None: | ||
self_axis = self.axis | ||
ndim = len(i0_shapes[0]) | ||
if self_axis < 0: | ||
self_axis += ndim | ||
if self_axis < 0 or self_axis >= ndim: | ||
raise RuntimeError( | ||
f"Unique axis `{self.axis}` is outside of input ndim = {ndim}." | ||
) | ||
ret[0] = tuple( | ||
fgraph.shape_feature.shape_ir(i, node.outputs[0]) for i in range(ndim) | ||
) | ||
[x_shape] = i0_shapes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure I understand what is happening in this function, but just to check, shouldn't there be a case for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is. We set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The big picture is the function tries to return a graph for the shape of the outputs given the input shapes (and possibly values, which you could retrieve from node.inputs). The default graph of the shape is just output.shape, which we try to avoid when possible, as we would like to avoid computing the Op in order to find out its shape. For This method is combining dims we can know from the input shapes and those that we can only get after we compute the outputs with |
||
shape0_op = Shape_i(0) | ||
out_shapes = [(shape0_op(out),) for out in node.outputs] | ||
|
||
axis = self.axis | ||
if axis is not None: | ||
shape = list(x_shape) | ||
shape[axis] = Shape_i(axis)(node.outputs[0]) | ||
out_shapes[0] = tuple(shape) | ||
|
||
if self.return_inverse: | ||
if self.axis is None: | ||
shape = (prod(i0_shapes[0]),) | ||
else: | ||
shape = (i0_shapes[0][self_axis],) | ||
if self.return_index: | ||
ret[2] = shape | ||
return ret | ||
ret[1] = shape | ||
return ret | ||
return ret | ||
|
||
def __setstate__(self, state): | ||
self.__dict__.update(state) | ||
# For backwards compatibility with pickled instances of Unique that | ||
# did not have the axis parameter specified | ||
if "axis" not in state: | ||
self.axis = None | ||
shape = prod(x_shape) if self.axis is None else x_shape[axis] | ||
return_index_out_idx = 2 if self.return_index else 1 | ||
out_shapes[return_index_out_idx] = (shape,) | ||
|
||
return out_shapes | ||
|
||
|
||
def unique( | ||
|
@@ -1293,6 +1269,9 @@ | |
* the number of times each unique value comes up in the input array | ||
|
||
""" | ||
ar = as_tensor_variable(ar) | ||
if axis is not None: | ||
axis = normalize_axis_index(axis, ar.ndim) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Armavica here is where we allow negative axis for the user |
||
return Unique(return_index, return_inverse, return_counts, axis)(ar) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why remove this possibility?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because it simplifies the logic in the Op. The helper users use
pt.unique
handles the negative axis and passes a positive one to the Op. Users don't really create Ops themselves