Skip to content

Commit 13f4443

Browse files
author
Diptorup Deb
authored
Merge pull request #1585 from IntelPython/fix/kernel_arg_type
Start kernel_arg_type enums from 0
2 parents 545dff2 + 5ef035e commit 13f4443

File tree

3 files changed

+179
-19
lines changed

3 files changed

+179
-19
lines changed

dpctl/_sycl_queue.pyx

+149
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,155 @@ __all__ = [
8888
_logger = logging.getLogger(__name__)
8989

9090

91+
cdef class kernel_arg_type_attribute:
92+
cdef str parent_name
93+
cdef str attr_name
94+
cdef int attr_value
95+
96+
def __cinit__(self, str parent, str name, int value):
97+
self.parent_name = parent
98+
self.attr_name = name
99+
self.attr_value = value
100+
101+
def __repr__(self):
102+
return f"<{self.parent_name}.{self.attr_name}: {self.attr_value}>"
103+
104+
def __str__(self):
105+
return f"<{self.parent_name}.{self.attr_name}: {self.attr_value}>"
106+
107+
@property
108+
def name(self):
109+
return self.attr_name
110+
111+
@property
112+
def value(self):
113+
return self.attr_value
114+
115+
116+
cdef class _kernel_arg_type:
117+
"""
118+
An enumeration of supported kernel argument types in
119+
:func:`dpctl.SyclQueue.submit`
120+
"""
121+
cdef str _name
122+
123+
def __cinit__(self):
124+
self._name = "kernel_arg_type"
125+
126+
127+
@property
128+
def __name__(self):
129+
return self._name
130+
131+
def __repr__(self):
132+
return "<enum 'kernel_arg_type'>"
133+
134+
def __str__(self):
135+
return "<enum 'kernel_arg_type'>"
136+
137+
@property
138+
def dpctl_int8(self):
139+
cdef str p_name = "dpctl_int8"
140+
return kernel_arg_type_attribute(
141+
self._name,
142+
p_name,
143+
_arg_data_type._INT8_T
144+
)
145+
146+
@property
147+
def dpctl_uint8(self):
148+
cdef str p_name = "dpctl_uint8"
149+
return kernel_arg_type_attribute(
150+
self._name,
151+
p_name,
152+
_arg_data_type._UINT8_T
153+
)
154+
155+
@property
156+
def dpctl_int16(self):
157+
cdef str p_name = "dpctl_int16"
158+
return kernel_arg_type_attribute(
159+
self._name,
160+
p_name,
161+
_arg_data_type._INT16_T
162+
)
163+
164+
@property
165+
def dpctl_uint16(self):
166+
cdef str p_name = "dpctl_uint16"
167+
return kernel_arg_type_attribute(
168+
self._name,
169+
p_name,
170+
_arg_data_type._UINT16_T
171+
)
172+
173+
@property
174+
def dpctl_int32(self):
175+
cdef str p_name = "dpctl_int32"
176+
return kernel_arg_type_attribute(
177+
self._name,
178+
p_name,
179+
_arg_data_type._INT32_T
180+
)
181+
182+
@property
183+
def dpctl_uint32(self):
184+
cdef str p_name = "dpctl_uint32"
185+
return kernel_arg_type_attribute(
186+
self._name,
187+
p_name,
188+
_arg_data_type._UINT32_T
189+
)
190+
191+
@property
192+
def dpctl_int64(self):
193+
cdef str p_name = "dpctl_int64"
194+
return kernel_arg_type_attribute(
195+
self._name,
196+
p_name,
197+
_arg_data_type._INT64_T
198+
)
199+
200+
@property
201+
def dpctl_uint64(self):
202+
cdef str p_name = "dpctl_uint64"
203+
return kernel_arg_type_attribute(
204+
self._name,
205+
p_name,
206+
_arg_data_type._UINT64_T
207+
)
208+
209+
@property
210+
def dpctl_float32(self):
211+
cdef str p_name = "dpctl_float32"
212+
return kernel_arg_type_attribute(
213+
self._name,
214+
p_name,
215+
_arg_data_type._FLOAT
216+
)
217+
218+
@property
219+
def dpctl_float64(self):
220+
cdef str p_name = "dpctl_float64"
221+
return kernel_arg_type_attribute(
222+
self._name,
223+
p_name,
224+
_arg_data_type._DOUBLE
225+
)
226+
227+
@property
228+
def dpctl_void_ptr(self):
229+
cdef str p_name = "dpctl_void_ptr"
230+
return kernel_arg_type_attribute(
231+
self._name,
232+
p_name,
233+
_arg_data_type._VOID_PTR
234+
)
235+
236+
237+
kernel_arg_type = _kernel_arg_type()
238+
239+
91240
cdef class SyclKernelSubmitError(Exception):
92241
"""
93242
A SyclKernelSubmitError exception is raised when the provided

dpctl/enum_types.py

-19
Original file line numberDiff line numberDiff line change
@@ -113,22 +113,3 @@ class global_mem_cache_type(Enum):
113113
none = auto()
114114
read_only = auto()
115115
read_write = auto()
116-
117-
118-
class kernel_arg_type(Enum):
119-
"""
120-
An enumeration of supported kernel argument types in
121-
:func:`dpctl.SyclQueue.submit`
122-
"""
123-
124-
dpctl_int8 = auto()
125-
dpctl_uint8 = auto()
126-
dpctl_int16 = auto()
127-
dpctl_uint16 = auto()
128-
dpctl_int32 = auto()
129-
dpctl_uint32 = auto()
130-
dpctl_int64 = auto()
131-
dpctl_uint64 = auto()
132-
dpctl_float32 = auto()
133-
dpctl_float64 = auto()
134-
dpctl_void_ptr = auto()

dpctl/tests/test_sycl_kernel_submit.py

+30
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import dpctl.memory as dpctl_mem
2727
import dpctl.program as dpctl_prog
2828
import dpctl.tensor as dpt
29+
from dpctl._sycl_queue import kernel_arg_type
2930

3031

3132
@pytest.mark.parametrize(
@@ -244,3 +245,32 @@ def test_submit_async():
244245
Xref[2, i] = min(Xref[0, i], Xref[1, i])
245246

246247
assert np.array_equal(Xnp[:, :n], Xref[:, :n])
248+
249+
250+
def _check_kernel_arg_type_instance(kati):
251+
assert isinstance(kati.name, str)
252+
assert isinstance(kati.value, int)
253+
assert isinstance(repr(kati), str)
254+
assert isinstance(str(kati), str)
255+
256+
257+
def test_kernel_arg_type():
258+
"""
259+
Check that enum values for kernel_arg_type start at 0,
260+
as numba_dpex expects. The next enumerated type must
261+
have next value.
262+
"""
263+
assert isinstance(kernel_arg_type.__name__, str)
264+
assert isinstance(repr(kernel_arg_type), str)
265+
assert isinstance(str(kernel_arg_type), str)
266+
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_int8)
267+
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_uint8)
268+
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_int16)
269+
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_uint16)
270+
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_int32)
271+
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_uint32)
272+
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_int64)
273+
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_uint64)
274+
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_float32)
275+
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_float64)
276+
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_void_ptr)

0 commit comments

Comments
 (0)