@@ -148,8 +148,32 @@ def get_distribution_strategy(distribution_strategy="mirrored",
148
148
149
149
if distribution_strategy == "tpu" :
150
150
# When tpu_address is an empty string, we communicate with local TPUs.
151
- cluster_resolver = tpu_initialize (tpu_address )
152
- return tf .distribute .TPUStrategy (cluster_resolver )
151
+ # Bug workaround that in v5p we need to explicitly specify the device
152
+ # assignment when using tpu strategy, adding device assignment to the
153
+ # strategy.
154
+ cluster_resolver = tf .distribute .cluster_resolver .TPUClusterResolver (
155
+ tpu = tpu_address
156
+ )
157
+ if tpu_address not in ("" , "local" ):
158
+ tf .config .experimental_connect_to_cluster (cluster_resolver )
159
+ topology = tf .tpu .experimental .initialize_tpu_system (cluster_resolver )
160
+
161
+ device_assignment = None
162
+ if hasattr (tf .tpu .experimental , "HardWareFeature" ):
163
+ hardware_feature = tf .tpu .experimental .HardWareFeature (
164
+ cluster_resolver .tpu_hardware_feature
165
+ )
166
+ if (
167
+ hardware_feature .embedding_feature
168
+ == tf .tpu .experimental .HardwareFeature .EmbeddingFeature .V2
169
+ ):
170
+ tpu_metadata = cluster_resolver .get_tpu_system_metadata ()
171
+ device_assignment = tf .tpu .experimental .DeviceAssignment .build (
172
+ topology , num_replicas = tpu_metadata .num_cores
173
+ )
174
+
175
+ return tf .distribute .TPUStrategy (
176
+ cluster_resolver , experimental_device_assignment = device_assignment )
153
177
154
178
if distribution_strategy == "multi_worker_mirrored" :
155
179
return tf .distribute .experimental .MultiWorkerMirroredStrategy (
0 commit comments