Skip to content

Commit f23740f

Browse files
committed
Revert "Use tfrt cpu client (#3898)"
This reverts commit a6ea0f3.
1 parent 1f79e32 commit f23740f

File tree

7 files changed

+2
-53
lines changed

7 files changed

+2
-53
lines changed

test/pjrt/test_experimental_pjrt_multi_cpu.py

-42
This file was deleted.

test/run_tests.sh

+1-2
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ function run_xla_backend_mp {
7777

7878
function run_pjrt {
7979
echo "Running in PjRt runtime: $@"
80-
# TODO(darisoy): run these tests with multiple CPU devices, this fails due to TF issue.
81-
PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_test "$@"
80+
PJRT_DEVICE=CPU run_test "$@"
8281
}
8382

8483
function run_async_scalar {

third_party/xla_client/BUILD

-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ cc_library(
150150
"//tensorflow/compiler/xla/pjrt:cpu_device",
151151
"//tensorflow/compiler/xla/pjrt:tpu_client",
152152
"//tensorflow/compiler/xla/pjrt:pjrt_client",
153-
"//tensorflow/compiler/xla/pjrt:tfrt_cpu_pjrt_client",
154153
"//tensorflow/compiler/xla/rpc:grpc_stub",
155154
"//tensorflow/compiler/xla/service:cpu_plugin",
156155
"//tensorflow/compiler/xla/service:platform_util",

third_party/xla_client/env_vars.cc

-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ const char* const kEnvTpuvmMode = "TPUVM_MODE";
2020
const char* const kEnvPjRtDevice = "PJRT_DEVICE";
2121
const char* const kEnvPjRtTpuMaxInflightComputations =
2222
"PJRT_TPU_MAX_INFLIGHT_COMPUTATIONS";
23-
const char* const kEnvPjrtAsyncCpuClient = "PJRT_CPU_ASYNC_CLIENT";
2423

2524
} // namespace env
2625
} // namespace xla

third_party/xla_client/env_vars.h

-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ extern const char* const kEnvStartService;
2020
extern const char* const kEnvTpuvmMode;
2121
extern const char* const kEnvPjRtDevice;
2222
extern const char* const kEnvPjRtTpuMaxInflightComputations;
23-
extern const char* const kEnvPjrtAsyncCpuClient;
2423

2524
} // namespace env
2625
} // namespace xla

third_party/xla_client/pjrt_computation_client.cc

+1-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include "tensorflow/compiler/xla/literal.h"
1010
#include "tensorflow/compiler/xla/pjrt/cpu_device.h"
1111
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
12-
#include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h"
1312
#include "tensorflow/compiler/xla/pjrt/tpu_client.h"
1413
#include "tensorflow/compiler/xla/shape.h"
1514
#include "tensorflow/compiler/xla/xla_client/computation_client.h"
@@ -47,10 +46,7 @@ PjRtComputationClient::PjRtComputationClient() {
4746
std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, "");
4847
if (device_type == "CPU") {
4948
TF_VLOG(1) << "Initializing PjRt CPU client...";
50-
bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncCpuClient, true);
51-
int cpu_device_count = sys_util::GetEnvInt(env::kEnvNumCpu, 1);
52-
client_ =
53-
std::move(xla::GetTfrtCpuClient(async, cpu_device_count).ValueOrDie());
49+
client_ = std::move(xla::GetCpuClient(/*asynchronous=*/false).ValueOrDie());
5450
} else if (device_type == "TPU") {
5551
TF_VLOG(1) << "Initializing PjRt TPU client...";
5652
int64_t max_inflight_computations = sys_util::GetEnvInt(

torch_xla/core/xla_env_vars.py

-1
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,3 @@
2626
TPU_PROCESS_ADDRESSES = 'TPU_PROCESS_ADDRESSES'
2727
TPU_VISIBLE_CHIPS = 'TPU_VISIBLE_CHIPS'
2828
TPU_PROCESS_PORT = 'TPU_PROCESS_PORT'
29-
PJRT_CPU_ASYNC_CLIENT = 'PJRT_CPU_ASYNC_CLIENT'

0 commit comments

Comments
 (0)