3
3
import warnings
4
4
from collections .abc import Callable , Iterable , Sequence
5
5
from itertools import chain , groupby
6
- from textwrap import dedent
7
6
from typing import cast , overload
8
7
9
8
import numpy as np
19
18
from pytensor .graph .utils import MethodNotDefined
20
19
from pytensor .link .c .op import COp
21
20
from pytensor .link .c .params_type import ParamsType
22
- from pytensor .npy_2_compat import npy_2_compat_header , numpy_version , using_numpy_2
21
+ from pytensor .npy_2_compat import numpy_version , using_numpy_2
23
22
from pytensor .printing import Printer , pprint , set_precedence
24
23
from pytensor .scalar .basic import ScalarConstant , ScalarVariable
25
24
from pytensor .tensor import (
@@ -2130,24 +2129,6 @@ def perform(self, node, inp, out_):
2130
2129
else :
2131
2130
o = None
2132
2131
2133
- # If i.dtype is more precise than numpy.intp (int32 on 32-bit machines,
2134
- # int64 on 64-bit machines), numpy may raise the following error:
2135
- # TypeError: array cannot be safely cast to required type.
2136
- # We need to check if values in i can fit in numpy.intp, because
2137
- # if they don't, that should be an error (no array can have that
2138
- # many elements on a 32-bit arch).
2139
- if i .dtype != np .intp :
2140
- i_ = np .asarray (i , dtype = np .intp )
2141
- if not np .can_cast (i .dtype , np .intp ):
2142
- # Check if there was actually an incorrect conversion
2143
- if np .any (i != i_ ):
2144
- raise IndexError (
2145
- "index contains values that are bigger "
2146
- "than the maximum array size on this system." ,
2147
- i ,
2148
- )
2149
- i = i_
2150
-
2151
2132
out [0 ] = x .take (i , axis = 0 , out = o )
2152
2133
2153
2134
def connection_pattern (self , node ):
@@ -2187,16 +2168,6 @@ def infer_shape(self, fgraph, node, ishapes):
2187
2168
x , ilist = ishapes
2188
2169
return [ilist + x [1 :]]
2189
2170
2190
- def c_support_code (self , ** kwargs ):
2191
- # In some versions of numpy, NPY_MIN_INTP is defined as MIN_LONG,
2192
- # which is not defined. It should be NPY_MIN_LONG instead in that case.
2193
- return npy_2_compat_header () + dedent (
2194
- """\
2195
- #ifndef MIN_LONG
2196
- #define MIN_LONG NPY_MIN_LONG
2197
- #endif"""
2198
- )
2199
-
2200
2171
def c_code (self , node , name , input_names , output_names , sub ):
2201
2172
if self .__class__ is not AdvancedSubtensor1 :
2202
2173
raise MethodNotDefined (
@@ -2207,69 +2178,24 @@ def c_code(self, node, name, input_names, output_names, sub):
2207
2178
output_name = output_names [0 ]
2208
2179
fail = sub ["fail" ]
2209
2180
return f"""
2210
- PyArrayObject *indices;
2211
- int i_type = PyArray_TYPE({ i_name } );
2212
- if (i_type != NPY_INTP) {{
2213
- // Cast { i_name } to NPY_INTP (expected by PyArray_TakeFrom),
2214
- // if all values fit.
2215
- if (!PyArray_CanCastSafely(i_type, NPY_INTP) &&
2216
- PyArray_SIZE({ i_name } ) > 0) {{
2217
- npy_int64 min_val, max_val;
2218
- PyObject* py_min_val = PyArray_Min({ i_name } , NPY_RAVEL_AXIS,
2219
- NULL);
2220
- if (py_min_val == NULL) {{
2221
- { fail } ;
2222
- }}
2223
- min_val = PyLong_AsLongLong(py_min_val);
2224
- Py_DECREF(py_min_val);
2225
- if (min_val == -1 && PyErr_Occurred()) {{
2226
- { fail } ;
2227
- }}
2228
- PyObject* py_max_val = PyArray_Max({ i_name } , NPY_RAVEL_AXIS,
2229
- NULL);
2230
- if (py_max_val == NULL) {{
2231
- { fail } ;
2232
- }}
2233
- max_val = PyLong_AsLongLong(py_max_val);
2234
- Py_DECREF(py_max_val);
2235
- if (max_val == -1 && PyErr_Occurred()) {{
2236
- { fail } ;
2237
- }}
2238
- if (min_val < NPY_MIN_INTP || max_val > NPY_MAX_INTP) {{
2239
- PyErr_SetString(PyExc_IndexError,
2240
- "Index contains values "
2241
- "that are bigger than the maximum array "
2242
- "size on this system.");
2243
- { fail } ;
2244
- }}
2245
- }}
2246
- indices = (PyArrayObject*) PyArray_Cast({ i_name } , NPY_INTP);
2247
- if (indices == NULL) {{
2248
- { fail } ;
2249
- }}
2250
- }}
2251
- else {{
2252
- indices = { i_name } ;
2253
- Py_INCREF(indices);
2254
- }}
2255
2181
if ({ output_name } != NULL) {{
2256
2182
npy_intp nd, i, *shape;
2257
- nd = PyArray_NDIM({ a_name } ) + PyArray_NDIM(indices ) - 1;
2183
+ nd = PyArray_NDIM({ a_name } ) + PyArray_NDIM({ i_name } ) - 1;
2258
2184
if (PyArray_NDIM({ output_name } ) != nd) {{
2259
2185
Py_CLEAR({ output_name } );
2260
2186
}}
2261
2187
else {{
2262
2188
shape = PyArray_DIMS({ output_name } );
2263
- for (i = 0; i < PyArray_NDIM(indices ); i++) {{
2264
- if (shape[i] != PyArray_DIMS(indices )[i]) {{
2189
+ for (i = 0; i < PyArray_NDIM({ i_name } ); i++) {{
2190
+ if (shape[i] != PyArray_DIMS({ i_name } )[i]) {{
2265
2191
Py_CLEAR({ output_name } );
2266
2192
break;
2267
2193
}}
2268
2194
}}
2269
2195
if ({ output_name } != NULL) {{
2270
2196
for (; i < nd; i++) {{
2271
2197
if (shape[i] != PyArray_DIMS({ a_name } )[
2272
- i-PyArray_NDIM(indices )+1]) {{
2198
+ i-PyArray_NDIM({ i_name } )+1]) {{
2273
2199
Py_CLEAR({ output_name } );
2274
2200
break;
2275
2201
}}
@@ -2278,13 +2204,12 @@ def c_code(self, node, name, input_names, output_names, sub):
2278
2204
}}
2279
2205
}}
2280
2206
{ output_name } = (PyArrayObject*)PyArray_TakeFrom(
2281
- { a_name } , (PyObject*)indices, 0, { output_name } , NPY_RAISE);
2282
- Py_DECREF(indices);
2207
+ { a_name } , (PyObject*){ i_name } , 0, { output_name } , NPY_RAISE);
2283
2208
if ({ output_name } == NULL) { fail } ;
2284
2209
"""
2285
2210
2286
2211
def c_code_cache_version (self ):
2287
- return (0 , 1 , 2 , 3 )
2212
+ return (4 , )
2288
2213
2289
2214
2290
2215
advanced_subtensor1 = AdvancedSubtensor1 ()
0 commit comments