Skip to content

Commit d9990ad

Browse files
committed
Make Split C-impl return a view
1 parent b27490a commit d9990ad

File tree

1 file changed

+63
-99
lines changed

1 file changed

+63
-99
lines changed

pytensor/tensor/basic.py

Lines changed: 63 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -2171,8 +2171,6 @@ class Split(COp):
21712171
array([3, 4])
21722172
>>> c
21732173
array([5])
2174-
2175-
TODO: Don't make a copy in C impl
21762174
"""
21772175

21782176
len_splits = None
@@ -2285,142 +2283,108 @@ def R_op(self, inputs, eval_points):
22852283
return self.make_node(eval_points[0], *inputs[1:]).outputs
22862284

22872285
def c_code_cache_version(self):
2288-
return (2,)
2289-
2290-
def c_support_code(self, **kwargs):
2291-
return """
2292-
/* Return 1 if output has the correct shape. */
2293-
int split_output_shape_is_correct (
2294-
PyArrayObject* output, PyArrayObject* array_to_split, int axis_to_split, npy_intp split_size
2295-
) {
2296-
return
2297-
PyArray_NDIM(output) == PyArray_NDIM(array_to_split)
2298-
&& memcmp(
2299-
PyArray_DIMS(output),
2300-
PyArray_DIMS(array_to_split),
2301-
axis_to_split * sizeof(npy_intp)
2302-
) == 0
2303-
&& memcmp(
2304-
PyArray_DIMS(output) + axis_to_split + 1,
2305-
PyArray_DIMS(array_to_split) + axis_to_split + 1,
2306-
(PyArray_NDIM(array_to_split) - axis_to_split - 1) * sizeof(npy_intp)
2307-
) == 0
2308-
&& split_size == PyArray_DIM(output, axis_to_split);
2309-
}
2310-
"""
2286+
return (3,)
23112287

23122288
def c_code(self, node, name, inputs, outputs, sub):
23132289
if self.len_splits == 0:
2314-
# There are no outputs, then nothing to do.
2315-
return ""
2290+
# This would be a view Op, anyway shouldn't be triggered
2291+
raise NotImplementedError()
23162292

23172293
# outputs_pointers lists the addresses of the pointers to the outputs.
23182294
outputs_pointers = "&" + (", &".join(outputs))
23192295
x, axis, splits = inputs
23202296
fail = sub["fail"]
2321-
x_typenum = np.dtype(node.inputs[0].dtype).num
2322-
x_itemsize = np.dtype(node.inputs[0].dtype).itemsize
2323-
axis_dtype = node.inputs[1].type.dtype_specs()[1]
23242297
splits_dtype = node.inputs[2].type.dtype_specs()[1]
2325-
expected_splits_count = self.len_splits
2298+
len_splits = self.len_splits
2299+
ndim = node.inputs[0].type.ndim
2300+
2301+
# Most times axis is constant, inline it
2302+
# This is safe to do because the hash of the c_code includes the constant signature
2303+
if isinstance(node.inputs[1], Constant):
2304+
static_axis = int(node.inputs[1].data)
2305+
static_axis = normalize_axis_index(static_axis, ndim)
2306+
axis_def = f"{static_axis};"
2307+
axis_check = ""
2308+
else:
2309+
axis_dtype = node.inputs[1].type.dtype_specs()[1]
2310+
axis_def = f"(({axis_dtype} *)PyArray_DATA({axis}))[0];"
2311+
axis_check = f"""
2312+
if (axis < 0){{
2313+
axis = ndim + axis;
2314+
}}
2315+
if (axis >= ndim || axis < 0) {{
2316+
PyErr_SetString(PyExc_ValueError, "Split axis is out of bounds");
2317+
{fail}
2318+
}}
2319+
"""
23262320

23272321
return f"""
2328-
int ndim = PyArray_NDIM({x});
2329-
int axis = (int)(*({axis_dtype}*)PyArray_GETPTR1({axis}, 0));
2322+
int ndim = {ndim};
2323+
int axis = {axis_def}
23302324
int splits_count = PyArray_DIM({splits}, 0);
2331-
npy_intp len_along_axis, sum_of_splits = 0, current_split_length = 0, current_split_start = 0;
2332-
npy_intp* split_dims = NULL;
2333-
PyObject* split_view = NULL;
2334-
npy_intp data_offset;
2335-
int i;
2325+
npy_intp sum_of_splits = 0, current_split_start = 0;
23362326
PyArrayObject** outputs[] = {{{outputs_pointers}}};
2327+
npy_intp split_dims[ndim];
2328+
PyObject* split_view = NULL;
23372329
23382330
/* Check inputs. */
2339-
2340-
if (splits_count != {expected_splits_count}) {{
2341-
PyErr_Format(PyExc_ValueError,
2342-
"Split: splits count (%d) != expected count (%d).", splits_count, {expected_splits_count});
2331+
if (PyArray_NDIM({x}) != ndim) {{
2332+
PyErr_Format(PyExc_ValueError, "Input to Split does not have expected ndim");
23432333
{fail}
23442334
}}
2345-
2346-
if (axis < 0) {{
2347-
axis += ndim;
2348-
}}
2349-
if (axis < 0 || axis >= ndim) {{
2350-
PyErr_Format(PyExc_IndexError, "Split: invalid axis %d for a %d-D array.", axis, ndim);
2335+
if (splits_count != {len_splits}) {{
2336+
PyErr_Format(PyExc_ValueError, "Split: splits count (%d) != expected count (%d).", splits_count, {len_splits});
23512337
{fail}
23522338
}}
2353-
len_along_axis = PyArray_DIM({x}, axis);
23542339
2355-
for (i = 0; i < splits_count; ++i) {{
2356-
current_split_length = (npy_intp)(*({splits_dtype}*)PyArray_GETPTR1({splits}, i));
2340+
{axis_check};
2341+
2342+
for (int i = 0; i < splits_count; ++i) {{
2343+
int current_split_length = (npy_intp)(*({splits_dtype}*)PyArray_GETPTR1({splits}, i));
23572344
if (current_split_length < 0) {{
23582345
PyErr_Format(PyExc_ValueError,
23592346
"Split: you try to take a negative number (%ld) of elements.", current_split_length);
23602347
{fail}
23612348
}}
23622349
sum_of_splits += current_split_length;
23632350
}}
2364-
if (sum_of_splits != len_along_axis) {{
2365-
PyErr_Format(PyExc_ValueError, "Split: the splits sums to %ld, expected %ld.", sum_of_splits, len_along_axis);
2366-
{fail}
2367-
}}
2368-
2369-
/* Check outputs. */
2370-
2371-
split_dims = (npy_intp*) malloc(ndim * sizeof(npy_intp));
2372-
if (split_dims == NULL) {{
2373-
PyErr_NoMemory();
2351+
if (sum_of_splits != PyArray_DIM({x}, axis)) {{
2352+
PyErr_Format(PyExc_ValueError, "Split: the splits sums to %ld, expected %ld.", sum_of_splits, PyArray_DIM({x}, axis));
23742353
{fail}
23752354
}}
23762355
2356+
/* Compute split. */
23772357
memcpy(split_dims, PyArray_DIMS({x}), ndim * sizeof(npy_intp));
23782358
2379-
for (i = 0; i < splits_count; ++i) {{
2380-
PyArrayObject** output = outputs[i];
2381-
current_split_length = (npy_intp) (* ({splits_dtype}*) PyArray_GETPTR1({splits}, i));
2382-
if (*output == NULL || !split_output_shape_is_correct(*output, {x}, axis, current_split_length)) {{
2383-
Py_XDECREF(*output);
2384-
split_dims[axis] = current_split_length;
2385-
*output = (PyArrayObject*)PyArray_EMPTY(ndim, split_dims, {x_typenum}, PyArray_IS_F_CONTIGUOUS({x}));
2386-
if (outputs == NULL) {{
2387-
PyErr_SetString(PyExc_RuntimeError, "Split: unable to allocate an output.");
2388-
free(split_dims);
2389-
{fail}
2390-
}}
2391-
}}
2392-
}}
2393-
2394-
/* Compute split. */
2359+
for (int i = 0; i < splits_count; ++i) {{
2360+
Py_XDECREF(*outputs[i]);
23952361
2396-
for (i = 0; i < splits_count; ++i) {{
2397-
current_split_length = (npy_intp) (* ({splits_dtype}*) PyArray_GETPTR1({splits}, i));
2398-
data_offset = PyArray_STRIDE({x}, axis) * current_split_start;
2362+
// Create view of input
2363+
npy_intp data_offset = PyArray_STRIDE({x}, axis) * current_split_start;
2364+
int current_split_length = (npy_intp)(*({splits_dtype}*)PyArray_GETPTR1({splits}, i));
23992365
split_dims[axis] = current_split_length;
2400-
split_view = PyArray_New(&PyArray_Type,
2401-
ndim, split_dims,
2402-
{x_typenum},
2403-
PyArray_STRIDES({x}),
2404-
PyArray_BYTES({x}) + data_offset,
2405-
{x_itemsize},
2406-
PyArray_FLAGS({x}),
2407-
NULL);
2408-
if (split_view == NULL) {{
2366+
PyArray_Descr *descr = PyArray_DESCR({x});
2367+
Py_INCREF(descr);
2368+
*outputs[i] = (PyArrayObject*)PyArray_NewFromDescr(&PyArray_Type,
2369+
descr, // PyArray_NewFromDescr steals this reference
2370+
ndim, split_dims,
2371+
PyArray_STRIDES({x}),
2372+
PyArray_BYTES({x}) + data_offset,
2373+
PyArray_FLAGS({x}) & ~NPY_ARRAY_OWNDATA,
2374+
NULL);
2375+
2376+
if (*outputs[i] == NULL) {{
24092377
PyErr_SetString(PyExc_RuntimeError, "Split: unable to create a view for a split.");
2410-
free(split_dims);
2411-
{fail}
2412-
}}
2413-
if (PyArray_CopyInto(*outputs[i], (PyArrayObject*)split_view) != 0) {{
2414-
PyErr_SetString(PyExc_RuntimeError, "Split: unable to copy a split view into the output.");
2415-
Py_XDECREF(split_view);
2416-
free(split_dims);
24172378
{fail}
24182379
}}
2419-
Py_XDECREF(split_view);
2380+
2381+
// Set as a view of input
2382+
Py_INCREF((PyObject*){x});
2383+
PyArray_SetBaseObject(*outputs[i], (PyObject*){x});
2384+
2385+
// Update split slice pointer
24202386
current_split_start += current_split_length;
24212387
}}
2422-
2423-
free(split_dims);
24242388
"""
24252389

24262390

0 commit comments

Comments
 (0)