File tree 7 files changed +2
-53
lines changed
7 files changed +2
-53
lines changed Load Diff This file was deleted.
Original file line number Diff line number Diff line change @@ -77,8 +77,7 @@ function run_xla_backend_mp {
77
77
78
78
function run_pjrt {
79
79
echo " Running in PjRt runtime: $@ "
80
- # TODO(darisoy): run these tests with multiple CPU devices, this fails due to TF issue.
81
- PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_test " $@ "
80
+ PJRT_DEVICE=CPU run_test " $@ "
82
81
}
83
82
84
83
function run_async_scalar {
Original file line number Diff line number Diff line change @@ -150,7 +150,6 @@ cc_library(
150
150
"//tensorflow/compiler/xla/pjrt:cpu_device" ,
151
151
"//tensorflow/compiler/xla/pjrt:tpu_client" ,
152
152
"//tensorflow/compiler/xla/pjrt:pjrt_client" ,
153
- "//tensorflow/compiler/xla/pjrt:tfrt_cpu_pjrt_client" ,
154
153
"//tensorflow/compiler/xla/rpc:grpc_stub" ,
155
154
"//tensorflow/compiler/xla/service:cpu_plugin" ,
156
155
"//tensorflow/compiler/xla/service:platform_util" ,
Original file line number Diff line number Diff line change @@ -20,7 +20,6 @@ const char* const kEnvTpuvmMode = "TPUVM_MODE";
20
20
const char * const kEnvPjRtDevice = " PJRT_DEVICE" ;
21
21
const char * const kEnvPjRtTpuMaxInflightComputations =
22
22
" PJRT_TPU_MAX_INFLIGHT_COMPUTATIONS" ;
23
- const char * const kEnvPjrtAsyncCpuClient = " PJRT_CPU_ASYNC_CLIENT" ;
24
23
25
24
} // namespace env
26
25
} // namespace xla
Original file line number Diff line number Diff line change @@ -20,7 +20,6 @@ extern const char* const kEnvStartService;
20
20
extern const char * const kEnvTpuvmMode ;
21
21
extern const char * const kEnvPjRtDevice ;
22
22
extern const char * const kEnvPjRtTpuMaxInflightComputations ;
23
- extern const char * const kEnvPjrtAsyncCpuClient ;
24
23
25
24
} // namespace env
26
25
} // namespace xla
Original file line number Diff line number Diff line change 9
9
#include " tensorflow/compiler/xla/literal.h"
10
10
#include " tensorflow/compiler/xla/pjrt/cpu_device.h"
11
11
#include " tensorflow/compiler/xla/pjrt/pjrt_client.h"
12
- #include " tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h"
13
12
#include " tensorflow/compiler/xla/pjrt/tpu_client.h"
14
13
#include " tensorflow/compiler/xla/shape.h"
15
14
#include " tensorflow/compiler/xla/xla_client/computation_client.h"
@@ -47,10 +46,7 @@ PjRtComputationClient::PjRtComputationClient() {
47
46
std::string device_type = sys_util::GetEnvString (env::kEnvPjRtDevice , " " );
48
47
if (device_type == " CPU" ) {
49
48
TF_VLOG (1 ) << " Initializing PjRt CPU client..." ;
50
- bool async = sys_util::GetEnvBool (env::kEnvPjrtAsyncCpuClient , true );
51
- int cpu_device_count = sys_util::GetEnvInt (env::kEnvNumCpu , 1 );
52
- client_ =
53
- std::move (xla::GetTfrtCpuClient (async, cpu_device_count).ValueOrDie ());
49
+ client_ = std::move (xla::GetCpuClient (/* asynchronous=*/ false ).ValueOrDie ());
54
50
} else if (device_type == " TPU" ) {
55
51
TF_VLOG (1 ) << " Initializing PjRt TPU client..." ;
56
52
int64_t max_inflight_computations = sys_util::GetEnvInt (
Original file line number Diff line number Diff line change 26
26
TPU_PROCESS_ADDRESSES = 'TPU_PROCESS_ADDRESSES'
27
27
TPU_VISIBLE_CHIPS = 'TPU_VISIBLE_CHIPS'
28
28
TPU_PROCESS_PORT = 'TPU_PROCESS_PORT'
29
- PJRT_CPU_ASYNC_CLIENT = 'PJRT_CPU_ASYNC_CLIENT'
You can’t perform that action at this time.
0 commit comments