Skip to content

Commit d6e6483

Browse files
added memory.py to expose dpclt.memory module
modularized test, + changes per black
1 parent 9271ff7 commit d6e6483

File tree

2 files changed

+20
-13
lines changed

2 files changed

+20
-13
lines changed

dpctl/memory.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from ._memory import MemoryUSMShared, MemoryUSMDevice, MemoryUSMHost
2+
3+
__all__ = ["MemoryUSMShared", "MemoryUSMDevice", "MemoryUSMHost"]

dpctl/tests/test_sycl_usm.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,20 @@ def test_memory_create(self):
3636
queue = dpctl.get_current_queue()
3737
mobj = MemoryUSMShared(nbytes, queue)
3838
self.assertEqual(mobj.nbytes, nbytes)
39-
self.assertTrue(hasattr(mobj, '__sycl_usm_array_interface__'))
39+
self.assertTrue(hasattr(mobj, "__sycl_usm_array_interface__"))
4040

4141
def _create_memory(self):
4242
nbytes = 1024
4343
queue = dpctl.get_current_queue()
4444
mobj = MemoryUSMShared(nbytes, queue)
4545
return mobj
4646

47+
def _create_host_buf(self, nbytes):
48+
ba = bytearray(nbytes)
49+
for i in range(nbytes):
50+
ba[i] = (i % 32) + ord("a")
51+
return ba
52+
4753
@unittest.skipUnless(
4854
dpctl.has_sycl_platforms(), "No SYCL devices except the default host device."
4955
)
@@ -97,9 +103,7 @@ def test_buffer_protocol(self):
97103
)
98104
def test_copy_host_roundtrip(self):
99105
mobj = self._create_memory()
100-
host_src_obj = bytearray(mobj.nbytes)
101-
for i in range(mobj.nbytes):
102-
host_src_obj[i] = (i % 32) + ord('a')
106+
host_src_obj = self._create_host_buf(mobj.nbytes)
103107
mobj.copy_from_host(host_src_obj)
104108
host_dest_obj = mobj.copy_to_host()
105109
del mobj
@@ -113,22 +117,24 @@ def test_zero_copy(self):
113117
mobj2 = type(mobj)(mobj)
114118

115119
self.assertTrue(mobj2.reference_obj is mobj)
116-
self.assertTrue(mobj2.__sycl_usm_array_interface__['data'] == mobj.__sycl_usm_array_interface__['data'])
120+
mobj_data = mobj.__sycl_usm_array_interface__["data"]
121+
mobj2_data = mobj2.__sycl_usm_array_interface__["data"]
122+
self.assertEqual(mobj_data, mobj2_data)
117123

118124
@unittest.skipUnless(
119125
dpctl.has_sycl_platforms(), "No SYCL devices except the default host device."
120126
)
121127
def test_pickling(self):
122128
import pickle
129+
123130
mobj = self._create_memory()
124-
host_src_obj = bytearray(mobj.nbytes)
125-
for i in range(mobj.nbytes):
126-
host_src_obj[i] = (i % 32) + ord('a')
131+
host_src_obj = self._create_host_buf(mobj.nbytes)
127132
mobj.copy_from_host(host_src_obj)
128133

129-
mobj2 = pickle.loads(pickle.dumps(mobj))
130-
self.assertEqual(mobj.tobytes(), mobj2.tobytes())
131-
self.assertNotEqual(mobj._pointer, mobj2._pointer)
134+
mobj_reconstructed = pickle.loads(pickle.dumps(mobj))
135+
self.assertEqual(mobj.tobytes(), mobj_reconstructed.tobytes())
136+
self.assertNotEqual(mobj._pointer, mobj_reconstructed._pointer)
137+
132138

133139
class TestMemoryUSMBase:
134140
""" Base tests for MemoryUSM* """
@@ -175,7 +181,5 @@ class TestMemoryUSMDevice(TestMemoryUSMBase, unittest.TestCase):
175181
usm_type = "device"
176182

177183

178-
179-
180184
if __name__ == "__main__":
181185
unittest.main()

0 commit comments

Comments
 (0)