Skip to content

Commit 726b359

Browse files
Remove kernel_arg_type from enum_types, implement in _sycl_queue
The new implementation uses values from C enum, and hence the consistency between values in Python and values in C are assured.
1 parent fc910c4 commit 726b359

File tree

3 files changed

+150
-20
lines changed

3 files changed

+150
-20
lines changed

dpctl/_sycl_queue.pyx

+149
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,155 @@ __all__ = [
8888
_logger = logging.getLogger(__name__)
8989

9090

91+
cdef class kernel_arg_type_attribute:
92+
cdef str parent_name
93+
cdef str attr_name
94+
cdef int attr_value
95+
96+
def __cinit__(self, str parent, str name, int value):
97+
self.parent_name = parent
98+
self.attr_name = name
99+
self.attr_value = value
100+
101+
def __repr__(self):
102+
return f"<{self.parent_name}.{self.attr_name}: {self.attr_value}>"
103+
104+
def __str__(self):
105+
return f"<{self.parent_name}.{self.attr_name}: {self.attr_value}>"
106+
107+
@property
108+
def name(self):
109+
return self.attr_name
110+
111+
@property
112+
def value(self):
113+
return self.attr_value
114+
115+
116+
cdef class _kernel_arg_type:
117+
"""
118+
An enumeration of supported kernel argument types in
119+
:func:`dpctl.SyclQueue.submit`
120+
"""
121+
cdef str _name
122+
123+
def __cinit__(self):
124+
self._name = "kernel_arg_type"
125+
126+
127+
@property
128+
def __name__(self):
129+
return self._name
130+
131+
def __repr__(self):
132+
return "<enum 'kernel_arg_type'>"
133+
134+
def __str__(self):
135+
return "<enum 'kernel_arg_type'>"
136+
137+
@property
138+
def dpctl_int8(self):
139+
cdef str p_name = "dpctl_int8"
140+
return kernel_arg_type_attribute(
141+
self._name,
142+
p_name,
143+
_arg_data_type._INT8_T
144+
)
145+
146+
@property
147+
def dpctl_uint8(self):
148+
cdef str p_name = "dpctl_uint8"
149+
return kernel_arg_type_attribute(
150+
self._name,
151+
p_name,
152+
_arg_data_type._UINT8_T
153+
)
154+
155+
@property
156+
def dpctl_int16(self):
157+
cdef str p_name = "dpctl_int16"
158+
return kernel_arg_type_attribute(
159+
self._name,
160+
p_name,
161+
_arg_data_type._INT16_T
162+
)
163+
164+
@property
165+
def dpctl_uint16(self):
166+
cdef str p_name = "dpctl_uint16"
167+
return kernel_arg_type_attribute(
168+
self._name,
169+
p_name,
170+
_arg_data_type._UINT16_T
171+
)
172+
173+
@property
174+
def dpctl_int32(self):
175+
cdef str p_name = "dpctl_int32"
176+
return kernel_arg_type_attribute(
177+
self._name,
178+
p_name,
179+
_arg_data_type._INT32_T
180+
)
181+
182+
@property
183+
def dpctl_uint32(self):
184+
cdef str p_name = "dpctl_uint32"
185+
return kernel_arg_type_attribute(
186+
self._name,
187+
p_name,
188+
_arg_data_type._UINT32_T
189+
)
190+
191+
@property
192+
def dpctl_int64(self):
193+
cdef str p_name = "dpctl_int64"
194+
return kernel_arg_type_attribute(
195+
self._name,
196+
p_name,
197+
_arg_data_type._INT64_T
198+
)
199+
200+
@property
201+
def dpctl_uint64(self):
202+
cdef str p_name = "dpctl_uint64"
203+
return kernel_arg_type_attribute(
204+
self._name,
205+
p_name,
206+
_arg_data_type._UINT64_T
207+
)
208+
209+
@property
210+
def dpctl_float32(self):
211+
cdef str p_name = "dpctl_float32"
212+
return kernel_arg_type_attribute(
213+
self._name,
214+
p_name,
215+
_arg_data_type._FLOAT
216+
)
217+
218+
@property
219+
def dpctl_float64(self):
220+
cdef str p_name = "dpctl_float64"
221+
return kernel_arg_type_attribute(
222+
self._name,
223+
p_name,
224+
_arg_data_type._DOUBLE
225+
)
226+
227+
@property
228+
def dpctl_void_ptr(self):
229+
cdef str p_name = "dpctl_void_ptr"
230+
return kernel_arg_type_attribute(
231+
self._name,
232+
p_name,
233+
_arg_data_type._VOID_PTR
234+
)
235+
236+
237+
kernel_arg_type = _kernel_arg_type()
238+
239+
91240
cdef class SyclKernelSubmitError(Exception):
92241
"""
93242
A SyclKernelSubmitError exception is raised when the provided

dpctl/enum_types.py

-19
Original file line numberDiff line numberDiff line change
@@ -113,22 +113,3 @@ class global_mem_cache_type(Enum):
113113
none = auto()
114114
read_only = auto()
115115
read_write = auto()
116-
117-
118-
class kernel_arg_type(Enum):
119-
"""
120-
An enumeration of supported kernel argument types in
121-
:func:`dpctl.SyclQueue.submit`
122-
"""
123-
124-
dpctl_int8 = 0
125-
dpctl_uint8 = auto()
126-
dpctl_int16 = auto()
127-
dpctl_uint16 = auto()
128-
dpctl_int32 = auto()
129-
dpctl_uint32 = auto()
130-
dpctl_int64 = auto()
131-
dpctl_uint64 = auto()
132-
dpctl_float32 = auto()
133-
dpctl_float64 = auto()
134-
dpctl_void_ptr = auto()

dpctl/tests/test_sycl_kernel_submit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import dpctl.memory as dpctl_mem
2727
import dpctl.program as dpctl_prog
2828
import dpctl.tensor as dpt
29-
from dpctl.enum_types import kernel_arg_type
29+
from dpctl._sycl_queue import kernel_arg_type
3030

3131

3232
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)