Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make feature converter cls configurable #942

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions t5/data/preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3072,3 +3072,124 @@ def noise_token_to_random_token_or_sentinel(
tokens, noise_mask, vocabulary, seeds=seeds[1:]),
noise_token_to_sentinel(
tokens, noise_mask, vocabulary, seeds=()))


# =============== EXPERIMENTAL preprocessors (not used for the T5 paper) =======


def trim_and_pad_dataset(dataset, sequence_length):
"""A wrapper to use `seqio.utils.trim_and_pad_dataset` as a preprocessor."""
return seqio.utils.trim_and_pad_dataset(
dataset, feature_lengths=sequence_length)


def targets_for_prefix_lm_objective(dataset, sequence_length, output_features):
"""Prepares targets to be used for prefix LM objective."""
dataset = select_random_chunk(
dataset, output_features, max_length=65536, feature_key='targets')
dataset = seqio.preprocessors.append_eos(dataset, output_features)
dataset = reduce_concat_tokens(dataset, batch_size=128)
dataset = trim_and_pad_dataset(dataset, sequence_length)
return dataset


def pack_prefix_lm_encoder_decoder(ds, sequence_length, pad_id=0):
"""Pack two examples into one with the prefix LM objective."""
packed_length = next(iter(sequence_length.values()))
assert packed_length % 2 == 0
assert all(l == packed_length for l in sequence_length.values())

@seqio.utils.map_over_dataset(num_seeds=1)
def pack_examples(example_pair, seed):
split_point = tf.random.stateless_uniform((),
minval=1,
maxval=packed_length,
seed=seed,
dtype=tf.int32)
inputs = tf.concat([
example_pair['targets'][0][:split_point],
example_pair['targets'][1][:packed_length - split_point]
],
axis=0)
inputs = tf.reshape(inputs, (packed_length,))
targets = tf.concat([
example_pair['targets'][0][split_point:],
example_pair['targets'][1][packed_length - split_point:]
],
axis=0)
targets = tf.reshape(targets, (packed_length,))

encoder_segment_ids = tf.cast(
tf.range(packed_length) >= split_point, tf.int32) + 1
decoder_segment_ids = tf.cast(
tf.range(packed_length) >= (packed_length - split_point), tf.int32) + 1

decoder_input_tokens = seqio.utils.make_autoregressive_inputs(
targets, sequence_id=decoder_segment_ids)

encoder_positions = tf.concat(
[tf.range(split_point),
tf.range(packed_length - split_point)], axis=0)
encoder_positions = tf.reshape(encoder_positions, (packed_length,))
decoder_positions = tf.concat(
[tf.range(packed_length - split_point),
tf.range(split_point)], axis=0)
decoder_positions = tf.reshape(decoder_positions, (packed_length,))
decoder_loss_weights = tf.cast(
tf.not_equal(targets, pad_id), dtype=tf.int32)
return {
'encoder_input_tokens': inputs,
'decoder_target_tokens': targets,
'decoder_input_tokens': decoder_input_tokens,
'encoder_segment_ids': encoder_segment_ids,
'encoder_positions': encoder_positions,
'decoder_segment_ids': decoder_segment_ids,
'decoder_positions': decoder_positions,
'decoder_loss_weights': decoder_loss_weights,
}

# Note that the batch requires the lengths to be the same.
return pack_examples(ds.batch(2))


def pack_prefix_lm_decoder_only(ds,
sequence_length,
loss_on_targets_only=True,
pad_id=0):
"""Randomly split the tokens for the prefix LM objective."""
packed_length = next(iter(sequence_length.values()))
assert packed_length % 2 == 0
assert all(l == packed_length for l in sequence_length.values())

@seqio.utils.map_over_dataset(num_seeds=1)
def pack_examples(example, seed):
split_point = tf.random.stateless_uniform((),
minval=1,
maxval=packed_length,
seed=seed,
dtype=tf.int32)
decoder_target_tokens = example['targets']
decoder_input_tokens = seqio.utils.make_autoregressive_inputs(
decoder_target_tokens)

if loss_on_targets_only:
decoder_loss_weights = tf.cast(
tf.range(packed_length) >= split_point, tf.int32)
else:
decoder_loss_weights = tf.ones((packed_length,), dtype=tf.int32)

padding_mask = tf.cast(
tf.not_equal(decoder_target_tokens, pad_id), dtype=tf.int32)
decoder_loss_weights *= padding_mask

decoder_causal_attention = tf.cast(
tf.range(packed_length) <= split_point, tf.int32)

return {
'decoder_target_tokens': decoder_target_tokens,
'decoder_input_tokens': decoder_input_tokens,
'decoder_loss_weights': decoder_loss_weights,
'decoder_causal_attention': decoder_causal_attention,
}

return pack_examples(ds)
120 changes: 120 additions & 0 deletions t5/data/preprocessors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1979,6 +1979,126 @@ def test_select_random_chunk_different_sizes(self):
additional_feature_keys=['inputs'], max_length=4)
_ = list(dataset.as_numpy_iterator())

def test_pack_prefix_lm_encoder_decoder(self):

x = [{'targets': [0, 1, 2, 3, 4, 5, 6, 7]},
{'targets': [8, 9, 10, 11, 12, 13, 14, 15]},
{'targets': [16, 17, 18, 19, 20, 21, 22, 23]},
{'targets': [24, 25, 26, 27, 28, 29, 30, 31]}]
ds = test_utils.create_default_dataset(
x, feature_names=['targets'], output_shapes={'targets': [8]})

# With this seed, split points are 3 and 5
with seqio.utils.map_seed_manager(2):
packed_ds = prep.pack_prefix_lm_encoder_decoder(
ds, {'inputs': 8, 'targets': 8})

expected = [
{
'encoder_input_tokens': [0, 1, 2, 8, 9, 10, 11, 12],
'decoder_target_tokens': [3, 4, 5, 6, 7, 13, 14, 15],
# The first token of the second sequence (in this case index 5)
# should be 0 instead of the last token of the first sequence.
'decoder_input_tokens': [0, 3, 4, 5, 6, 0, 13, 14],
'encoder_segment_ids': [1, 1, 1, 2, 2, 2, 2, 2],
'encoder_positions': [0, 1, 2, 0, 1, 2, 3, 4],
'decoder_loss_weights': [1, 1, 1, 1, 1, 1, 1, 1],
'decoder_segment_ids': [1, 1, 1, 1, 1, 2, 2, 2],
'decoder_positions': [0, 1, 2, 3, 4, 0, 1, 2],
},
{
'encoder_input_tokens': [16, 17, 18, 19, 20, 24, 25, 26],
'decoder_target_tokens': [21, 22, 23, 27, 28, 29, 30, 31],
'decoder_input_tokens': [0, 21, 22, 0, 27, 28, 29, 30],
'encoder_segment_ids': [1, 1, 1, 1, 1, 2, 2, 2],
'encoder_positions': [0, 1, 2, 3, 4, 0, 1, 2],
'decoder_loss_weights': [1, 1, 1, 1, 1, 1, 1, 1],
'decoder_segment_ids': [1, 1, 1, 2, 2, 2, 2, 2],
'decoder_positions': [0, 1, 2, 0, 1, 2, 3, 4],
}
]
assert_dataset(packed_ds, expected)

def test_pack_prefix_lm_encoder_decoder_with_padding(self):
x = [{'targets': [9, 1, 2, 3, 4, 5, 6, 0]},
{'targets': [8, 9, 10, 11, 12, 13, 0, 0]}]
ds = test_utils.create_default_dataset(
x, feature_names=['targets'], output_shapes={'targets': [8]})

# With this seed, split point is 3.
with seqio.utils.map_seed_manager(2):
packed_ds = prep.pack_prefix_lm_encoder_decoder(
ds, {'inputs': 8, 'targets': 8})

expected = [
{
'encoder_input_tokens': [9, 1, 2, 8, 9, 10, 11, 12],
'decoder_target_tokens': [3, 4, 5, 6, 0, 13, 0, 0],
'decoder_input_tokens': [0, 3, 4, 5, 6, 0, 13, 0],
'encoder_segment_ids': [1, 1, 1, 2, 2, 2, 2, 2],
'encoder_positions': [0, 1, 2, 0, 1, 2, 3, 4],
'decoder_loss_weights': [1, 1, 1, 1, 0, 1, 0, 0],
'decoder_segment_ids': [1, 1, 1, 1, 1, 2, 2, 2],
'decoder_positions': [0, 1, 2, 3, 4, 0, 1, 2],
},
]
assert_dataset(packed_ds, expected)

def test_pack_prefix_lm_decoder_only(self):
x = [{'targets': [9, 1, 2, 3, 4, 5, 6, 7]},
{'targets': [8, 9, 10, 11, 12, 13, 14, 15]}]
ds = test_utils.create_default_dataset(x, feature_names=['targets'])

# With this seed, split points are 3 and 5.
with seqio.utils.map_seed_manager(2):
packed_ds = prep.pack_prefix_lm_decoder_only(ds, {'length': 8})

expected = [{
'decoder_target_tokens': [9, 1, 2, 3, 4, 5, 6, 7],
'decoder_input_tokens': [0, 9, 1, 2, 3, 4, 5, 6],
'decoder_loss_weights': [0, 0, 0, 1, 1, 1, 1, 1],
'decoder_causal_attention': [1, 1, 1, 1, 0, 0, 0, 0],
}, {
'decoder_target_tokens': [8, 9, 10, 11, 12, 13, 14, 15],
'decoder_input_tokens': [0, 8, 9, 10, 11, 12, 13, 14],
'decoder_loss_weights': [0, 0, 0, 0, 0, 1, 1, 1],
'decoder_causal_attention': [1, 1, 1, 1, 1, 1, 0, 0],
}]
assert_dataset(packed_ds, expected)

def test_pack_prefix_lm_decoder_only_with_padding(self):
x = [{'targets': [8, 9, 10, 11, 12, 13, 0, 0]}]
ds = test_utils.create_default_dataset(x, feature_names=['targets'])

# With this seed, split point is 3.
with seqio.utils.map_seed_manager(2):
packed_ds = prep.pack_prefix_lm_decoder_only(ds, {'length': 8})

expected = [{
'decoder_target_tokens': [8, 9, 10, 11, 12, 13, 0, 0],
'decoder_input_tokens': [0, 8, 9, 10, 11, 12, 13, 0],
'decoder_loss_weights': [0, 0, 0, 1, 1, 1, 0, 0],
'decoder_causal_attention': [1, 1, 1, 1, 0, 0, 0, 0],
}]
assert_dataset(packed_ds, expected)

def test_pack_prefix_lm_decoder_only_with_padding_loss_on_targets_false(self):
x = [{'targets': [8, 9, 10, 11, 12, 13, 0, 0]}]
ds = test_utils.create_default_dataset(x, feature_names=['targets'])

# With this seed, split point is 3.
with seqio.utils.map_seed_manager(2):
packed_ds = prep.pack_prefix_lm_decoder_only(
ds, {'length': 8}, loss_on_targets_only=False)

expected = [{
'decoder_target_tokens': [8, 9, 10, 11, 12, 13, 0, 0],
'decoder_input_tokens': [0, 8, 9, 10, 11, 12, 13, 0],
'decoder_loss_weights': [1, 1, 1, 1, 1, 1, 0, 0],
'decoder_causal_attention': [1, 1, 1, 1, 0, 0, 0, 0],
}]
assert_dataset(packed_ds, expected)


if __name__ == '__main__':
absltest.main()
59 changes: 59 additions & 0 deletions t5/data/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,62 @@
],
metric_fns=[],
output_features=DEFAULT_OUTPUT_FEATURES)


# =============== PrefixLM objectives (not used in the T5 paper) ===============


# Vocabulary (shared by encoder and decoder)
sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model"

vocab = seqio.SentencePieceVocabulary(sentencepiece_model_file)
prefix_lm_obj_output_features = {
"encoder_input_tokens": seqio.Feature(vocabulary=vocab),
"decoder_target_tokens": seqio.Feature(vocabulary=vocab),
"decoder_input_tokens": seqio.Feature(vocabulary=vocab),
"encoder_segment_ids": seqio.Feature(vocabulary=vocab),
"encoder_positions": seqio.Feature(vocabulary=vocab),
"decoder_segment_ids": seqio.Feature(vocabulary=vocab),
"decoder_positions": seqio.Feature(vocabulary=vocab),
"decoder_loss_weights": seqio.Feature(vocabulary=vocab),
# All but the last stage of the preprocessing uses "targets" as the key. So
# this output feature is necessary. It not marked required because the final
# preprocessor drops it.
"targets": seqio.Feature(vocabulary=vocab, required=False),
}


seqio.TaskRegistry.add(
"c4_prefix_lm_objective_encoder_decoder_architecture",
source=seqio.TfdsDataSource(tfds_name="c4/en:2.2.0"),
preprocessors=[
functools.partial(
preprocessors.rekey, key_map={
"inputs": None,
"targets": "text"
}),
seqio.preprocessors.tokenize,
seqio.CacheDatasetPlaceholder(),
preprocessors.targets_for_prefix_lm_objective,
preprocessors.pack_prefix_lm_encoder_decoder,
],
output_features=prefix_lm_obj_output_features,
metric_fns=[])


seqio.TaskRegistry.add(
"c4_prefix_lm_objective_decoder_architecture",
source=seqio.TfdsDataSource(tfds_name="c4/en:2.2.0"),
preprocessors=[
functools.partial(
preprocessors.rekey, key_map={
"inputs": None,
"targets": "text"
}),
seqio.preprocessors.tokenize,
seqio.CacheDatasetPlaceholder(),
preprocessors.targets_for_prefix_lm_objective,
preprocessors.pack_prefix_lm_decoder_only,
],
output_features=prefix_lm_obj_output_features,
metric_fns=[])