Skip to content

Propagate static output shapes in Split and avoid copy in C-impl #1343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions pytensor/link/numba/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,7 @@ def join(axis, *tensors):
def numba_funcify_Split(op, **kwargs):
@numba_basic.numba_njit
def split(tensor, axis, indices):
# Work around for https://github.com/numba/numba/issues/8257
axis = axis % tensor.ndim
axis = numba_basic.to_scalar(axis)
return np.split(tensor, np.cumsum(indices)[:-1], axis=axis)
return np.split(tensor, np.cumsum(indices)[:-1], axis=axis.item())

return split

Expand Down
201 changes: 92 additions & 109 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2171,8 +2171,6 @@
array([3, 4])
>>> c
array([5])

TODO: Don't make a copy in C impl
"""

len_splits = None
Expand Down Expand Up @@ -2201,26 +2199,46 @@
raise TypeError("`axis` parameter must be an integer scalar")

inputs = [x, axis, splits]
out_type = TensorType(dtype=x.dtype, shape=(None,) * x.type.ndim)
outputs = [out_type() for i in range(self.len_splits)]

x_dtype = x.type.dtype
if isinstance(axis, Constant):
# In this case we can preserve more static shape info
static_axis = axis.data.item()
outputs = []
x_static_shape = list(x.type.shape)
for i in range(self.len_splits):
try:
static_split_size = int(get_scalar_constant_value(splits[i]))
except NotScalarConstantError:
static_split_size = None
except IndexError:
raise ValueError("Number of splits is larger than splits size")

Check warning on line 2215 in pytensor/tensor/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/basic.py#L2214-L2215

Added lines #L2214 - L2215 were not covered by tests
static_out_shape = x_static_shape.copy()
static_out_shape[static_axis] = static_split_size
outputs.append(tensor(shape=tuple(static_out_shape), dtype=x_dtype))
else:
outputs = [
tensor(shape=(None,) * x.type.ndim, dtype=x_dtype)
for i in range(self.len_splits)
]

return Apply(self, inputs, outputs)

def perform(self, node, inputs, outputs):
def perform(self, node, inputs, outputs_storage):
x, axis, splits = inputs

if len(splits) != self.len_splits:
raise ValueError("Length of splits is not equal to n_splits")
if np.sum(splits) != x.shape[axis]:
if splits.sum() != x.shape[axis]:
raise ValueError(
f"Split sizes sum to {np.sum(splits)}; expected {x.shape[axis]}"
f"Split sizes sum to {splits.sum()}; expected {x.shape[axis]}"
)
if np.any(splits < 0):
if (splits < 0).any():
raise ValueError("Split sizes cannot be negative")

split_outs = np.split(x, np.cumsum(splits[:-1]), axis=axis)
for i, out in enumerate(split_outs):
outputs[i][0] = out
for out_storage, out in zip(outputs_storage, split_outs, strict=False):
out_storage[0] = out

def infer_shape(self, fgraph, node, in_shapes):
axis = node.inputs[1]
Expand All @@ -2234,10 +2252,10 @@
out_shapes.append(temp)
return out_shapes

def grad(self, inputs, g_outputs):
def L_op(self, inputs, outputs, g_outputs):
"""Join the gradients along the axis that was used to split x."""
x, axis, n = inputs
outputs = self(*inputs, return_list=True)

# If all the output gradients are disconnected, then so are the inputs
if builtins.all(isinstance(g.type, DisconnectedType) for g in g_outputs):
return [
Expand Down Expand Up @@ -2265,142 +2283,107 @@
return self.make_node(eval_points[0], *inputs[1:]).outputs

def c_code_cache_version(self):
return (2,)

def c_support_code(self, **kwargs):
return """
/* Return 1 if output has the correct shape. */
int split_output_shape_is_correct (
PyArrayObject* output, PyArrayObject* array_to_split, int axis_to_split, npy_intp split_size
) {
return
PyArray_NDIM(output) == PyArray_NDIM(array_to_split)
&& memcmp(
PyArray_DIMS(output),
PyArray_DIMS(array_to_split),
axis_to_split * sizeof(npy_intp)
) == 0
&& memcmp(
PyArray_DIMS(output) + axis_to_split + 1,
PyArray_DIMS(array_to_split) + axis_to_split + 1,
(PyArray_NDIM(array_to_split) - axis_to_split - 1) * sizeof(npy_intp)
) == 0
&& split_size == PyArray_DIM(output, axis_to_split);
}
"""
return (3,)

def c_code(self, node, name, inputs, outputs, sub):
if self.len_splits == 0:
# There are no outputs, then nothing to do.
return ""
# This would be a view Op, anyway shouldn't be triggered
raise NotImplementedError()

Check warning on line 2291 in pytensor/tensor/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/basic.py#L2291

Added line #L2291 was not covered by tests

# outputs_pointers lists the addresses of the pointers to the outputs.
outputs_pointers = "&" + (", &".join(outputs))
x, axis, splits = inputs
fail = sub["fail"]
x_typenum = np.dtype(node.inputs[0].dtype).num
x_itemsize = np.dtype(node.inputs[0].dtype).itemsize
axis_dtype = node.inputs[1].type.dtype_specs()[1]
splits_dtype = node.inputs[2].type.dtype_specs()[1]
expected_splits_count = self.len_splits
len_splits = self.len_splits
ndim = node.inputs[0].type.ndim

# Most times axis is constant, inline it
# This is safe to do because the hash of the c_code includes the constant signature
if isinstance(node.inputs[1], Constant):
static_axis = int(node.inputs[1].data)
static_axis = normalize_axis_index(static_axis, ndim)
axis_def = f"{static_axis};"
axis_check = ""
else:
axis_dtype = node.inputs[1].type.dtype_specs()[1]
axis_def = f"(({axis_dtype} *)PyArray_DATA({axis}))[0];"
axis_check = f"""
if (axis < 0){{
axis = ndim + axis;
}}
if (axis >= ndim || axis < 0) {{
PyErr_SetString(PyExc_ValueError, "Split axis is out of bounds");
{fail}
}}
"""

return f"""
int ndim = PyArray_NDIM({x});
int axis = (int)(*({axis_dtype}*)PyArray_GETPTR1({axis}, 0));
int ndim = {ndim};
int axis = {axis_def}
int splits_count = PyArray_DIM({splits}, 0);
npy_intp len_along_axis, sum_of_splits = 0, current_split_length = 0, current_split_start = 0;
npy_intp* split_dims = NULL;
PyObject* split_view = NULL;
npy_intp data_offset;
int i;
npy_intp sum_of_splits = 0, current_split_start = 0;
PyArrayObject** outputs[] = {{{outputs_pointers}}};
npy_intp split_dims[ndim];

/* Check inputs. */

if (splits_count != {expected_splits_count}) {{
PyErr_Format(PyExc_ValueError,
"Split: splits count (%d) != expected count (%d).", splits_count, {expected_splits_count});
if (PyArray_NDIM({x}) != ndim) {{
PyErr_Format(PyExc_ValueError, "Input to Split does not have expected ndim");
{fail}
}}

if (axis < 0) {{
axis += ndim;
}}
if (axis < 0 || axis >= ndim) {{
PyErr_Format(PyExc_IndexError, "Split: invalid axis %d for a %d-D array.", axis, ndim);
if (splits_count != {len_splits}) {{
PyErr_Format(PyExc_ValueError, "Split: splits count (%d) != expected count (%d).", splits_count, {len_splits});
{fail}
}}
len_along_axis = PyArray_DIM({x}, axis);

for (i = 0; i < splits_count; ++i) {{
current_split_length = (npy_intp)(*({splits_dtype}*)PyArray_GETPTR1({splits}, i));
{axis_check};

for (int i = 0; i < splits_count; ++i) {{
int current_split_length = (npy_intp)(*({splits_dtype}*)PyArray_GETPTR1({splits}, i));
if (current_split_length < 0) {{
PyErr_Format(PyExc_ValueError,
"Split: you try to take a negative number (%ld) of elements.", current_split_length);
{fail}
}}
sum_of_splits += current_split_length;
}}
if (sum_of_splits != len_along_axis) {{
PyErr_Format(PyExc_ValueError, "Split: the splits sums to %ld, expected %ld.", sum_of_splits, len_along_axis);
{fail}
}}

/* Check outputs. */

split_dims = (npy_intp*) malloc(ndim * sizeof(npy_intp));
if (split_dims == NULL) {{
PyErr_NoMemory();
if (sum_of_splits != PyArray_DIM({x}, axis)) {{
PyErr_Format(PyExc_ValueError, "Split: the splits sums to %ld, expected %ld.", sum_of_splits, PyArray_DIM({x}, axis));
{fail}
}}

/* Compute split. */
memcpy(split_dims, PyArray_DIMS({x}), ndim * sizeof(npy_intp));

for (i = 0; i < splits_count; ++i) {{
PyArrayObject** output = outputs[i];
current_split_length = (npy_intp) (* ({splits_dtype}*) PyArray_GETPTR1({splits}, i));
if (*output == NULL || !split_output_shape_is_correct(*output, {x}, axis, current_split_length)) {{
Py_XDECREF(*output);
split_dims[axis] = current_split_length;
*output = (PyArrayObject*)PyArray_EMPTY(ndim, split_dims, {x_typenum}, PyArray_IS_F_CONTIGUOUS({x}));
if (outputs == NULL) {{
PyErr_SetString(PyExc_RuntimeError, "Split: unable to allocate an output.");
free(split_dims);
{fail}
}}
}}
}}

/* Compute split. */
for (int i = 0; i < splits_count; ++i) {{
Py_XDECREF(*outputs[i]);

for (i = 0; i < splits_count; ++i) {{
current_split_length = (npy_intp) (* ({splits_dtype}*) PyArray_GETPTR1({splits}, i));
data_offset = PyArray_STRIDE({x}, axis) * current_split_start;
// Create view of input
npy_intp data_offset = PyArray_STRIDE({x}, axis) * current_split_start;
int current_split_length = (npy_intp)(*({splits_dtype}*)PyArray_GETPTR1({splits}, i));
split_dims[axis] = current_split_length;
split_view = PyArray_New(&PyArray_Type,
ndim, split_dims,
{x_typenum},
PyArray_STRIDES({x}),
PyArray_BYTES({x}) + data_offset,
{x_itemsize},
PyArray_FLAGS({x}),
NULL);
if (split_view == NULL) {{
PyArray_Descr *descr = PyArray_DESCR({x});
Py_INCREF(descr);
*outputs[i] = (PyArrayObject*)PyArray_NewFromDescr(&PyArray_Type,
descr, // PyArray_NewFromDescr steals this reference
ndim, split_dims,
PyArray_STRIDES({x}),
PyArray_BYTES({x}) + data_offset,
PyArray_FLAGS({x}) & ~NPY_ARRAY_OWNDATA,
NULL);

if (*outputs[i] == NULL) {{
PyErr_SetString(PyExc_RuntimeError, "Split: unable to create a view for a split.");
free(split_dims);
{fail}
}}
if (PyArray_CopyInto(*outputs[i], (PyArrayObject*)split_view) != 0) {{
PyErr_SetString(PyExc_RuntimeError, "Split: unable to copy a split view into the output.");
Py_XDECREF(split_view);
free(split_dims);
{fail}
}}
Py_XDECREF(split_view);

// Set as a view of input
Py_INCREF((PyObject*){x});
PyArray_SetBaseObject(*outputs[i], (PyObject*){x});

// Update split slice pointer
current_split_start += current_split_length;
}}

free(split_dims);
"""


Expand Down
8 changes: 5 additions & 3 deletions tests/link/jax/test_tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,14 @@ def test_runtime_errors(self):
):
fn(np.zeros((6, 4), dtype=pytensor.config.floatX))

a_splits = ptb.split(a, splits_size=[2, 4], n_splits=3, axis=0)
fn = pytensor.function([a], a_splits, mode="JAX")
# This check is triggered at compile time if splits_size has incompatible static length
splits_size = vector("splits_size", shape=(None,), dtype=int)
a_splits = ptb.split(a, splits_size=splits_size, n_splits=3, axis=0)
fn = pytensor.function([a, splits_size], a_splits, mode="JAX")
with pytest.raises(
ValueError, match="Length of splits is not equal to n_splits"
):
fn(np.zeros((6, 4), dtype=pytensor.config.floatX))
fn(np.zeros((6, 4), dtype=pytensor.config.floatX), [2, 2])

a_splits = ptb.split(a, splits_size=[2, 4], n_splits=2, axis=0)
fn = pytensor.function([a], a_splits, mode="JAX")
Expand Down
6 changes: 1 addition & 5 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2172,11 +2172,7 @@ def test_split_view(self, linker):
res = f(x_test)
for r, expected in zip(res, ([], [0, 1, 2], [3, 4]), strict=True):
assert np.allclose(r, expected)
if linker == "py":
assert r.base is x_test
else:
# C impl always makes a copy
assert r.base is not x_test
assert r.base is x_test


def test_TensorFromScalar():
Expand Down