Skip to content

Commit 7b9996f

Browse files
Internal change
PiperOrigin-RevId: 507827478
1 parent c0e9168 commit 7b9996f

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

official/nlp/docs/tfhub.md

+11-11
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ encoder_inputs = dict(
7272
)
7373
encoder_outputs = encoder(encoder_inputs)
7474
assert encoder_outputs.keys() == {
75-
"pooled_output", # Shape [batch_size, width], dtype=float32
76-
"default", # Alias for "pooled_output" (aligns with other models)
77-
"sequence_output", # Shape [batch_size, seq_length, width], dtype=float32
78-
"encoder_outputs", # List of Tensors with outputs of all transformer layers
75+
"pooled_output", # Shape [batch_size, width], dtype=float32
76+
"default", # Alias for "pooled_output" (aligns with other models)
77+
"sequence_output", # Shape [batch_size, seq_length, width], dtype=float32
78+
"encoder_outputs", # List of Tensors with outputs of all transformer layers
7979
}
8080
```
8181

@@ -170,10 +170,10 @@ mlm_inputs = dict(
170170
)
171171
mlm_outputs = encoder.mlm(mlm_inputs)
172172
assert mlm_outputs.keys() == {
173-
"pooled_output", # Shape [batch, width], dtype=float32
174-
"sequence_output", # Shape [batch, seq_length, width], dtype=float32
175-
"encoder_outputs", # List of Tensors with outputs of all transformer layers
176-
"mlm_logits" # Shape [batch, num_predictions, vocab_size], dtype=float32
173+
"pooled_output", # Shape [batch, width], dtype=float32
174+
"sequence_output", # Shape [batch, seq_length, width], dtype=float32
175+
"encoder_outputs", # List of Tensors with outputs of all transformer layers
176+
"mlm_logits" # Shape [batch, num_predictions, vocab_size], dtype=float32
177177
}
178178
```
179179

@@ -246,9 +246,9 @@ preprocessor = hub.load(...)
246246
text_input = ... # Shape [batch_size], dtype=tf.string
247247
encoder_inputs = preprocessor(text_input, seq_length=seq_length)
248248
assert encoder_inputs.keys() == {
249-
"input_word_ids", # Shape [batch_size, seq_length], dtype=int32
250-
"input_mask", # Shape [batch_size, seq_length], dtype=int32
251-
"input_type_ids" # Shape [batch_size, seq_length], dtype=int32
249+
"input_word_ids", # Shape [batch_size, seq_length], dtype=int32
250+
"input_mask", # Shape [batch_size, seq_length], dtype=int32
251+
"input_type_ids" # Shape [batch_size, seq_length], dtype=int32
252252
}
253253
```
254254

0 commit comments

Comments
 (0)