Skip to content

Commit a983743

Browse files
ZhaoyueChengtensorflower-gardener
authored andcommitted
add parameters for SparseCore Embedding Config on v5p to run DLRM models on cloud v5p for tutorial
PiperOrigin-RevId: 650336609
1 parent 485ec40 commit a983743

File tree

3 files changed

+177
-48
lines changed

3 files changed

+177
-48
lines changed

official/recommendation/ranking/configs/config.py

+10
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@ class ModelConfig(hyperparams.Config):
112112
module
113113
dcn_use_bias: Flag to determine whether to use bias for the dcn interaction
114114
module
115+
use_partial_tpu_embedding: Flag to determine whether to use partial tpu
116+
embedding layer or not.
117+
max_ids_per_chip_per_sample: Maximum number of ids per chip per sample.
118+
max_ids_per_table: Maximum number of ids per table.
119+
max_unique_ids_per_table: Maximum number of unique ids per table.
115120
"""
116121
num_dense_features: int = 13
117122
vocab_sizes: List[int] = dataclasses.field(default_factory=list)
@@ -128,6 +133,10 @@ class ModelConfig(hyperparams.Config):
128133
dcn_kernel_initializer: str = 'truncated_normal'
129134
dcn_bias_initializer: str = 'zeros'
130135
dcn_use_bias: bool = True
136+
use_partial_tpu_embedding: bool = True
137+
max_ids_per_chip_per_sample: int | None = None
138+
max_ids_per_table: Union[int, List[int]] | None = None
139+
max_unique_ids_per_table: Union[int, List[int]] | None = None
131140

132141

133142
@dataclasses.dataclass
@@ -424,6 +433,7 @@ def dlrm_dcn_v2_criteo_tb_config() -> Config:
424433
dcn_use_bias=True,
425434
concat_dense=False,
426435
use_multi_hot=True,
436+
use_partial_tpu_embedding=False,
427437
multi_hot_sizes=multi_hot_sizes,
428438
),
429439
loss=Loss(label_smoothing=0.0),

official/recommendation/ranking/task.py

+92-18
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""Task for the Ranking model."""
1616

1717
import math
18-
from typing import Dict, List, Optional, Union
18+
from typing import Dict, List, Optional, Union, Tuple
1919

2020
import tensorflow as tf, tf_keras
2121
import tensorflow_recommenders as tfrs
@@ -35,8 +35,14 @@ def _get_tpu_embedding_feature_config(
3535
vocab_sizes: List[int],
3636
embedding_dim: Union[int, List[int]],
3737
table_name_prefix: str = 'embedding_table',
38-
batch_size: Optional[int] = None
39-
) -> Dict[str, tf.tpu.experimental.embedding.FeatureConfig]:
38+
batch_size: Optional[int] = None,
39+
max_ids_per_chip_per_sample: Optional[int] = None,
40+
max_ids_per_table: Optional[Union[int, List[int]]] = None,
41+
max_unique_ids_per_table: Optional[Union[int, List[int]]] = None,
42+
) -> Tuple[
43+
Dict[str, tf.tpu.experimental.embedding.FeatureConfig],
44+
Optional[tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig],
45+
]:
4046
"""Returns TPU embedding feature config.
4147
4248
i'th table config will have vocab size of vocab_sizes[i] and embedding
@@ -47,37 +53,97 @@ def _get_tpu_embedding_feature_config(
4753
embedding_dim: An integer or a list of embedding table dimensions.
4854
table_name_prefix: a prefix for embedding tables.
4955
batch_size: Per-replica batch size.
56+
max_ids_per_chip_per_sample: Maximum number of embedding ids per chip per
57+
sample.
58+
max_ids_per_table: Maximum number of embedding ids per table.
59+
max_unique_ids_per_table: Maximum number of unique embedding ids per table.
60+
5061
Returns:
5162
A dictionary of feature_name, FeatureConfig pairs.
5263
"""
5364
if isinstance(embedding_dim, List):
5465
if len(vocab_sizes) != len(embedding_dim):
5566
raise ValueError(
5667
f'length of vocab_sizes: {len(vocab_sizes)} is not equal to the '
57-
f'length of embedding_dim: {len(embedding_dim)}')
68+
f'length of embedding_dim: {len(embedding_dim)}'
69+
)
5870
elif isinstance(embedding_dim, int):
5971
embedding_dim = [embedding_dim] * len(vocab_sizes)
6072
else:
61-
raise ValueError('embedding_dim is not either a list or an int, got '
62-
f'{type(embedding_dim)}')
73+
raise ValueError(
74+
'embedding_dim is not either a list or an int, got '
75+
f'{type(embedding_dim)}'
76+
)
77+
78+
if isinstance(max_ids_per_table, List):
79+
if len(vocab_sizes) != len(max_ids_per_table):
80+
raise ValueError(
81+
f'length of vocab_sizes: {len(vocab_sizes)} is not equal to the '
82+
f'length of max_ids_per_table: {len(max_ids_per_table)}'
83+
)
84+
elif isinstance(max_ids_per_table, int):
85+
max_ids_per_table = [max_ids_per_table] * len(vocab_sizes)
86+
elif max_ids_per_table is not None:
87+
raise ValueError(
88+
'max_ids_per_table is not either a list or an int or None, got '
89+
f'{type(max_ids_per_table)}'
90+
)
91+
92+
if isinstance(max_unique_ids_per_table, List):
93+
if len(vocab_sizes) != len(max_unique_ids_per_table):
94+
raise ValueError(
95+
f'length of vocab_sizes: {len(vocab_sizes)} is not equal to the '
96+
'length of max_unique_ids_per_table: '
97+
f'{len(max_unique_ids_per_table)}'
98+
)
99+
elif isinstance(max_unique_ids_per_table, int):
100+
max_unique_ids_per_table = [max_unique_ids_per_table] * len(vocab_sizes)
101+
elif max_unique_ids_per_table is not None:
102+
raise ValueError(
103+
'max_unique_ids_per_table is not either a list or an int or None, '
104+
f'got {type(max_unique_ids_per_table)}'
105+
)
63106

64107
feature_config = {}
108+
sparsecore_config = None
109+
max_ids_per_table_dict = {}
110+
max_unique_ids_per_table_dict = {}
65111

66112
for i, vocab_size in enumerate(vocab_sizes):
67113
table_config = tf.tpu.experimental.embedding.TableConfig(
68114
vocabulary_size=vocab_size,
69115
dim=embedding_dim[i],
70116
combiner='mean',
71117
initializer=tf.initializers.TruncatedNormal(
72-
mean=0.0, stddev=1 / math.sqrt(embedding_dim[i])),
73-
name=table_name_prefix + '_%02d' % i)
118+
mean=0.0, stddev=1 / math.sqrt(embedding_dim[i])
119+
),
120+
name=table_name_prefix + '_%02d' % i,
121+
)
74122
feature_config[str(i)] = tf.tpu.experimental.embedding.FeatureConfig(
75123
name=str(i),
76124
table=table_config,
77125
output_shape=[batch_size] if batch_size else None,
78126
)
127+
if max_ids_per_table:
128+
max_ids_per_table_dict[str(table_name_prefix + '_%02d' % i)] = (
129+
max_ids_per_table[i]
130+
)
131+
if max_unique_ids_per_table:
132+
max_unique_ids_per_table_dict[str(table_name_prefix + '_%02d' % i)] = (
133+
max_unique_ids_per_table[i]
134+
)
79135

80-
return feature_config
136+
if all((max_ids_per_chip_per_sample, max_ids_per_table,
137+
max_unique_ids_per_table)):
138+
sparsecore_config = tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig(
139+
disable_table_stacking=False,
140+
max_ids_per_chip_per_sample=max_ids_per_chip_per_sample,
141+
max_ids_per_table=max_ids_per_table_dict,
142+
max_unique_ids_per_table=max_unique_ids_per_table_dict,
143+
allow_id_dropping=False,
144+
)
145+
146+
return feature_config, sparsecore_config
81147

82148

83149
class RankingTask(base_task.Task):
@@ -173,25 +239,33 @@ def build_model(self) -> tf_keras.Model:
173239
decay_start_steps=dense_lr_config.decay_start_steps)
174240
dense_optimizer.learning_rate = dense_lr_callable
175241

176-
feature_config = _get_tpu_embedding_feature_config(
177-
embedding_dim=self.task_config.model.embedding_dim,
178-
vocab_sizes=self.task_config.model.vocab_sizes,
179-
batch_size=self.task_config.train_data.global_batch_size
180-
// tf.distribute.get_strategy().num_replicas_in_sync,
242+
feature_config, sparse_core_embedding_config = (
243+
_get_tpu_embedding_feature_config(
244+
embedding_dim=self.task_config.model.embedding_dim,
245+
vocab_sizes=self.task_config.model.vocab_sizes,
246+
batch_size=self.task_config.train_data.global_batch_size
247+
// tf.distribute.get_strategy().num_replicas_in_sync,
248+
max_ids_per_chip_per_sample=self.task_config.model.max_ids_per_chip_per_sample,
249+
max_ids_per_table=self.task_config.model.max_ids_per_table,
250+
max_unique_ids_per_table=self.task_config.model.max_unique_ids_per_table,
251+
)
181252
)
182253

183-
if self.task_config.model.use_multi_hot:
184-
embedding_layer = tfrs.layers.embedding.tpu_embedding_layer.TPUEmbedding(
254+
# to work around PartialTPUEmbedding issue in v5p and to enable multi hot
255+
# features
256+
if self.task_config.model.use_partial_tpu_embedding:
257+
embedding_layer = tfrs.experimental.layers.embedding.PartialTPUEmbedding(
185258
feature_config=feature_config,
186259
optimizer=embedding_optimizer,
187260
pipeline_execution_with_tensor_core=self.trainer_config.pipeline_sparse_and_dense_execution,
261+
size_threshold=self.task_config.model.size_threshold,
188262
)
189263
else:
190-
embedding_layer = tfrs.experimental.layers.embedding.PartialTPUEmbedding(
264+
embedding_layer = tfrs.layers.embedding.tpu_embedding_layer.TPUEmbedding(
191265
feature_config=feature_config,
192266
optimizer=embedding_optimizer,
193267
pipeline_execution_with_tensor_core=self.trainer_config.pipeline_sparse_and_dense_execution,
194-
size_threshold=self.task_config.model.size_threshold,
268+
sparse_core_embedding_config=sparse_core_embedding_config,
195269
)
196270

197271
if self.task_config.model.interaction == 'dot':

0 commit comments

Comments
 (0)