diff --git a/dpctl/_sycl_queue.pyx b/dpctl/_sycl_queue.pyx index 542b7b5a47..0dec0990c3 100644 --- a/dpctl/_sycl_queue.pyx +++ b/dpctl/_sycl_queue.pyx @@ -88,6 +88,155 @@ __all__ = [ _logger = logging.getLogger(__name__) +cdef class kernel_arg_type_attribute: + cdef str parent_name + cdef str attr_name + cdef int attr_value + + def __cinit__(self, str parent, str name, int value): + self.parent_name = parent + self.attr_name = name + self.attr_value = value + + def __repr__(self): + return f"<{self.parent_name}.{self.attr_name}: {self.attr_value}>" + + def __str__(self): + return f"<{self.parent_name}.{self.attr_name}: {self.attr_value}>" + + @property + def name(self): + return self.attr_name + + @property + def value(self): + return self.attr_value + + +cdef class _kernel_arg_type: + """ + An enumeration of supported kernel argument types in + :func:`dpctl.SyclQueue.submit` + """ + cdef str _name + + def __cinit__(self): + self._name = "kernel_arg_type" + + + @property + def __name__(self): + return self._name + + def __repr__(self): + return "" + + def __str__(self): + return "" + + @property + def dpctl_int8(self): + cdef str p_name = "dpctl_int8" + return kernel_arg_type_attribute( + self._name, + p_name, + _arg_data_type._INT8_T + ) + + @property + def dpctl_uint8(self): + cdef str p_name = "dpctl_uint8" + return kernel_arg_type_attribute( + self._name, + p_name, + _arg_data_type._UINT8_T + ) + + @property + def dpctl_int16(self): + cdef str p_name = "dpctl_int16" + return kernel_arg_type_attribute( + self._name, + p_name, + _arg_data_type._INT16_T + ) + + @property + def dpctl_uint16(self): + cdef str p_name = "dpctl_uint16" + return kernel_arg_type_attribute( + self._name, + p_name, + _arg_data_type._UINT16_T + ) + + @property + def dpctl_int32(self): + cdef str p_name = "dpctl_int32" + return kernel_arg_type_attribute( + self._name, + p_name, + _arg_data_type._INT32_T + ) + + @property + def dpctl_uint32(self): + cdef str p_name = "dpctl_uint32" + return kernel_arg_type_attribute( + self._name, + p_name, + _arg_data_type._UINT32_T + ) + + @property + def dpctl_int64(self): + cdef str p_name = "dpctl_int64" + return kernel_arg_type_attribute( + self._name, + p_name, + _arg_data_type._INT64_T + ) + + @property + def dpctl_uint64(self): + cdef str p_name = "dpctl_uint64" + return kernel_arg_type_attribute( + self._name, + p_name, + _arg_data_type._UINT64_T + ) + + @property + def dpctl_float32(self): + cdef str p_name = "dpctl_float32" + return kernel_arg_type_attribute( + self._name, + p_name, + _arg_data_type._FLOAT + ) + + @property + def dpctl_float64(self): + cdef str p_name = "dpctl_float64" + return kernel_arg_type_attribute( + self._name, + p_name, + _arg_data_type._DOUBLE + ) + + @property + def dpctl_void_ptr(self): + cdef str p_name = "dpctl_void_ptr" + return kernel_arg_type_attribute( + self._name, + p_name, + _arg_data_type._VOID_PTR + ) + + +kernel_arg_type = _kernel_arg_type() + + cdef class SyclKernelSubmitError(Exception): """ A SyclKernelSubmitError exception is raised when the provided diff --git a/dpctl/enum_types.py b/dpctl/enum_types.py index bb8c54b7be..102ae09015 100644 --- a/dpctl/enum_types.py +++ b/dpctl/enum_types.py @@ -113,22 +113,3 @@ class global_mem_cache_type(Enum): none = auto() read_only = auto() read_write = auto() - - -class kernel_arg_type(Enum): - """ - An enumeration of supported kernel argument types in - :func:`dpctl.SyclQueue.submit` - """ - - dpctl_int8 = auto() - dpctl_uint8 = auto() - dpctl_int16 = auto() - dpctl_uint16 = auto() - dpctl_int32 = auto() - dpctl_uint32 = auto() - dpctl_int64 = auto() - dpctl_uint64 = auto() - dpctl_float32 = auto() - dpctl_float64 = auto() - dpctl_void_ptr = auto() diff --git a/dpctl/tests/test_sycl_kernel_submit.py b/dpctl/tests/test_sycl_kernel_submit.py index aa590388e9..01558dd4df 100644 --- a/dpctl/tests/test_sycl_kernel_submit.py +++ b/dpctl/tests/test_sycl_kernel_submit.py @@ -26,6 +26,7 @@ import dpctl.memory as dpctl_mem import dpctl.program as dpctl_prog import dpctl.tensor as dpt +from dpctl._sycl_queue import kernel_arg_type @pytest.mark.parametrize( @@ -244,3 +245,32 @@ def test_submit_async(): Xref[2, i] = min(Xref[0, i], Xref[1, i]) assert np.array_equal(Xnp[:, :n], Xref[:, :n]) + + +def _check_kernel_arg_type_instance(kati): + assert isinstance(kati.name, str) + assert isinstance(kati.value, int) + assert isinstance(repr(kati), str) + assert isinstance(str(kati), str) + + +def test_kernel_arg_type(): + """ + Check that enum values for kernel_arg_type start at 0, + as numba_dpex expects. The next enumerated type must + have next value. + """ + assert isinstance(kernel_arg_type.__name__, str) + assert isinstance(repr(kernel_arg_type), str) + assert isinstance(str(kernel_arg_type), str) + _check_kernel_arg_type_instance(kernel_arg_type.dpctl_int8) + _check_kernel_arg_type_instance(kernel_arg_type.dpctl_uint8) + _check_kernel_arg_type_instance(kernel_arg_type.dpctl_int16) + _check_kernel_arg_type_instance(kernel_arg_type.dpctl_uint16) + _check_kernel_arg_type_instance(kernel_arg_type.dpctl_int32) + _check_kernel_arg_type_instance(kernel_arg_type.dpctl_uint32) + _check_kernel_arg_type_instance(kernel_arg_type.dpctl_int64) + _check_kernel_arg_type_instance(kernel_arg_type.dpctl_uint64) + _check_kernel_arg_type_instance(kernel_arg_type.dpctl_float32) + _check_kernel_arg_type_instance(kernel_arg_type.dpctl_float64) + _check_kernel_arg_type_instance(kernel_arg_type.dpctl_void_ptr)