Skip to content

Use tfrt cpu client #3898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions test/pjrt/test_experimental_pjrt_multi_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
import collections
import torch
import torch_xla
from absl.testing import absltest, parameterized
import torch_xla.core.xla_model as xm
import torch_xla.core.xla_env_vars as xenv
from torch_xla.experimental import pjrt


class TestExperimentalPjrtMultiCpu(parameterized.TestCase):

def setUp(self):
pjrt.set_device_type('CPU')

os.environ.pop(xenv.CPU_NUM_DEVICES, None)
os.environ.pop(xenv.PJRT_CPU_ASYNC_CLIENT, None)

def test_default_cpu_device(self):
expected = {0: {0: torch.device('xla:0'),}}
devices_per_process = pjrt.run_multiprocess(xm.xla_device)
self.assertDictEqual(devices_per_process, expected)

def test_multi_cpu_devices(self):
expected = {
0: {
0: torch.device('xla:0'),
1: torch.device('xla:1'),
2: torch.device('xla:2'),
3: torch.device('xla:3')
}
}
os.environ.update({
xenv.PJRT_CPU_ASYNC_CLIENT: 'true',
xenv.CPU_NUM_DEVICES: '4',
})
devices_per_process = pjrt.run_multiprocess(xm.xla_device)
self.assertDictEqual(devices_per_process, expected)


if __name__ == '__main__':
absltest.main()
3 changes: 2 additions & 1 deletion test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ function run_xla_backend_mp {

function run_pjrt {
echo "Running in PjRt runtime: $@"
PJRT_DEVICE=CPU run_test "$@"
# TODO(darisoy): run these tests with multiple CPU devices, this fails due to TF issue.
PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_test "$@"
}

function run_async_scalar {
Expand Down
1 change: 1 addition & 0 deletions third_party/xla_client/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ cc_library(
"//tensorflow/compiler/xla/pjrt:cpu_device",
"//tensorflow/compiler/xla/pjrt:tpu_client",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"//tensorflow/compiler/xla/pjrt:tfrt_cpu_pjrt_client",
"//tensorflow/compiler/xla/rpc:grpc_stub",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/compiler/xla/service:platform_util",
Expand Down
1 change: 1 addition & 0 deletions third_party/xla_client/env_vars.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const char* const kEnvTpuvmMode = "TPUVM_MODE";
const char* const kEnvPjRtDevice = "PJRT_DEVICE";
const char* const kEnvPjRtTpuMaxInflightComputations =
"PJRT_TPU_MAX_INFLIGHT_COMPUTATIONS";
const char* const kEnvPjrtAsyncCpuClient = "PJRT_CPU_ASYNC_CLIENT";

} // namespace env
} // namespace xla
1 change: 1 addition & 0 deletions third_party/xla_client/env_vars.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ extern const char* const kEnvStartService;
extern const char* const kEnvTpuvmMode;
extern const char* const kEnvPjRtDevice;
extern const char* const kEnvPjRtTpuMaxInflightComputations;
extern const char* const kEnvPjrtAsyncCpuClient;

} // namespace env
} // namespace xla
Expand Down
5 changes: 4 additions & 1 deletion third_party/xla_client/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "tensorflow/compiler/xla/pjrt/cpu_device.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/tpu_client.h"
#include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/xla_client/computation_client.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
Expand Down Expand Up @@ -46,7 +47,9 @@ PjRtComputationClient::PjRtComputationClient() {
std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, "");
if (device_type == "CPU") {
TF_VLOG(1) << "Initializing PjRt CPU client...";
client_ = std::move(xla::GetCpuClient(/*asynchronous=*/false).ValueOrDie());
bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncCpuClient, true);
int cpu_device_count = sys_util::GetEnvInt(env::kEnvNumCpu, 1);
client_ = std::move(xla::GetTfrtCpuClient(async, cpu_device_count).ValueOrDie());
} else if (device_type == "TPU") {
TF_VLOG(1) << "Initializing PjRt TPU client...";
int64_t max_inflight_computations = sys_util::GetEnvInt(
Expand Down
1 change: 1 addition & 0 deletions torch_xla/core/xla_env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@
TPU_PROCESS_ADDRESSES = 'TPU_PROCESS_ADDRESSES'
TPU_VISIBLE_DEVICES = 'TPU_VISIBLE_DEVICES'
TPU_PROCESS_PORT = 'TPU_PROCESS_PORT'
PJRT_CPU_ASYNC_CLIENT = 'PJRT_CPU_ASYNC_CLIENT'