@@ -2171,8 +2171,6 @@ class Split(COp):
2171
2171
array([3, 4])
2172
2172
>>> c
2173
2173
array([5])
2174
-
2175
- TODO: Don't make a copy in C impl
2176
2174
"""
2177
2175
2178
2176
len_splits = None
@@ -2285,142 +2283,107 @@ def R_op(self, inputs, eval_points):
2285
2283
return self .make_node (eval_points [0 ], * inputs [1 :]).outputs
2286
2284
2287
2285
def c_code_cache_version (self ):
2288
- return (2 ,)
2289
-
2290
- def c_support_code (self , ** kwargs ):
2291
- return """
2292
- /* Return 1 if output has the correct shape. */
2293
- int split_output_shape_is_correct (
2294
- PyArrayObject* output, PyArrayObject* array_to_split, int axis_to_split, npy_intp split_size
2295
- ) {
2296
- return
2297
- PyArray_NDIM(output) == PyArray_NDIM(array_to_split)
2298
- && memcmp(
2299
- PyArray_DIMS(output),
2300
- PyArray_DIMS(array_to_split),
2301
- axis_to_split * sizeof(npy_intp)
2302
- ) == 0
2303
- && memcmp(
2304
- PyArray_DIMS(output) + axis_to_split + 1,
2305
- PyArray_DIMS(array_to_split) + axis_to_split + 1,
2306
- (PyArray_NDIM(array_to_split) - axis_to_split - 1) * sizeof(npy_intp)
2307
- ) == 0
2308
- && split_size == PyArray_DIM(output, axis_to_split);
2309
- }
2310
- """
2286
+ return (3 ,)
2311
2287
2312
2288
def c_code (self , node , name , inputs , outputs , sub ):
2313
2289
if self .len_splits == 0 :
2314
- # There are no outputs, then nothing to do.
2315
- return ""
2290
+ # This would be a view Op, anyway shouldn't be triggered
2291
+ raise NotImplementedError ()
2316
2292
2317
2293
# outputs_pointers lists the addresses of the pointers to the outputs.
2318
2294
outputs_pointers = "&" + (", &" .join (outputs ))
2319
2295
x , axis , splits = inputs
2320
2296
fail = sub ["fail" ]
2321
- x_typenum = np .dtype (node .inputs [0 ].dtype ).num
2322
- x_itemsize = np .dtype (node .inputs [0 ].dtype ).itemsize
2323
- axis_dtype = node .inputs [1 ].type .dtype_specs ()[1 ]
2324
2297
splits_dtype = node .inputs [2 ].type .dtype_specs ()[1 ]
2325
- expected_splits_count = self .len_splits
2298
+ len_splits = self .len_splits
2299
+ ndim = node .inputs [0 ].type .ndim
2300
+
2301
+ # Most times axis is constant, inline it
2302
+ # This is safe to do because the hash of the c_code includes the constant signature
2303
+ if isinstance (node .inputs [1 ], Constant ):
2304
+ static_axis = int (node .inputs [1 ].data )
2305
+ static_axis = normalize_axis_index (static_axis , ndim )
2306
+ axis_def = f"{ static_axis } ;"
2307
+ axis_check = ""
2308
+ else :
2309
+ axis_dtype = node .inputs [1 ].type .dtype_specs ()[1 ]
2310
+ axis_def = f"(({ axis_dtype } *)PyArray_DATA({ axis } ))[0];"
2311
+ axis_check = f"""
2312
+ if (axis < 0){{
2313
+ axis = ndim + axis;
2314
+ }}
2315
+ if (axis >= ndim || axis < 0) {{
2316
+ PyErr_SetString(PyExc_ValueError, "Split axis is out of bounds");
2317
+ { fail }
2318
+ }}
2319
+ """
2326
2320
2327
2321
return f"""
2328
- int ndim = PyArray_NDIM( { x } ) ;
2329
- int axis = (int)(*( { axis_dtype } *)PyArray_GETPTR1( { axis } , 0));
2322
+ int ndim = { ndim } ;
2323
+ int axis = { axis_def }
2330
2324
int splits_count = PyArray_DIM({ splits } , 0);
2331
- npy_intp len_along_axis, sum_of_splits = 0, current_split_length = 0, current_split_start = 0;
2332
- npy_intp* split_dims = NULL;
2333
- PyObject* split_view = NULL;
2334
- npy_intp data_offset;
2335
- int i;
2325
+ npy_intp sum_of_splits = 0, current_split_start = 0;
2336
2326
PyArrayObject** outputs[] = {{{ outputs_pointers } }};
2327
+ npy_intp split_dims[ndim];
2337
2328
2338
2329
/* Check inputs. */
2339
-
2340
- if (splits_count != { expected_splits_count } ) {{
2341
- PyErr_Format(PyExc_ValueError,
2342
- "Split: splits count (%d) != expected count (%d).", splits_count, { expected_splits_count } );
2330
+ if (PyArray_NDIM({ x } ) != ndim) {{
2331
+ PyErr_Format(PyExc_ValueError, "Input to Split does not have expected ndim");
2343
2332
{ fail }
2344
2333
}}
2345
-
2346
- if (axis < 0) {{
2347
- axis += ndim;
2348
- }}
2349
- if (axis < 0 || axis >= ndim) {{
2350
- PyErr_Format(PyExc_IndexError, "Split: invalid axis %d for a %d-D array.", axis, ndim);
2334
+ if (splits_count != { len_splits } ) {{
2335
+ PyErr_Format(PyExc_ValueError, "Split: splits count (%d) != expected count (%d).", splits_count, { len_splits } );
2351
2336
{ fail }
2352
2337
}}
2353
- len_along_axis = PyArray_DIM({ x } , axis);
2354
2338
2355
- for (i = 0; i < splits_count; ++i) {{
2356
- current_split_length = (npy_intp)(*({ splits_dtype } *)PyArray_GETPTR1({ splits } , i));
2339
+ { axis_check } ;
2340
+
2341
+ for (int i = 0; i < splits_count; ++i) {{
2342
+ int current_split_length = (npy_intp)(*({ splits_dtype } *)PyArray_GETPTR1({ splits } , i));
2357
2343
if (current_split_length < 0) {{
2358
2344
PyErr_Format(PyExc_ValueError,
2359
2345
"Split: you try to take a negative number (%ld) of elements.", current_split_length);
2360
2346
{ fail }
2361
2347
}}
2362
2348
sum_of_splits += current_split_length;
2363
2349
}}
2364
- if (sum_of_splits != len_along_axis) {{
2365
- PyErr_Format(PyExc_ValueError, "Split: the splits sums to %ld, expected %ld.", sum_of_splits, len_along_axis);
2366
- { fail }
2367
- }}
2368
-
2369
- /* Check outputs. */
2370
-
2371
- split_dims = (npy_intp*) malloc(ndim * sizeof(npy_intp));
2372
- if (split_dims == NULL) {{
2373
- PyErr_NoMemory();
2350
+ if (sum_of_splits != PyArray_DIM({ x } , axis)) {{
2351
+ PyErr_Format(PyExc_ValueError, "Split: the splits sums to %ld, expected %ld.", sum_of_splits, PyArray_DIM({ x } , axis));
2374
2352
{ fail }
2375
2353
}}
2376
2354
2355
+ /* Compute split. */
2377
2356
memcpy(split_dims, PyArray_DIMS({ x } ), ndim * sizeof(npy_intp));
2378
2357
2379
- for (i = 0; i < splits_count; ++i) {{
2380
- PyArrayObject** output = outputs[i];
2381
- current_split_length = (npy_intp) (* ({ splits_dtype } *) PyArray_GETPTR1({ splits } , i));
2382
- if (*output == NULL || !split_output_shape_is_correct(*output, { x } , axis, current_split_length)) {{
2383
- Py_XDECREF(*output);
2384
- split_dims[axis] = current_split_length;
2385
- *output = (PyArrayObject*)PyArray_EMPTY(ndim, split_dims, { x_typenum } , PyArray_IS_F_CONTIGUOUS({ x } ));
2386
- if (outputs == NULL) {{
2387
- PyErr_SetString(PyExc_RuntimeError, "Split: unable to allocate an output.");
2388
- free(split_dims);
2389
- { fail }
2390
- }}
2391
- }}
2392
- }}
2358
+ for (int i = 0; i < splits_count; ++i) {{
2359
+ Py_XDECREF(*outputs[i]);
2393
2360
2394
- /* Compute split. */
2395
-
2396
- for (i = 0; i < splits_count; ++i) {{
2397
- current_split_length = (npy_intp) (* ({ splits_dtype } *) PyArray_GETPTR1({ splits } , i));
2398
- data_offset = PyArray_STRIDE({ x } , axis) * current_split_start;
2361
+ // Create view of input
2362
+ npy_intp data_offset = PyArray_STRIDE({ x } , axis) * current_split_start;
2363
+ int current_split_length = (npy_intp)(*({ splits_dtype } *)PyArray_GETPTR1({ splits } , i));
2399
2364
split_dims[axis] = current_split_length;
2400
- split_view = PyArray_New(&PyArray_Type,
2401
- ndim, split_dims,
2402
- { x_typenum } ,
2403
- PyArray_STRIDES({ x } ),
2404
- PyArray_BYTES({ x } ) + data_offset,
2405
- { x_itemsize } ,
2406
- PyArray_FLAGS({ x } ),
2407
- NULL);
2408
- if (split_view == NULL) {{
2365
+ PyArray_Descr *descr = PyArray_DESCR({ x } );
2366
+ Py_INCREF(descr);
2367
+ *outputs[i] = (PyArrayObject*)PyArray_NewFromDescr(&PyArray_Type,
2368
+ descr, // PyArray_NewFromDescr steals this reference
2369
+ ndim, split_dims,
2370
+ PyArray_STRIDES({ x } ),
2371
+ PyArray_BYTES({ x } ) + data_offset,
2372
+ PyArray_FLAGS({ x } ) & ~NPY_ARRAY_OWNDATA,
2373
+ NULL);
2374
+
2375
+ if (*outputs[i] == NULL) {{
2409
2376
PyErr_SetString(PyExc_RuntimeError, "Split: unable to create a view for a split.");
2410
- free(split_dims);
2411
- { fail }
2412
- }}
2413
- if (PyArray_CopyInto(*outputs[i], (PyArrayObject*)split_view) != 0) {{
2414
- PyErr_SetString(PyExc_RuntimeError, "Split: unable to copy a split view into the output.");
2415
- Py_XDECREF(split_view);
2416
- free(split_dims);
2417
2377
{ fail }
2418
2378
}}
2419
- Py_XDECREF(split_view);
2379
+
2380
+ // Set as a view of input
2381
+ Py_INCREF((PyObject*){ x } );
2382
+ PyArray_SetBaseObject(*outputs[i], (PyObject*){ x } );
2383
+
2384
+ // Update split slice pointer
2420
2385
current_split_start += current_split_length;
2421
2386
}}
2422
-
2423
- free(split_dims);
2424
2387
"""
2425
2388
2426
2389
0 commit comments