@@ -36,14 +36,20 @@ def test_memory_create(self):
36
36
queue = dpctl .get_current_queue ()
37
37
mobj = MemoryUSMShared (nbytes , queue )
38
38
self .assertEqual (mobj .nbytes , nbytes )
39
- self .assertTrue (hasattr (mobj , ' __sycl_usm_array_interface__' ))
39
+ self .assertTrue (hasattr (mobj , " __sycl_usm_array_interface__" ))
40
40
41
41
def _create_memory (self ):
42
42
nbytes = 1024
43
43
queue = dpctl .get_current_queue ()
44
44
mobj = MemoryUSMShared (nbytes , queue )
45
45
return mobj
46
46
47
+ def _create_host_buf (self , nbytes ):
48
+ ba = bytearray (nbytes )
49
+ for i in range (nbytes ):
50
+ ba [i ] = (i % 32 ) + ord ("a" )
51
+ return ba
52
+
47
53
@unittest .skipUnless (
48
54
dpctl .has_sycl_platforms (), "No SYCL devices except the default host device."
49
55
)
@@ -97,9 +103,7 @@ def test_buffer_protocol(self):
97
103
)
98
104
def test_copy_host_roundtrip (self ):
99
105
mobj = self ._create_memory ()
100
- host_src_obj = bytearray (mobj .nbytes )
101
- for i in range (mobj .nbytes ):
102
- host_src_obj [i ] = (i % 32 ) + ord ('a' )
106
+ host_src_obj = self ._create_host_buf (mobj .nbytes )
103
107
mobj .copy_from_host (host_src_obj )
104
108
host_dest_obj = mobj .copy_to_host ()
105
109
del mobj
@@ -113,22 +117,24 @@ def test_zero_copy(self):
113
117
mobj2 = type (mobj )(mobj )
114
118
115
119
self .assertTrue (mobj2 .reference_obj is mobj )
116
- self .assertTrue (mobj2 .__sycl_usm_array_interface__ ['data' ] == mobj .__sycl_usm_array_interface__ ['data' ])
120
+ mobj_data = mobj .__sycl_usm_array_interface__ ["data" ]
121
+ mobj2_data = mobj2 .__sycl_usm_array_interface__ ["data" ]
122
+ self .assertEqual (mobj_data , mobj2_data )
117
123
118
124
@unittest .skipUnless (
119
125
dpctl .has_sycl_platforms (), "No SYCL devices except the default host device."
120
126
)
121
127
def test_pickling (self ):
122
128
import pickle
129
+
123
130
mobj = self ._create_memory ()
124
- host_src_obj = bytearray (mobj .nbytes )
125
- for i in range (mobj .nbytes ):
126
- host_src_obj [i ] = (i % 32 ) + ord ('a' )
131
+ host_src_obj = self ._create_host_buf (mobj .nbytes )
127
132
mobj .copy_from_host (host_src_obj )
128
133
129
- mobj2 = pickle .loads (pickle .dumps (mobj ))
130
- self .assertEqual (mobj .tobytes (), mobj2 .tobytes ())
131
- self .assertNotEqual (mobj ._pointer , mobj2 ._pointer )
134
+ mobj_reconstructed = pickle .loads (pickle .dumps (mobj ))
135
+ self .assertEqual (mobj .tobytes (), mobj_reconstructed .tobytes ())
136
+ self .assertNotEqual (mobj ._pointer , mobj_reconstructed ._pointer )
137
+
132
138
133
139
class TestMemoryUSMBase :
134
140
""" Base tests for MemoryUSM* """
@@ -175,7 +181,5 @@ class TestMemoryUSMDevice(TestMemoryUSMBase, unittest.TestCase):
175
181
usm_type = "device"
176
182
177
183
178
-
179
-
180
184
if __name__ == "__main__" :
181
185
unittest .main ()
0 commit comments