Skip to content

Commit b7e3aff

Browse files
ZhaoyueChengtensorflower-gardener
authored andcommitted
Add sparse_core_embedding_config to TPUEmbeddingLayer and add device assignment to VF in TPU distribute strategy.
PiperOrigin-RevId: 646175965
1 parent 952b8f0 commit b7e3aff

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

official/common/distribute_utils.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,32 @@ def get_distribution_strategy(distribution_strategy="mirrored",
148148

149149
if distribution_strategy == "tpu":
150150
# When tpu_address is an empty string, we communicate with local TPUs.
151-
cluster_resolver = tpu_initialize(tpu_address)
152-
return tf.distribute.TPUStrategy(cluster_resolver)
151+
# Bug workaround that in v5p we need to explicitly specify the device
152+
# assignment when using tpu strategy, adding device assignment to the
153+
# strategy.
154+
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
155+
tpu=tpu_address
156+
)
157+
if tpu_address not in ("", "local"):
158+
tf.config.experimental_connect_to_cluster(cluster_resolver)
159+
topology = tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
160+
161+
device_assignment = None
162+
if hasattr(tf.tpu.experimental, "HardWareFeature"):
163+
hardware_feature = tf.tpu.experimental.HardWareFeature(
164+
cluster_resolver.tpu_hardware_feature
165+
)
166+
if (
167+
hardware_feature.embedding_feature
168+
== tf.tpu.experimental.HardwareFeature.EmbeddingFeature.V2
169+
):
170+
tpu_metadata = cluster_resolver.get_tpu_system_metadata()
171+
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
172+
topology, num_replicas=tpu_metadata.num_cores
173+
)
174+
175+
return tf.distribute.TPUStrategy(
176+
cluster_resolver, experimental_device_assignment=device_assignment)
153177

154178
if distribution_strategy == "multi_worker_mirrored":
155179
return tf.distribute.experimental.MultiWorkerMirroredStrategy(

0 commit comments

Comments
 (0)