Skip to content

Commit 9f09c7e

Browse files
committed
Implement PJRT multi CPU tests
1 parent 39f7e41 commit 9f09c7e

File tree

2 files changed

+32
-20
lines changed

2 files changed

+32
-20
lines changed
+31-20
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,42 @@
11
import os
2-
2+
import collections
33
import torch
44
import torch_xla
55
from absl.testing import absltest, parameterized
6+
import torch_xla.core.xla_model as xm
67
import torch_xla.core.xla_env_vars as xenv
8+
from torch_xla.experimental import pjrt
79

8-
class TestExperimentalPjrtMultiCpu(parameterized.TestCase):
9-
10-
def setUp(self):
11-
pjrt.set_device_type('CPU')
12-
13-
os.environ.pop(xenv.CPU_ASYNC_CLIENT, None)
14-
os.environ.pop(xenv.CPU_NUM_DEVICES, None)
1510

16-
def test_default_cpu_device(self):
17-
devices_per_process = pjrt.run_multiprocess(xm.xla_device)
18-
print(devices_per_process)
19-
20-
def test_multi_cpu_devices(self):
21-
os.environ.update({
22-
xenv.CPU_ASYNC_CLIENT: True,
23-
xenv.CPU_NUM_DEVICES: 4,
24-
})
25-
devices_per_process = pjrt.run_multiprocess(xm.xla_device)
26-
print(devices_per_process)
11+
class TestExperimentalPjrtMultiCpu(parameterized.TestCase):
2712

13+
def setUp(self):
14+
pjrt.set_device_type('CPU')
15+
16+
os.environ.pop(xenv.CPU_NUM_DEVICES, None)
17+
os.environ.pop(xenv.CPU_ASYNC_CLIENT, None)
18+
19+
def test_default_cpu_device(self):
20+
expected = {0: {0: torch.device('xla:0'),}}
21+
devices_per_process = pjrt.run_multiprocess(xm.xla_device)
22+
self.assertDictEqual(devices_per_process, expected)
23+
24+
def test_multi_cpu_devices(self):
25+
expected = {
26+
0: {
27+
0: torch.device('xla:0'),
28+
1: torch.device('xla:1'),
29+
2: torch.device('xla:2'),
30+
3: torch.device('xla:3')
31+
}
32+
}
33+
os.environ.update({
34+
xenv.CPU_ASYNC_CLIENT: 'true',
35+
xenv.CPU_NUM_DEVICES: '4',
36+
})
37+
devices_per_process = pjrt.run_multiprocess(xm.xla_device)
38+
self.assertDictEqual(devices_per_process, expected)
2839

2940

3041
if __name__ == '__main__':
31-
absltest.main()
42+
absltest.main()

torch_xla/core/xla_env_vars.py

+1
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@
2626
TPU_PROCESS_ADDRESSES = 'TPU_PROCESS_ADDRESSES'
2727
TPU_VISIBLE_DEVICES = 'TPU_VISIBLE_DEVICES'
2828
TPU_PROCESS_PORT = 'TPU_PROCESS_PORT'
29+
CPU_ASYNC_CLIENT = 'CPU_ASYNC_CLIENT'

0 commit comments

Comments
 (0)