File tree 3 files changed +5
-1
lines changed
3 files changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -20,6 +20,7 @@ const char* const kEnvTpuvmMode = "TPUVM_MODE";
20
20
const char * const kEnvPjRtDevice = " PJRT_DEVICE" ;
21
21
const char * const kEnvPjRtTpuMaxInflightComputations =
22
22
" PJRT_TPU_MAX_INFLIGHT_COMPUTATIONS" ;
23
+ const char * const kEnvAsyncCpuClient = " CPU_ASYNC_CLIENT" ;
23
24
24
25
} // namespace env
25
26
} // namespace xla
Original file line number Diff line number Diff line change @@ -20,6 +20,7 @@ extern const char* const kEnvStartService;
20
20
extern const char * const kEnvTpuvmMode ;
21
21
extern const char * const kEnvPjRtDevice ;
22
22
extern const char * const kEnvPjRtTpuMaxInflightComputations ;
23
+ extern const char * const kEnvAsyncCpuClient ;
23
24
24
25
} // namespace env
25
26
} // namespace xla
Original file line number Diff line number Diff line change @@ -46,7 +46,9 @@ PjRtComputationClient::PjRtComputationClient() {
46
46
std::string device_type = sys_util::GetEnvString (env::kEnvPjRtDevice , " " );
47
47
if (device_type == " CPU" ) {
48
48
TF_VLOG (1 ) << " Initializing PjRt CPU client..." ;
49
- client_ = std::move (xla::GetCpuClient (/* asynchronous=*/ false ).ValueOrDie ());
49
+ bool async = sys_util::GetEnvBool (env::kEnvAsyncCpuClient , true );
50
+ int cpu_device_count = sys_util::GetEnvInt (env::kEnvNumCpu , 1 );
51
+ client_ = std::move (xla::GetTfrtCpuClient (async, cpu_device_count).ValueOrDie ());
50
52
} else if (device_type == " TPU" ) {
51
53
TF_VLOG (1 ) << " Initializing PjRt TPU client..." ;
52
54
int64_t max_inflight_computations = sys_util::GetEnvInt (
You can’t perform that action at this time.
0 commit comments