diff --git a/test/pjrt/test_experimental_pjrt_multi_cpu.py b/test/pjrt/test_experimental_pjrt_multi_cpu.py new file mode 100644 index 000000000000..c4ae7eb7cb2a --- /dev/null +++ b/test/pjrt/test_experimental_pjrt_multi_cpu.py @@ -0,0 +1,42 @@ +import os +import collections +import torch +import torch_xla +from absl.testing import absltest, parameterized +import torch_xla.core.xla_model as xm +import torch_xla.core.xla_env_vars as xenv +from torch_xla.experimental import pjrt + + +class TestExperimentalPjrtMultiCpu(parameterized.TestCase): + + def setUp(self): + pjrt.set_device_type('CPU') + + os.environ.pop(xenv.CPU_NUM_DEVICES, None) + os.environ.pop(xenv.PJRT_CPU_ASYNC_CLIENT, None) + + def test_default_cpu_device(self): + expected = {0: {0: torch.device('xla:0'),}} + devices_per_process = pjrt.run_multiprocess(xm.xla_device) + self.assertDictEqual(devices_per_process, expected) + + def test_multi_cpu_devices(self): + expected = { + 0: { + 0: torch.device('xla:0'), + 1: torch.device('xla:1'), + 2: torch.device('xla:2'), + 3: torch.device('xla:3') + } + } + os.environ.update({ + xenv.PJRT_CPU_ASYNC_CLIENT: 'true', + xenv.CPU_NUM_DEVICES: '4', + }) + devices_per_process = pjrt.run_multiprocess(xm.xla_device) + self.assertDictEqual(devices_per_process, expected) + + +if __name__ == '__main__': + absltest.main() diff --git a/test/run_tests.sh b/test/run_tests.sh index de5e75b1a693..b97b2779d3ca 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -77,7 +77,8 @@ function run_xla_backend_mp { function run_pjrt { echo "Running in PjRt runtime: $@" - PJRT_DEVICE=CPU run_test "$@" + # TODO(darisoy): run these tests with multiple CPU devices, this fails due to TF issue. + PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_test "$@" } function run_async_scalar { diff --git a/third_party/xla_client/BUILD b/third_party/xla_client/BUILD index 14d4c7a50523..5cb8835cd8be 100644 --- a/third_party/xla_client/BUILD +++ b/third_party/xla_client/BUILD @@ -150,6 +150,7 @@ cc_library( "//tensorflow/compiler/xla/pjrt:cpu_device", "//tensorflow/compiler/xla/pjrt:tpu_client", "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/compiler/xla/pjrt:tfrt_cpu_pjrt_client", "//tensorflow/compiler/xla/rpc:grpc_stub", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/service:platform_util", diff --git a/third_party/xla_client/env_vars.cc b/third_party/xla_client/env_vars.cc index 06d395f345f2..16d7d216bebf 100644 --- a/third_party/xla_client/env_vars.cc +++ b/third_party/xla_client/env_vars.cc @@ -20,6 +20,7 @@ const char* const kEnvTpuvmMode = "TPUVM_MODE"; const char* const kEnvPjRtDevice = "PJRT_DEVICE"; const char* const kEnvPjRtTpuMaxInflightComputations = "PJRT_TPU_MAX_INFLIGHT_COMPUTATIONS"; +const char* const kEnvPjrtAsyncCpuClient = "PJRT_CPU_ASYNC_CLIENT"; } // namespace env } // namespace xla diff --git a/third_party/xla_client/env_vars.h b/third_party/xla_client/env_vars.h index 1f6b279a88b0..3c9763aa1cc9 100644 --- a/third_party/xla_client/env_vars.h +++ b/third_party/xla_client/env_vars.h @@ -20,6 +20,7 @@ extern const char* const kEnvStartService; extern const char* const kEnvTpuvmMode; extern const char* const kEnvPjRtDevice; extern const char* const kEnvPjRtTpuMaxInflightComputations; +extern const char* const kEnvPjrtAsyncCpuClient; } // namespace env } // namespace xla diff --git a/third_party/xla_client/pjrt_computation_client.cc b/third_party/xla_client/pjrt_computation_client.cc index 552bd07d39da..96295304212a 100644 --- a/third_party/xla_client/pjrt_computation_client.cc +++ b/third_party/xla_client/pjrt_computation_client.cc @@ -10,6 +10,7 @@ #include "tensorflow/compiler/xla/pjrt/cpu_device.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/pjrt/tpu_client.h" +#include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/xla_client/computation_client.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" @@ -46,7 +47,9 @@ PjRtComputationClient::PjRtComputationClient() { std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, ""); if (device_type == "CPU") { TF_VLOG(1) << "Initializing PjRt CPU client..."; - client_ = std::move(xla::GetCpuClient(/*asynchronous=*/false).ValueOrDie()); + bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncCpuClient, true); + int cpu_device_count = sys_util::GetEnvInt(env::kEnvNumCpu, 1); + client_ = std::move(xla::GetTfrtCpuClient(async, cpu_device_count).ValueOrDie()); } else if (device_type == "TPU") { TF_VLOG(1) << "Initializing PjRt TPU client..."; int64_t max_inflight_computations = sys_util::GetEnvInt( diff --git a/torch_xla/core/xla_env_vars.py b/torch_xla/core/xla_env_vars.py index 83d681b80afa..c12a91386d37 100644 --- a/torch_xla/core/xla_env_vars.py +++ b/torch_xla/core/xla_env_vars.py @@ -26,3 +26,4 @@ TPU_PROCESS_ADDRESSES = 'TPU_PROCESS_ADDRESSES' TPU_VISIBLE_DEVICES = 'TPU_VISIBLE_DEVICES' TPU_PROCESS_PORT = 'TPU_PROCESS_PORT' +PJRT_CPU_ASYNC_CLIENT = 'PJRT_CPU_ASYNC_CLIENT'