Skip to content

Commit 39f7e41

Browse files
committed
Minor build fixes and test skeleton
1 parent 9cf650f commit 39f7e41

File tree

3 files changed

+33
-0
lines changed

3 files changed

+33
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import os
2+
3+
import torch
4+
import torch_xla
5+
from absl.testing import absltest, parameterized
6+
import torch_xla.core.xla_env_vars as xenv
7+
8+
class TestExperimentalPjrtMultiCpu(parameterized.TestCase):
9+
10+
def setUp(self):
11+
pjrt.set_device_type('CPU')
12+
13+
os.environ.pop(xenv.CPU_ASYNC_CLIENT, None)
14+
os.environ.pop(xenv.CPU_NUM_DEVICES, None)
15+
16+
def test_default_cpu_device(self):
17+
devices_per_process = pjrt.run_multiprocess(xm.xla_device)
18+
print(devices_per_process)
19+
20+
def test_multi_cpu_devices(self):
21+
os.environ.update({
22+
xenv.CPU_ASYNC_CLIENT: True,
23+
xenv.CPU_NUM_DEVICES: 4,
24+
})
25+
devices_per_process = pjrt.run_multiprocess(xm.xla_device)
26+
print(devices_per_process)
27+
28+
29+
30+
if __name__ == '__main__':
31+
absltest.main()

third_party/xla_client/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ cc_library(
150150
"//tensorflow/compiler/xla/pjrt:cpu_device",
151151
"//tensorflow/compiler/xla/pjrt:tpu_client",
152152
"//tensorflow/compiler/xla/pjrt:pjrt_client",
153+
"//tensorflow/compiler/xla/pjrt:tfrt_cpu_pjrt_client",
153154
"//tensorflow/compiler/xla/rpc:grpc_stub",
154155
"//tensorflow/compiler/xla/service:cpu_plugin",
155156
"//tensorflow/compiler/xla/service:platform_util",

third_party/xla_client/pjrt_computation_client.cc

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "tensorflow/compiler/xla/pjrt/cpu_device.h"
1111
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
1212
#include "tensorflow/compiler/xla/pjrt/tpu_client.h"
13+
#include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h"
1314
#include "tensorflow/compiler/xla/shape.h"
1415
#include "tensorflow/compiler/xla/xla_client/computation_client.h"
1516
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"

0 commit comments

Comments
 (0)