1
1
import os
2
-
2
+ import collections
3
3
import torch
4
4
import torch_xla
5
5
from absl .testing import absltest , parameterized
6
+ import torch_xla .core .xla_model as xm
6
7
import torch_xla .core .xla_env_vars as xenv
8
+ from torch_xla .experimental import pjrt
7
9
8
- class TestExperimentalPjrtMultiCpu (parameterized .TestCase ):
9
-
10
- def setUp (self ):
11
- pjrt .set_device_type ('CPU' )
12
-
13
- os .environ .pop (xenv .CPU_ASYNC_CLIENT , None )
14
- os .environ .pop (xenv .CPU_NUM_DEVICES , None )
15
10
16
- def test_default_cpu_device (self ):
17
- devices_per_process = pjrt .run_multiprocess (xm .xla_device )
18
- print (devices_per_process )
19
-
20
- def test_multi_cpu_devices (self ):
21
- os .environ .update ({
22
- xenv .CPU_ASYNC_CLIENT : True ,
23
- xenv .CPU_NUM_DEVICES : 4 ,
24
- })
25
- devices_per_process = pjrt .run_multiprocess (xm .xla_device )
26
- print (devices_per_process )
11
+ class TestExperimentalPjrtMultiCpu (parameterized .TestCase ):
27
12
13
+ def setUp (self ):
14
+ pjrt .set_device_type ('CPU' )
15
+
16
+ os .environ .pop (xenv .CPU_NUM_DEVICES , None )
17
+ os .environ .pop (xenv .CPU_ASYNC_CLIENT , None )
18
+
19
+ def test_default_cpu_device (self ):
20
+ expected = {0 : {0 : torch .device ('xla:0' ),}}
21
+ devices_per_process = pjrt .run_multiprocess (xm .xla_device )
22
+ self .assertDictEqual (devices_per_process , expected )
23
+
24
+ def test_multi_cpu_devices (self ):
25
+ expected = {
26
+ 0 : {
27
+ 0 : torch .device ('xla:0' ),
28
+ 1 : torch .device ('xla:1' ),
29
+ 2 : torch .device ('xla:2' ),
30
+ 3 : torch .device ('xla:3' )
31
+ }
32
+ }
33
+ os .environ .update ({
34
+ xenv .CPU_ASYNC_CLIENT : 'true' ,
35
+ xenv .CPU_NUM_DEVICES : '4' ,
36
+ })
37
+ devices_per_process = pjrt .run_multiprocess (xm .xla_device )
38
+ self .assertDictEqual (devices_per_process , expected )
28
39
29
40
30
41
if __name__ == '__main__' :
31
- absltest .main ()
42
+ absltest .main ()
0 commit comments