From 2134705783108f92b08ce40a563f6214c25ab160 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 31 Mar 2025 18:54:53 +0200 Subject: [PATCH 1/4] Remove patch on Numba impl of Split --- pytensor/link/numba/dispatch/tensor_basic.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index 8f5972c058..7daa625794 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -136,10 +136,7 @@ def join(axis, *tensors): def numba_funcify_Split(op, **kwargs): @numba_basic.numba_njit def split(tensor, axis, indices): - # Work around for https://github.com/numba/numba/issues/8257 - axis = axis % tensor.ndim - axis = numba_basic.to_scalar(axis) - return np.split(tensor, np.cumsum(indices)[:-1], axis=axis) + return np.split(tensor, np.cumsum(indices)[:-1], axis=axis.item()) return split From 2a9d167f263a4ad1b2e2da3b88f85fdab49e0d52 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 31 Mar 2025 13:42:22 +0200 Subject: [PATCH 2/4] Provide static shape in output of Split --- pytensor/tensor/basic.py | 24 ++++++++++++++++++++++-- tests/link/jax/test_tensor_basic.py | 8 +++++--- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 6bcb084f4e..1db798d798 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -2201,8 +2201,28 @@ def make_node(self, x, axis, splits): raise TypeError("`axis` parameter must be an integer scalar") inputs = [x, axis, splits] - out_type = TensorType(dtype=x.dtype, shape=(None,) * x.type.ndim) - outputs = [out_type() for i in range(self.len_splits)] + + x_dtype = x.type.dtype + if isinstance(axis, Constant): + # In this case we can preserve more static shape info + static_axis = axis.data.item() + outputs = [] + x_static_shape = list(x.type.shape) + for i in range(self.len_splits): + try: + static_split_size = int(get_scalar_constant_value(splits[i])) + except NotScalarConstantError: + static_split_size = None + except IndexError: + raise ValueError("Number of splits is larger than splits size") + static_out_shape = x_static_shape.copy() + static_out_shape[static_axis] = static_split_size + outputs.append(tensor(shape=tuple(static_out_shape), dtype=x_dtype)) + else: + outputs = [ + tensor(shape=(None,) * x.type.ndim, dtype=x_dtype) + for i in range(self.len_splits) + ] return Apply(self, inputs, outputs) diff --git a/tests/link/jax/test_tensor_basic.py b/tests/link/jax/test_tensor_basic.py index 5461095c70..1e1f496de1 100644 --- a/tests/link/jax/test_tensor_basic.py +++ b/tests/link/jax/test_tensor_basic.py @@ -150,12 +150,14 @@ def test_runtime_errors(self): ): fn(np.zeros((6, 4), dtype=pytensor.config.floatX)) - a_splits = ptb.split(a, splits_size=[2, 4], n_splits=3, axis=0) - fn = pytensor.function([a], a_splits, mode="JAX") + # This check is triggered at compile time if splits_size has incompatible static length + splits_size = vector("splits_size", shape=(None,), dtype=int) + a_splits = ptb.split(a, splits_size=splits_size, n_splits=3, axis=0) + fn = pytensor.function([a, splits_size], a_splits, mode="JAX") with pytest.raises( ValueError, match="Length of splits is not equal to n_splits" ): - fn(np.zeros((6, 4), dtype=pytensor.config.floatX)) + fn(np.zeros((6, 4), dtype=pytensor.config.floatX), [2, 2]) a_splits = ptb.split(a, splits_size=[2, 4], n_splits=2, axis=0) fn = pytensor.function([a], a_splits, mode="JAX") From b27490ad027f7cbbdb01d4f546e358ae5360436f Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Thu, 3 Apr 2025 13:21:15 +0200 Subject: [PATCH 3/4] Cleanup Split methods --- pytensor/tensor/basic.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 1db798d798..7fb1d66644 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -2226,21 +2226,21 @@ def make_node(self, x, axis, splits): return Apply(self, inputs, outputs) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, outputs_storage): x, axis, splits = inputs if len(splits) != self.len_splits: raise ValueError("Length of splits is not equal to n_splits") - if np.sum(splits) != x.shape[axis]: + if splits.sum() != x.shape[axis]: raise ValueError( - f"Split sizes sum to {np.sum(splits)}; expected {x.shape[axis]}" + f"Split sizes sum to {splits.sum()}; expected {x.shape[axis]}" ) - if np.any(splits < 0): + if (splits < 0).any(): raise ValueError("Split sizes cannot be negative") split_outs = np.split(x, np.cumsum(splits[:-1]), axis=axis) - for i, out in enumerate(split_outs): - outputs[i][0] = out + for out_storage, out in zip(outputs_storage, split_outs, strict=False): + out_storage[0] = out def infer_shape(self, fgraph, node, in_shapes): axis = node.inputs[1] @@ -2254,10 +2254,10 @@ def infer_shape(self, fgraph, node, in_shapes): out_shapes.append(temp) return out_shapes - def grad(self, inputs, g_outputs): + def L_op(self, inputs, outputs, g_outputs): """Join the gradients along the axis that was used to split x.""" x, axis, n = inputs - outputs = self(*inputs, return_list=True) + # If all the output gradients are disconnected, then so are the inputs if builtins.all(isinstance(g.type, DisconnectedType) for g in g_outputs): return [ From f1d6ba02a4b3d58dafa5a5aceaaf9fcde60917e8 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Thu, 3 Apr 2025 12:38:27 +0200 Subject: [PATCH 4/4] Make Split C-impl return a view --- pytensor/tensor/basic.py | 161 ++++++++++++++----------------------- tests/tensor/test_basic.py | 6 +- 2 files changed, 63 insertions(+), 104 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 7fb1d66644..5d6c059c53 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -2171,8 +2171,6 @@ class Split(COp): array([3, 4]) >>> c array([5]) - - TODO: Don't make a copy in C impl """ len_splits = None @@ -2285,75 +2283,63 @@ def R_op(self, inputs, eval_points): return self.make_node(eval_points[0], *inputs[1:]).outputs def c_code_cache_version(self): - return (2,) - - def c_support_code(self, **kwargs): - return """ - /* Return 1 if output has the correct shape. */ - int split_output_shape_is_correct ( - PyArrayObject* output, PyArrayObject* array_to_split, int axis_to_split, npy_intp split_size - ) { - return - PyArray_NDIM(output) == PyArray_NDIM(array_to_split) - && memcmp( - PyArray_DIMS(output), - PyArray_DIMS(array_to_split), - axis_to_split * sizeof(npy_intp) - ) == 0 - && memcmp( - PyArray_DIMS(output) + axis_to_split + 1, - PyArray_DIMS(array_to_split) + axis_to_split + 1, - (PyArray_NDIM(array_to_split) - axis_to_split - 1) * sizeof(npy_intp) - ) == 0 - && split_size == PyArray_DIM(output, axis_to_split); - } - """ + return (3,) def c_code(self, node, name, inputs, outputs, sub): if self.len_splits == 0: - # There are no outputs, then nothing to do. - return "" + # This would be a view Op, anyway shouldn't be triggered + raise NotImplementedError() # outputs_pointers lists the addresses of the pointers to the outputs. outputs_pointers = "&" + (", &".join(outputs)) x, axis, splits = inputs fail = sub["fail"] - x_typenum = np.dtype(node.inputs[0].dtype).num - x_itemsize = np.dtype(node.inputs[0].dtype).itemsize - axis_dtype = node.inputs[1].type.dtype_specs()[1] splits_dtype = node.inputs[2].type.dtype_specs()[1] - expected_splits_count = self.len_splits + len_splits = self.len_splits + ndim = node.inputs[0].type.ndim + + # Most times axis is constant, inline it + # This is safe to do because the hash of the c_code includes the constant signature + if isinstance(node.inputs[1], Constant): + static_axis = int(node.inputs[1].data) + static_axis = normalize_axis_index(static_axis, ndim) + axis_def = f"{static_axis};" + axis_check = "" + else: + axis_dtype = node.inputs[1].type.dtype_specs()[1] + axis_def = f"(({axis_dtype} *)PyArray_DATA({axis}))[0];" + axis_check = f""" + if (axis < 0){{ + axis = ndim + axis; + }} + if (axis >= ndim || axis < 0) {{ + PyErr_SetString(PyExc_ValueError, "Split axis is out of bounds"); + {fail} + }} + """ return f""" - int ndim = PyArray_NDIM({x}); - int axis = (int)(*({axis_dtype}*)PyArray_GETPTR1({axis}, 0)); + int ndim = {ndim}; + int axis = {axis_def} int splits_count = PyArray_DIM({splits}, 0); - npy_intp len_along_axis, sum_of_splits = 0, current_split_length = 0, current_split_start = 0; - npy_intp* split_dims = NULL; - PyObject* split_view = NULL; - npy_intp data_offset; - int i; + npy_intp sum_of_splits = 0, current_split_start = 0; PyArrayObject** outputs[] = {{{outputs_pointers}}}; + npy_intp split_dims[ndim]; /* Check inputs. */ - - if (splits_count != {expected_splits_count}) {{ - PyErr_Format(PyExc_ValueError, - "Split: splits count (%d) != expected count (%d).", splits_count, {expected_splits_count}); + if (PyArray_NDIM({x}) != ndim) {{ + PyErr_Format(PyExc_ValueError, "Input to Split does not have expected ndim"); {fail} }} - - if (axis < 0) {{ - axis += ndim; - }} - if (axis < 0 || axis >= ndim) {{ - PyErr_Format(PyExc_IndexError, "Split: invalid axis %d for a %d-D array.", axis, ndim); + if (splits_count != {len_splits}) {{ + PyErr_Format(PyExc_ValueError, "Split: splits count (%d) != expected count (%d).", splits_count, {len_splits}); {fail} }} - len_along_axis = PyArray_DIM({x}, axis); - for (i = 0; i < splits_count; ++i) {{ - current_split_length = (npy_intp)(*({splits_dtype}*)PyArray_GETPTR1({splits}, i)); + {axis_check}; + + for (int i = 0; i < splits_count; ++i) {{ + int current_split_length = (npy_intp)(*({splits_dtype}*)PyArray_GETPTR1({splits}, i)); if (current_split_length < 0) {{ PyErr_Format(PyExc_ValueError, "Split: you try to take a negative number (%ld) of elements.", current_split_length); @@ -2361,66 +2347,43 @@ def c_code(self, node, name, inputs, outputs, sub): }} sum_of_splits += current_split_length; }} - if (sum_of_splits != len_along_axis) {{ - PyErr_Format(PyExc_ValueError, "Split: the splits sums to %ld, expected %ld.", sum_of_splits, len_along_axis); - {fail} - }} - - /* Check outputs. */ - - split_dims = (npy_intp*) malloc(ndim * sizeof(npy_intp)); - if (split_dims == NULL) {{ - PyErr_NoMemory(); + if (sum_of_splits != PyArray_DIM({x}, axis)) {{ + PyErr_Format(PyExc_ValueError, "Split: the splits sums to %ld, expected %ld.", sum_of_splits, PyArray_DIM({x}, axis)); {fail} }} + /* Compute split. */ memcpy(split_dims, PyArray_DIMS({x}), ndim * sizeof(npy_intp)); - for (i = 0; i < splits_count; ++i) {{ - PyArrayObject** output = outputs[i]; - current_split_length = (npy_intp) (* ({splits_dtype}*) PyArray_GETPTR1({splits}, i)); - if (*output == NULL || !split_output_shape_is_correct(*output, {x}, axis, current_split_length)) {{ - Py_XDECREF(*output); - split_dims[axis] = current_split_length; - *output = (PyArrayObject*)PyArray_EMPTY(ndim, split_dims, {x_typenum}, PyArray_IS_F_CONTIGUOUS({x})); - if (outputs == NULL) {{ - PyErr_SetString(PyExc_RuntimeError, "Split: unable to allocate an output."); - free(split_dims); - {fail} - }} - }} - }} + for (int i = 0; i < splits_count; ++i) {{ + Py_XDECREF(*outputs[i]); - /* Compute split. */ - - for (i = 0; i < splits_count; ++i) {{ - current_split_length = (npy_intp) (* ({splits_dtype}*) PyArray_GETPTR1({splits}, i)); - data_offset = PyArray_STRIDE({x}, axis) * current_split_start; + // Create view of input + npy_intp data_offset = PyArray_STRIDE({x}, axis) * current_split_start; + int current_split_length = (npy_intp)(*({splits_dtype}*)PyArray_GETPTR1({splits}, i)); split_dims[axis] = current_split_length; - split_view = PyArray_New(&PyArray_Type, - ndim, split_dims, - {x_typenum}, - PyArray_STRIDES({x}), - PyArray_BYTES({x}) + data_offset, - {x_itemsize}, - PyArray_FLAGS({x}), - NULL); - if (split_view == NULL) {{ + PyArray_Descr *descr = PyArray_DESCR({x}); + Py_INCREF(descr); + *outputs[i] = (PyArrayObject*)PyArray_NewFromDescr(&PyArray_Type, + descr, // PyArray_NewFromDescr steals this reference + ndim, split_dims, + PyArray_STRIDES({x}), + PyArray_BYTES({x}) + data_offset, + PyArray_FLAGS({x}) & ~NPY_ARRAY_OWNDATA, + NULL); + + if (*outputs[i] == NULL) {{ PyErr_SetString(PyExc_RuntimeError, "Split: unable to create a view for a split."); - free(split_dims); - {fail} - }} - if (PyArray_CopyInto(*outputs[i], (PyArrayObject*)split_view) != 0) {{ - PyErr_SetString(PyExc_RuntimeError, "Split: unable to copy a split view into the output."); - Py_XDECREF(split_view); - free(split_dims); {fail} }} - Py_XDECREF(split_view); + + // Set as a view of input + Py_INCREF((PyObject*){x}); + PyArray_SetBaseObject(*outputs[i], (PyObject*){x}); + + // Update split slice pointer current_split_start += current_split_length; }} - - free(split_dims); """ diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index dee0023efd..e29a47691a 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -2172,11 +2172,7 @@ def test_split_view(self, linker): res = f(x_test) for r, expected in zip(res, ([], [0, 1, 2], [3, 4]), strict=True): assert np.allclose(r, expected) - if linker == "py": - assert r.base is x_test - else: - # C impl always makes a copy - assert r.base is not x_test + assert r.base is x_test def test_TensorFromScalar():