Skip to content

Commit f1d6ba0

Browse files
committed
Make Split C-impl return a view
1 parent b27490a commit f1d6ba0

File tree

2 files changed

+63
-104
lines changed

2 files changed

+63
-104
lines changed

Diff for: pytensor/tensor/basic.py

+62-99
Original file line numberDiff line numberDiff line change
@@ -2171,8 +2171,6 @@ class Split(COp):
21712171
array([3, 4])
21722172
>>> c
21732173
array([5])
2174-
2175-
TODO: Don't make a copy in C impl
21762174
"""
21772175

21782176
len_splits = None
@@ -2285,142 +2283,107 @@ def R_op(self, inputs, eval_points):
22852283
return self.make_node(eval_points[0], *inputs[1:]).outputs
22862284

22872285
def c_code_cache_version(self):
2288-
return (2,)
2289-
2290-
def c_support_code(self, **kwargs):
2291-
return """
2292-
/* Return 1 if output has the correct shape. */
2293-
int split_output_shape_is_correct (
2294-
PyArrayObject* output, PyArrayObject* array_to_split, int axis_to_split, npy_intp split_size
2295-
) {
2296-
return
2297-
PyArray_NDIM(output) == PyArray_NDIM(array_to_split)
2298-
&& memcmp(
2299-
PyArray_DIMS(output),
2300-
PyArray_DIMS(array_to_split),
2301-
axis_to_split * sizeof(npy_intp)
2302-
) == 0
2303-
&& memcmp(
2304-
PyArray_DIMS(output) + axis_to_split + 1,
2305-
PyArray_DIMS(array_to_split) + axis_to_split + 1,
2306-
(PyArray_NDIM(array_to_split) - axis_to_split - 1) * sizeof(npy_intp)
2307-
) == 0
2308-
&& split_size == PyArray_DIM(output, axis_to_split);
2309-
}
2310-
"""
2286+
return (3,)
23112287

23122288
def c_code(self, node, name, inputs, outputs, sub):
23132289
if self.len_splits == 0:
2314-
# There are no outputs, then nothing to do.
2315-
return ""
2290+
# This would be a view Op, anyway shouldn't be triggered
2291+
raise NotImplementedError()
23162292

23172293
# outputs_pointers lists the addresses of the pointers to the outputs.
23182294
outputs_pointers = "&" + (", &".join(outputs))
23192295
x, axis, splits = inputs
23202296
fail = sub["fail"]
2321-
x_typenum = np.dtype(node.inputs[0].dtype).num
2322-
x_itemsize = np.dtype(node.inputs[0].dtype).itemsize
2323-
axis_dtype = node.inputs[1].type.dtype_specs()[1]
23242297
splits_dtype = node.inputs[2].type.dtype_specs()[1]
2325-
expected_splits_count = self.len_splits
2298+
len_splits = self.len_splits
2299+
ndim = node.inputs[0].type.ndim
2300+
2301+
# Most times axis is constant, inline it
2302+
# This is safe to do because the hash of the c_code includes the constant signature
2303+
if isinstance(node.inputs[1], Constant):
2304+
static_axis = int(node.inputs[1].data)
2305+
static_axis = normalize_axis_index(static_axis, ndim)
2306+
axis_def = f"{static_axis};"
2307+
axis_check = ""
2308+
else:
2309+
axis_dtype = node.inputs[1].type.dtype_specs()[1]
2310+
axis_def = f"(({axis_dtype} *)PyArray_DATA({axis}))[0];"
2311+
axis_check = f"""
2312+
if (axis < 0){{
2313+
axis = ndim + axis;
2314+
}}
2315+
if (axis >= ndim || axis < 0) {{
2316+
PyErr_SetString(PyExc_ValueError, "Split axis is out of bounds");
2317+
{fail}
2318+
}}
2319+
"""
23262320

23272321
return f"""
2328-
int ndim = PyArray_NDIM({x});
2329-
int axis = (int)(*({axis_dtype}*)PyArray_GETPTR1({axis}, 0));
2322+
int ndim = {ndim};
2323+
int axis = {axis_def}
23302324
int splits_count = PyArray_DIM({splits}, 0);
2331-
npy_intp len_along_axis, sum_of_splits = 0, current_split_length = 0, current_split_start = 0;
2332-
npy_intp* split_dims = NULL;
2333-
PyObject* split_view = NULL;
2334-
npy_intp data_offset;
2335-
int i;
2325+
npy_intp sum_of_splits = 0, current_split_start = 0;
23362326
PyArrayObject** outputs[] = {{{outputs_pointers}}};
2327+
npy_intp split_dims[ndim];
23372328
23382329
/* Check inputs. */
2339-
2340-
if (splits_count != {expected_splits_count}) {{
2341-
PyErr_Format(PyExc_ValueError,
2342-
"Split: splits count (%d) != expected count (%d).", splits_count, {expected_splits_count});
2330+
if (PyArray_NDIM({x}) != ndim) {{
2331+
PyErr_Format(PyExc_ValueError, "Input to Split does not have expected ndim");
23432332
{fail}
23442333
}}
2345-
2346-
if (axis < 0) {{
2347-
axis += ndim;
2348-
}}
2349-
if (axis < 0 || axis >= ndim) {{
2350-
PyErr_Format(PyExc_IndexError, "Split: invalid axis %d for a %d-D array.", axis, ndim);
2334+
if (splits_count != {len_splits}) {{
2335+
PyErr_Format(PyExc_ValueError, "Split: splits count (%d) != expected count (%d).", splits_count, {len_splits});
23512336
{fail}
23522337
}}
2353-
len_along_axis = PyArray_DIM({x}, axis);
23542338
2355-
for (i = 0; i < splits_count; ++i) {{
2356-
current_split_length = (npy_intp)(*({splits_dtype}*)PyArray_GETPTR1({splits}, i));
2339+
{axis_check};
2340+
2341+
for (int i = 0; i < splits_count; ++i) {{
2342+
int current_split_length = (npy_intp)(*({splits_dtype}*)PyArray_GETPTR1({splits}, i));
23572343
if (current_split_length < 0) {{
23582344
PyErr_Format(PyExc_ValueError,
23592345
"Split: you try to take a negative number (%ld) of elements.", current_split_length);
23602346
{fail}
23612347
}}
23622348
sum_of_splits += current_split_length;
23632349
}}
2364-
if (sum_of_splits != len_along_axis) {{
2365-
PyErr_Format(PyExc_ValueError, "Split: the splits sums to %ld, expected %ld.", sum_of_splits, len_along_axis);
2366-
{fail}
2367-
}}
2368-
2369-
/* Check outputs. */
2370-
2371-
split_dims = (npy_intp*) malloc(ndim * sizeof(npy_intp));
2372-
if (split_dims == NULL) {{
2373-
PyErr_NoMemory();
2350+
if (sum_of_splits != PyArray_DIM({x}, axis)) {{
2351+
PyErr_Format(PyExc_ValueError, "Split: the splits sums to %ld, expected %ld.", sum_of_splits, PyArray_DIM({x}, axis));
23742352
{fail}
23752353
}}
23762354
2355+
/* Compute split. */
23772356
memcpy(split_dims, PyArray_DIMS({x}), ndim * sizeof(npy_intp));
23782357
2379-
for (i = 0; i < splits_count; ++i) {{
2380-
PyArrayObject** output = outputs[i];
2381-
current_split_length = (npy_intp) (* ({splits_dtype}*) PyArray_GETPTR1({splits}, i));
2382-
if (*output == NULL || !split_output_shape_is_correct(*output, {x}, axis, current_split_length)) {{
2383-
Py_XDECREF(*output);
2384-
split_dims[axis] = current_split_length;
2385-
*output = (PyArrayObject*)PyArray_EMPTY(ndim, split_dims, {x_typenum}, PyArray_IS_F_CONTIGUOUS({x}));
2386-
if (outputs == NULL) {{
2387-
PyErr_SetString(PyExc_RuntimeError, "Split: unable to allocate an output.");
2388-
free(split_dims);
2389-
{fail}
2390-
}}
2391-
}}
2392-
}}
2358+
for (int i = 0; i < splits_count; ++i) {{
2359+
Py_XDECREF(*outputs[i]);
23932360
2394-
/* Compute split. */
2395-
2396-
for (i = 0; i < splits_count; ++i) {{
2397-
current_split_length = (npy_intp) (* ({splits_dtype}*) PyArray_GETPTR1({splits}, i));
2398-
data_offset = PyArray_STRIDE({x}, axis) * current_split_start;
2361+
// Create view of input
2362+
npy_intp data_offset = PyArray_STRIDE({x}, axis) * current_split_start;
2363+
int current_split_length = (npy_intp)(*({splits_dtype}*)PyArray_GETPTR1({splits}, i));
23992364
split_dims[axis] = current_split_length;
2400-
split_view = PyArray_New(&PyArray_Type,
2401-
ndim, split_dims,
2402-
{x_typenum},
2403-
PyArray_STRIDES({x}),
2404-
PyArray_BYTES({x}) + data_offset,
2405-
{x_itemsize},
2406-
PyArray_FLAGS({x}),
2407-
NULL);
2408-
if (split_view == NULL) {{
2365+
PyArray_Descr *descr = PyArray_DESCR({x});
2366+
Py_INCREF(descr);
2367+
*outputs[i] = (PyArrayObject*)PyArray_NewFromDescr(&PyArray_Type,
2368+
descr, // PyArray_NewFromDescr steals this reference
2369+
ndim, split_dims,
2370+
PyArray_STRIDES({x}),
2371+
PyArray_BYTES({x}) + data_offset,
2372+
PyArray_FLAGS({x}) & ~NPY_ARRAY_OWNDATA,
2373+
NULL);
2374+
2375+
if (*outputs[i] == NULL) {{
24092376
PyErr_SetString(PyExc_RuntimeError, "Split: unable to create a view for a split.");
2410-
free(split_dims);
2411-
{fail}
2412-
}}
2413-
if (PyArray_CopyInto(*outputs[i], (PyArrayObject*)split_view) != 0) {{
2414-
PyErr_SetString(PyExc_RuntimeError, "Split: unable to copy a split view into the output.");
2415-
Py_XDECREF(split_view);
2416-
free(split_dims);
24172377
{fail}
24182378
}}
2419-
Py_XDECREF(split_view);
2379+
2380+
// Set as a view of input
2381+
Py_INCREF((PyObject*){x});
2382+
PyArray_SetBaseObject(*outputs[i], (PyObject*){x});
2383+
2384+
// Update split slice pointer
24202385
current_split_start += current_split_length;
24212386
}}
2422-
2423-
free(split_dims);
24242387
"""
24252388

24262389

Diff for: tests/tensor/test_basic.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -2172,11 +2172,7 @@ def test_split_view(self, linker):
21722172
res = f(x_test)
21732173
for r, expected in zip(res, ([], [0, 1, 2], [3, 4]), strict=True):
21742174
assert np.allclose(r, expected)
2175-
if linker == "py":
2176-
assert r.base is x_test
2177-
else:
2178-
# C impl always makes a copy
2179-
assert r.base is not x_test
2175+
assert r.base is x_test
21802176

21812177

21822178
def test_TensorFromScalar():

0 commit comments

Comments
 (0)