Skip to content

Commit a6ea0f3

Browse files
authored
Use tfrt cpu client (#3898)
* Use TFRT CPU client and make parameters modifiable through environment variables. * Minor build fixes and test skeleton * Implement PJRT multi CPU tests * Clarify env var naming and work around for failing test
1 parent 35e6c49 commit a6ea0f3

File tree

7 files changed

+52
-2
lines changed

7 files changed

+52
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import os
2+
import collections
3+
import torch
4+
import torch_xla
5+
from absl.testing import absltest, parameterized
6+
import torch_xla.core.xla_model as xm
7+
import torch_xla.core.xla_env_vars as xenv
8+
from torch_xla.experimental import pjrt
9+
10+
11+
class TestExperimentalPjrtMultiCpu(parameterized.TestCase):
12+
13+
def setUp(self):
14+
pjrt.set_device_type('CPU')
15+
16+
os.environ.pop(xenv.CPU_NUM_DEVICES, None)
17+
os.environ.pop(xenv.PJRT_CPU_ASYNC_CLIENT, None)
18+
19+
def test_default_cpu_device(self):
20+
expected = {0: {0: torch.device('xla:0'),}}
21+
devices_per_process = pjrt.run_multiprocess(xm.xla_device)
22+
self.assertDictEqual(devices_per_process, expected)
23+
24+
def test_multi_cpu_devices(self):
25+
expected = {
26+
0: {
27+
0: torch.device('xla:0'),
28+
1: torch.device('xla:1'),
29+
2: torch.device('xla:2'),
30+
3: torch.device('xla:3')
31+
}
32+
}
33+
os.environ.update({
34+
xenv.PJRT_CPU_ASYNC_CLIENT: 'true',
35+
xenv.CPU_NUM_DEVICES: '4',
36+
})
37+
devices_per_process = pjrt.run_multiprocess(xm.xla_device)
38+
self.assertDictEqual(devices_per_process, expected)
39+
40+
41+
if __name__ == '__main__':
42+
absltest.main()

test/run_tests.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ function run_xla_backend_mp {
7777

7878
function run_pjrt {
7979
echo "Running in PjRt runtime: $@"
80-
PJRT_DEVICE=CPU run_test "$@"
80+
# TODO(darisoy): run these tests with multiple CPU devices, this fails due to TF issue.
81+
PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_test "$@"
8182
}
8283

8384
function run_async_scalar {

third_party/xla_client/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ cc_library(
150150
"//tensorflow/compiler/xla/pjrt:cpu_device",
151151
"//tensorflow/compiler/xla/pjrt:tpu_client",
152152
"//tensorflow/compiler/xla/pjrt:pjrt_client",
153+
"//tensorflow/compiler/xla/pjrt:tfrt_cpu_pjrt_client",
153154
"//tensorflow/compiler/xla/rpc:grpc_stub",
154155
"//tensorflow/compiler/xla/service:cpu_plugin",
155156
"//tensorflow/compiler/xla/service:platform_util",

third_party/xla_client/env_vars.cc

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ const char* const kEnvTpuvmMode = "TPUVM_MODE";
2020
const char* const kEnvPjRtDevice = "PJRT_DEVICE";
2121
const char* const kEnvPjRtTpuMaxInflightComputations =
2222
"PJRT_TPU_MAX_INFLIGHT_COMPUTATIONS";
23+
const char* const kEnvPjrtAsyncCpuClient = "PJRT_CPU_ASYNC_CLIENT";
2324

2425
} // namespace env
2526
} // namespace xla

third_party/xla_client/env_vars.h

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ extern const char* const kEnvStartService;
2020
extern const char* const kEnvTpuvmMode;
2121
extern const char* const kEnvPjRtDevice;
2222
extern const char* const kEnvPjRtTpuMaxInflightComputations;
23+
extern const char* const kEnvPjrtAsyncCpuClient;
2324

2425
} // namespace env
2526
} // namespace xla

third_party/xla_client/pjrt_computation_client.cc

+4-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "tensorflow/compiler/xla/pjrt/cpu_device.h"
1111
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
1212
#include "tensorflow/compiler/xla/pjrt/tpu_client.h"
13+
#include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h"
1314
#include "tensorflow/compiler/xla/shape.h"
1415
#include "tensorflow/compiler/xla/xla_client/computation_client.h"
1516
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
@@ -46,7 +47,9 @@ PjRtComputationClient::PjRtComputationClient() {
4647
std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, "");
4748
if (device_type == "CPU") {
4849
TF_VLOG(1) << "Initializing PjRt CPU client...";
49-
client_ = std::move(xla::GetCpuClient(/*asynchronous=*/false).ValueOrDie());
50+
bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncCpuClient, true);
51+
int cpu_device_count = sys_util::GetEnvInt(env::kEnvNumCpu, 1);
52+
client_ = std::move(xla::GetTfrtCpuClient(async, cpu_device_count).ValueOrDie());
5053
} else if (device_type == "TPU") {
5154
TF_VLOG(1) << "Initializing PjRt TPU client...";
5255
int64_t max_inflight_computations = sys_util::GetEnvInt(

torch_xla/core/xla_env_vars.py

+1
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@
2626
TPU_PROCESS_ADDRESSES = 'TPU_PROCESS_ADDRESSES'
2727
TPU_VISIBLE_DEVICES = 'TPU_VISIBLE_DEVICES'
2828
TPU_PROCESS_PORT = 'TPU_PROCESS_PORT'
29+
PJRT_CPU_ASYNC_CLIENT = 'PJRT_CPU_ASYNC_CLIENT'

0 commit comments

Comments
 (0)