@@ -72,10 +72,10 @@ encoder_inputs = dict(
72
72
)
73
73
encoder_outputs = encoder(encoder_inputs)
74
74
assert encoder_outputs.keys() == {
75
- " pooled_output" , # Shape [batch_size, width], dtype=float32
76
- " default" , # Alias for "pooled_output" (aligns with other models)
77
- " sequence_output" , # Shape [batch_size, seq_length, width], dtype=float32
78
- " encoder_outputs" , # List of Tensors with outputs of all transformer layers
75
+ " pooled_output" , # Shape [batch_size, width], dtype=float32
76
+ " default" , # Alias for "pooled_output" (aligns with other models)
77
+ " sequence_output" , # Shape [batch_size, seq_length, width], dtype=float32
78
+ " encoder_outputs" , # List of Tensors with outputs of all transformer layers
79
79
}
80
80
```
81
81
@@ -170,10 +170,10 @@ mlm_inputs = dict(
170
170
)
171
171
mlm_outputs = encoder.mlm(mlm_inputs)
172
172
assert mlm_outputs.keys() == {
173
- " pooled_output" , # Shape [batch, width], dtype=float32
174
- " sequence_output" , # Shape [batch, seq_length, width], dtype=float32
175
- " encoder_outputs" , # List of Tensors with outputs of all transformer layers
176
- " mlm_logits" # Shape [batch, num_predictions, vocab_size], dtype=float32
173
+ " pooled_output" , # Shape [batch, width], dtype=float32
174
+ " sequence_output" , # Shape [batch, seq_length, width], dtype=float32
175
+ " encoder_outputs" , # List of Tensors with outputs of all transformer layers
176
+ " mlm_logits" # Shape [batch, num_predictions, vocab_size], dtype=float32
177
177
}
178
178
```
179
179
@@ -246,9 +246,9 @@ preprocessor = hub.load(...)
246
246
text_input = ... # Shape [batch_size], dtype=tf.string
247
247
encoder_inputs = preprocessor(text_input, seq_length = seq_length)
248
248
assert encoder_inputs.keys() == {
249
- " input_word_ids" , # Shape [batch_size, seq_length], dtype=int32
250
- " input_mask" , # Shape [batch_size, seq_length], dtype=int32
251
- " input_type_ids" # Shape [batch_size, seq_length], dtype=int32
249
+ " input_word_ids" , # Shape [batch_size, seq_length], dtype=int32
250
+ " input_mask" , # Shape [batch_size, seq_length], dtype=int32
251
+ " input_type_ids" # Shape [batch_size, seq_length], dtype=int32
252
252
}
253
253
```
254
254
0 commit comments