File tree 3 files changed +33
-0
lines changed
3 files changed +33
-0
lines changed Original file line number Diff line number Diff line change
1
+ import os
2
+
3
+ import torch
4
+ import torch_xla
5
+ from absl .testing import absltest , parameterized
6
+ import torch_xla .core .xla_env_vars as xenv
7
+
8
+ class TestExperimentalPjrtMultiCpu (parameterized .TestCase ):
9
+
10
+ def setUp (self ):
11
+ pjrt .set_device_type ('CPU' )
12
+
13
+ os .environ .pop (xenv .CPU_ASYNC_CLIENT , None )
14
+ os .environ .pop (xenv .CPU_NUM_DEVICES , None )
15
+
16
+ def test_default_cpu_device (self ):
17
+ devices_per_process = pjrt .run_multiprocess (xm .xla_device )
18
+ print (devices_per_process )
19
+
20
+ def test_multi_cpu_devices (self ):
21
+ os .environ .update ({
22
+ xenv .CPU_ASYNC_CLIENT : True ,
23
+ xenv .CPU_NUM_DEVICES : 4 ,
24
+ })
25
+ devices_per_process = pjrt .run_multiprocess (xm .xla_device )
26
+ print (devices_per_process )
27
+
28
+
29
+
30
+ if __name__ == '__main__' :
31
+ absltest .main ()
Original file line number Diff line number Diff line change @@ -150,6 +150,7 @@ cc_library(
150
150
"//tensorflow/compiler/xla/pjrt:cpu_device" ,
151
151
"//tensorflow/compiler/xla/pjrt:tpu_client" ,
152
152
"//tensorflow/compiler/xla/pjrt:pjrt_client" ,
153
+ "//tensorflow/compiler/xla/pjrt:tfrt_cpu_pjrt_client" ,
153
154
"//tensorflow/compiler/xla/rpc:grpc_stub" ,
154
155
"//tensorflow/compiler/xla/service:cpu_plugin" ,
155
156
"//tensorflow/compiler/xla/service:platform_util" ,
Original file line number Diff line number Diff line change 10
10
#include " tensorflow/compiler/xla/pjrt/cpu_device.h"
11
11
#include " tensorflow/compiler/xla/pjrt/pjrt_client.h"
12
12
#include " tensorflow/compiler/xla/pjrt/tpu_client.h"
13
+ #include " tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h"
13
14
#include " tensorflow/compiler/xla/shape.h"
14
15
#include " tensorflow/compiler/xla/xla_client/computation_client.h"
15
16
#include " tensorflow/compiler/xla/xla_client/debug_macros.h"
You can’t perform that action at this time.
0 commit comments