15
15
"""Task for the Ranking model."""
16
16
17
17
import math
18
- from typing import Dict , List , Optional , Union
18
+ from typing import Dict , List , Optional , Union , Tuple
19
19
20
20
import tensorflow as tf , tf_keras
21
21
import tensorflow_recommenders as tfrs
@@ -35,8 +35,14 @@ def _get_tpu_embedding_feature_config(
35
35
vocab_sizes : List [int ],
36
36
embedding_dim : Union [int , List [int ]],
37
37
table_name_prefix : str = 'embedding_table' ,
38
- batch_size : Optional [int ] = None
39
- ) -> Dict [str , tf .tpu .experimental .embedding .FeatureConfig ]:
38
+ batch_size : Optional [int ] = None ,
39
+ max_ids_per_chip_per_sample : Optional [int ] = None ,
40
+ max_ids_per_table : Optional [Union [int , List [int ]]] = None ,
41
+ max_unique_ids_per_table : Optional [Union [int , List [int ]]] = None ,
42
+ ) -> Tuple [
43
+ Dict [str , tf .tpu .experimental .embedding .FeatureConfig ],
44
+ Optional [tf .tpu .experimental .embedding .SparseCoreEmbeddingConfig ],
45
+ ]:
40
46
"""Returns TPU embedding feature config.
41
47
42
48
i'th table config will have vocab size of vocab_sizes[i] and embedding
@@ -47,37 +53,97 @@ def _get_tpu_embedding_feature_config(
47
53
embedding_dim: An integer or a list of embedding table dimensions.
48
54
table_name_prefix: a prefix for embedding tables.
49
55
batch_size: Per-replica batch size.
56
+ max_ids_per_chip_per_sample: Maximum number of embedding ids per chip per
57
+ sample.
58
+ max_ids_per_table: Maximum number of embedding ids per table.
59
+ max_unique_ids_per_table: Maximum number of unique embedding ids per table.
60
+
50
61
Returns:
51
62
A dictionary of feature_name, FeatureConfig pairs.
52
63
"""
53
64
if isinstance (embedding_dim , List ):
54
65
if len (vocab_sizes ) != len (embedding_dim ):
55
66
raise ValueError (
56
67
f'length of vocab_sizes: { len (vocab_sizes )} is not equal to the '
57
- f'length of embedding_dim: { len (embedding_dim )} ' )
68
+ f'length of embedding_dim: { len (embedding_dim )} '
69
+ )
58
70
elif isinstance (embedding_dim , int ):
59
71
embedding_dim = [embedding_dim ] * len (vocab_sizes )
60
72
else :
61
- raise ValueError ('embedding_dim is not either a list or an int, got '
62
- f'{ type (embedding_dim )} ' )
73
+ raise ValueError (
74
+ 'embedding_dim is not either a list or an int, got '
75
+ f'{ type (embedding_dim )} '
76
+ )
77
+
78
+ if isinstance (max_ids_per_table , List ):
79
+ if len (vocab_sizes ) != len (max_ids_per_table ):
80
+ raise ValueError (
81
+ f'length of vocab_sizes: { len (vocab_sizes )} is not equal to the '
82
+ f'length of max_ids_per_table: { len (max_ids_per_table )} '
83
+ )
84
+ elif isinstance (max_ids_per_table , int ):
85
+ max_ids_per_table = [max_ids_per_table ] * len (vocab_sizes )
86
+ elif max_ids_per_table is not None :
87
+ raise ValueError (
88
+ 'max_ids_per_table is not either a list or an int or None, got '
89
+ f'{ type (max_ids_per_table )} '
90
+ )
91
+
92
+ if isinstance (max_unique_ids_per_table , List ):
93
+ if len (vocab_sizes ) != len (max_unique_ids_per_table ):
94
+ raise ValueError (
95
+ f'length of vocab_sizes: { len (vocab_sizes )} is not equal to the '
96
+ 'length of max_unique_ids_per_table: '
97
+ f'{ len (max_unique_ids_per_table )} '
98
+ )
99
+ elif isinstance (max_unique_ids_per_table , int ):
100
+ max_unique_ids_per_table = [max_unique_ids_per_table ] * len (vocab_sizes )
101
+ elif max_unique_ids_per_table is not None :
102
+ raise ValueError (
103
+ 'max_unique_ids_per_table is not either a list or an int or None, '
104
+ f'got { type (max_unique_ids_per_table )} '
105
+ )
63
106
64
107
feature_config = {}
108
+ sparsecore_config = None
109
+ max_ids_per_table_dict = {}
110
+ max_unique_ids_per_table_dict = {}
65
111
66
112
for i , vocab_size in enumerate (vocab_sizes ):
67
113
table_config = tf .tpu .experimental .embedding .TableConfig (
68
114
vocabulary_size = vocab_size ,
69
115
dim = embedding_dim [i ],
70
116
combiner = 'mean' ,
71
117
initializer = tf .initializers .TruncatedNormal (
72
- mean = 0.0 , stddev = 1 / math .sqrt (embedding_dim [i ])),
73
- name = table_name_prefix + '_%02d' % i )
118
+ mean = 0.0 , stddev = 1 / math .sqrt (embedding_dim [i ])
119
+ ),
120
+ name = table_name_prefix + '_%02d' % i ,
121
+ )
74
122
feature_config [str (i )] = tf .tpu .experimental .embedding .FeatureConfig (
75
123
name = str (i ),
76
124
table = table_config ,
77
125
output_shape = [batch_size ] if batch_size else None ,
78
126
)
127
+ if max_ids_per_table :
128
+ max_ids_per_table_dict [str (table_name_prefix + '_%02d' % i )] = (
129
+ max_ids_per_table [i ]
130
+ )
131
+ if max_unique_ids_per_table :
132
+ max_unique_ids_per_table_dict [str (table_name_prefix + '_%02d' % i )] = (
133
+ max_unique_ids_per_table [i ]
134
+ )
79
135
80
- return feature_config
136
+ if all ((max_ids_per_chip_per_sample , max_ids_per_table ,
137
+ max_unique_ids_per_table )):
138
+ sparsecore_config = tf .tpu .experimental .embedding .SparseCoreEmbeddingConfig (
139
+ disable_table_stacking = False ,
140
+ max_ids_per_chip_per_sample = max_ids_per_chip_per_sample ,
141
+ max_ids_per_table = max_ids_per_table_dict ,
142
+ max_unique_ids_per_table = max_unique_ids_per_table_dict ,
143
+ allow_id_dropping = False ,
144
+ )
145
+
146
+ return feature_config , sparsecore_config
81
147
82
148
83
149
class RankingTask (base_task .Task ):
@@ -173,25 +239,33 @@ def build_model(self) -> tf_keras.Model:
173
239
decay_start_steps = dense_lr_config .decay_start_steps )
174
240
dense_optimizer .learning_rate = dense_lr_callable
175
241
176
- feature_config = _get_tpu_embedding_feature_config (
177
- embedding_dim = self .task_config .model .embedding_dim ,
178
- vocab_sizes = self .task_config .model .vocab_sizes ,
179
- batch_size = self .task_config .train_data .global_batch_size
180
- // tf .distribute .get_strategy ().num_replicas_in_sync ,
242
+ feature_config , sparse_core_embedding_config = (
243
+ _get_tpu_embedding_feature_config (
244
+ embedding_dim = self .task_config .model .embedding_dim ,
245
+ vocab_sizes = self .task_config .model .vocab_sizes ,
246
+ batch_size = self .task_config .train_data .global_batch_size
247
+ // tf .distribute .get_strategy ().num_replicas_in_sync ,
248
+ max_ids_per_chip_per_sample = self .task_config .model .max_ids_per_chip_per_sample ,
249
+ max_ids_per_table = self .task_config .model .max_ids_per_table ,
250
+ max_unique_ids_per_table = self .task_config .model .max_unique_ids_per_table ,
251
+ )
181
252
)
182
253
183
- if self .task_config .model .use_multi_hot :
184
- embedding_layer = tfrs .layers .embedding .tpu_embedding_layer .TPUEmbedding (
254
+ # to work around PartialTPUEmbedding issue in v5p and to enable multi hot
255
+ # features
256
+ if self .task_config .model .use_partial_tpu_embedding :
257
+ embedding_layer = tfrs .experimental .layers .embedding .PartialTPUEmbedding (
185
258
feature_config = feature_config ,
186
259
optimizer = embedding_optimizer ,
187
260
pipeline_execution_with_tensor_core = self .trainer_config .pipeline_sparse_and_dense_execution ,
261
+ size_threshold = self .task_config .model .size_threshold ,
188
262
)
189
263
else :
190
- embedding_layer = tfrs .experimental . layers .embedding .PartialTPUEmbedding (
264
+ embedding_layer = tfrs .layers .embedding .tpu_embedding_layer . TPUEmbedding (
191
265
feature_config = feature_config ,
192
266
optimizer = embedding_optimizer ,
193
267
pipeline_execution_with_tensor_core = self .trainer_config .pipeline_sparse_and_dense_execution ,
194
- size_threshold = self . task_config . model . size_threshold ,
268
+ sparse_core_embedding_config = sparse_core_embedding_config ,
195
269
)
196
270
197
271
if self .task_config .model .interaction == 'dot' :
0 commit comments