From 462efd99f3c282aca4dc0c7feaf972003adef0e0 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Tue, 9 Aug 2022 21:12:47 +0200 Subject: [PATCH 01/12] Tmp lossseq --- megatron/model/gpt_model.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 6896df7f4..eaef54727 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -190,7 +190,28 @@ def CrossEntropy(output, labels): expected_number_of_tokens = loss_mask.sum() loss_mask = loss_mask.view(-1) - loss = torch.sum(losses.view(-1) * loss_mask) / expected_number_of_tokens + + # Turn loss_mask from [0,0,0,1,1,0,0,1,0,0,1,1,1] > [0,0,0,0.5,0.5,0,0,1,0,0,0.3,0.3,0.3] + loss_mask = loss_mask.float() + sequence = 0 + for idx,num in enumerate(loss_mask.tolist()): + if (num == 0) and (sequence == 0): + continue + elif (num == 0) and (sequence > 0): + # Sequence just finished + start_idx = idx - sequence + loss_mask[start_idx:idx] /= sequence + # Reset + sequence = 0 + else: + sequence += 1 + if sequence > 0: + start_idx = idx - sequence + loss_mask[start_idx:] /= sequence + + expected_num_of_target_seqs = loss_mask.sum() + + loss = torch.sum(losses.view(-1) * loss_mask) / expected_num_of_target_seqs#expected_number_of_tokens return loss return CrossEntropy From 992446c856692f12288152de4a0924cdb7f1901a Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Wed, 10 Aug 2022 09:49:48 +0200 Subject: [PATCH 02/12] Efficient loss normalization --- megatron/model/gpt_model.py | 32 +++++++++++++------------------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index eaef54727..6e02e5ebc 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -159,6 +159,17 @@ def load_state_dict(self, state_dict, strict=True): self.language_model.load_state_dict(state_dict, strict=strict) +def fast_normalize(loss_mask: torch.Tensor): + """ + Turn loss_mask from [0,0,0,1,1,0,0,1,0,0,1,1,1] > [0,0,0,0.5,0.5,0,0,1,0,0,0.3,0.3,0.3] + + Credits to @thomasw21 for this efficient implementation! + """ + loss_mask = loss_mask.float() + _, inverse_indices, counts = torch.unique_consecutive(loss_mask, return_inverse=True, return_counts=True) + l = torch.gather(dim=0, index=inverse_indices, input=1./counts) + return loss_mask * l + def get_cross_entropy(is_prefix: bool): def CrossEntropy(output, labels): labels, loss_mask = labels[0], labels[1] @@ -191,27 +202,10 @@ def CrossEntropy(output, labels): loss_mask = loss_mask.view(-1) - # Turn loss_mask from [0,0,0,1,1,0,0,1,0,0,1,1,1] > [0,0,0,0.5,0.5,0,0,1,0,0,0.3,0.3,0.3] - loss_mask = loss_mask.float() - sequence = 0 - for idx,num in enumerate(loss_mask.tolist()): - if (num == 0) and (sequence == 0): - continue - elif (num == 0) and (sequence > 0): - # Sequence just finished - start_idx = idx - sequence - loss_mask[start_idx:idx] /= sequence - # Reset - sequence = 0 - else: - sequence += 1 - if sequence > 0: - start_idx = idx - sequence - loss_mask[start_idx:] /= sequence - + loss_mask = fast_normalize(loss_mask) expected_num_of_target_seqs = loss_mask.sum() - loss = torch.sum(losses.view(-1) * loss_mask) / expected_num_of_target_seqs#expected_number_of_tokens + loss = torch.sum(losses.view(-1) * loss_mask) / expected_num_of_target_seqs return loss return CrossEntropy From 616cfe86b74d7e8d86b94748b659141a7f8c4bfc Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Wed, 10 Aug 2022 10:11:30 +0200 Subject: [PATCH 03/12] Reuse variable --- megatron/model/gpt_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 6e02e5ebc..b7cef6da4 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -167,8 +167,8 @@ def fast_normalize(loss_mask: torch.Tensor): """ loss_mask = loss_mask.float() _, inverse_indices, counts = torch.unique_consecutive(loss_mask, return_inverse=True, return_counts=True) - l = torch.gather(dim=0, index=inverse_indices, input=1./counts) - return loss_mask * l + counts = torch.gather(dim=0, index=inverse_indices, input=1./counts) + return loss_mask * counts def get_cross_entropy(is_prefix: bool): def CrossEntropy(output, labels): From 900c83569a868e8ff9f94d44f0c4a0e53ab4a311 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Wed, 10 Aug 2022 10:23:15 +0200 Subject: [PATCH 04/12] Simplify division --- megatron/model/gpt_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index b7cef6da4..7a958a313 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -167,8 +167,8 @@ def fast_normalize(loss_mask: torch.Tensor): """ loss_mask = loss_mask.float() _, inverse_indices, counts = torch.unique_consecutive(loss_mask, return_inverse=True, return_counts=True) - counts = torch.gather(dim=0, index=inverse_indices, input=1./counts) - return loss_mask * counts + counts = torch.gather(dim=0, index=inverse_indices, input=counts) + return loss_mask / counts def get_cross_entropy(is_prefix: bool): def CrossEntropy(output, labels): From 7bc1dd207c2c413789ae6720e57cb572c843bde5 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Mon, 15 Aug 2022 16:11:46 +0200 Subject: [PATCH 05/12] Add norm_target_loss arg --- megatron/arguments.py | 2 ++ megatron/model/gpt_model.py | 18 +++++++++--------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 4b63112e9..e65ac30e4 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -926,6 +926,8 @@ def __call__(self, parser, args, values, option_string=None): help='Mask loss for the end of document tokens.') group.add_argument('--loss-on-targets-only', action='store_true', help='Mask loss on input sequence.') + group.add_argument('--norm-target-loss', action='store_true', + help='Normalize the loss per target.') group.add_argument('--reweight-loss-based-on-position-frequency', action="store_true", help='Some objectives require us to sample loss_mask. This might introduce bias towards ' 'specific positions. This option tries to un-bias the loss by reweighting loss on specific ' diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 7a958a313..056115372 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -15,7 +15,6 @@ """GPT-2 model.""" -from functools import partial import torch from megatron import get_args @@ -170,7 +169,7 @@ def fast_normalize(loss_mask: torch.Tensor): counts = torch.gather(dim=0, index=inverse_indices, input=counts) return loss_mask / counts -def get_cross_entropy(is_prefix: bool): +def get_cross_entropy(is_prefix: bool, norm_target_loss: bool): def CrossEntropy(output, labels): labels, loss_mask = labels[0], labels[1] @@ -197,15 +196,16 @@ def CrossEntropy(output, labels): else: average_tokens_per_sample = sequence_length expected_number_of_tokens = average_tokens_per_sample * micro_batch_size + elif norm_target_loss: + loss_mask = loss_mask.view(-1) + loss_mask = fast_normalize(loss_mask) + expected_num_of_target_seqs = loss_mask.sum() + loss = torch.sum(losses.view(-1) * loss_mask) / expected_num_of_target_seqs + return loss else: expected_number_of_tokens = loss_mask.sum() - loss_mask = loss_mask.view(-1) - - loss_mask = fast_normalize(loss_mask) - expected_num_of_target_seqs = loss_mask.sum() - - loss = torch.sum(losses.view(-1) * loss_mask) / expected_num_of_target_seqs + loss = torch.sum(losses.view(-1) * loss_mask) / expected_number_of_tokens return loss return CrossEntropy @@ -329,7 +329,7 @@ def _logits_helper(embedding, lm_output): partition_method = 'type:transformer' super().__init__(layers=self.specs, - loss_fn=get_cross_entropy(is_prefix=attn_mask_type is AttnMaskType.prefix), + loss_fn=get_cross_entropy(is_prefix=attn_mask_type is AttnMaskType.prefix, norm_target_loss=args.norm_target_loss), topology=topo, activation_checkpoint_interval=interval, partition_method=partition_method) From fce1a98ede227f567c87bac3ce067d2a3f0de9c9 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Tue, 16 Aug 2022 11:23:05 +0200 Subject: [PATCH 06/12] Clarify loss on targets & remove kwarg --- megatron/arguments.py | 2 +- megatron/model/gpt_model.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index e65ac30e4..64681f1e0 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -927,7 +927,7 @@ def __call__(self, parser, args, values, option_string=None): group.add_argument('--loss-on-targets-only', action='store_true', help='Mask loss on input sequence.') group.add_argument('--norm-target-loss', action='store_true', - help='Normalize the loss per target.') + help='Normalize the loss per target. Used for multi-task finetuning with packing.') group.add_argument('--reweight-loss-based-on-position-frequency', action="store_true", help='Some objectives require us to sample loss_mask. This might introduce bias towards ' 'specific positions. This option tries to un-bias the loss by reweighting loss on specific ' diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 056115372..8d2ff2bca 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -161,15 +161,13 @@ def load_state_dict(self, state_dict, strict=True): def fast_normalize(loss_mask: torch.Tensor): """ Turn loss_mask from [0,0,0,1,1,0,0,1,0,0,1,1,1] > [0,0,0,0.5,0.5,0,0,1,0,0,0.3,0.3,0.3] - - Credits to @thomasw21 for this efficient implementation! """ loss_mask = loss_mask.float() _, inverse_indices, counts = torch.unique_consecutive(loss_mask, return_inverse=True, return_counts=True) counts = torch.gather(dim=0, index=inverse_indices, input=counts) return loss_mask / counts -def get_cross_entropy(is_prefix: bool, norm_target_loss: bool): +def get_cross_entropy(is_prefix: bool): def CrossEntropy(output, labels): labels, loss_mask = labels[0], labels[1] @@ -196,7 +194,7 @@ def CrossEntropy(output, labels): else: average_tokens_per_sample = sequence_length expected_number_of_tokens = average_tokens_per_sample * micro_batch_size - elif norm_target_loss: + elif args.norm_target_loss: loss_mask = loss_mask.view(-1) loss_mask = fast_normalize(loss_mask) expected_num_of_target_seqs = loss_mask.sum() @@ -329,7 +327,7 @@ def _logits_helper(embedding, lm_output): partition_method = 'type:transformer' super().__init__(layers=self.specs, - loss_fn=get_cross_entropy(is_prefix=attn_mask_type is AttnMaskType.prefix, norm_target_loss=args.norm_target_loss), + loss_fn=get_cross_entropy(is_prefix=attn_mask_type is AttnMaskType.prefix), topology=topo, activation_checkpoint_interval=interval, partition_method=partition_method) From 2e7554d7f0ed4d7c200b04c8c29c84d47f29c90c Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Wed, 17 Aug 2022 11:14:09 +0200 Subject: [PATCH 07/12] Loss mask is already float --- megatron/model/gpt_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 8d2ff2bca..1d23b4ec4 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -162,7 +162,6 @@ def fast_normalize(loss_mask: torch.Tensor): """ Turn loss_mask from [0,0,0,1,1,0,0,1,0,0,1,1,1] > [0,0,0,0.5,0.5,0,0,1,0,0,0.3,0.3,0.3] """ - loss_mask = loss_mask.float() _, inverse_indices, counts = torch.unique_consecutive(loss_mask, return_inverse=True, return_counts=True) counts = torch.gather(dim=0, index=inverse_indices, input=counts) return loss_mask / counts From a6b262404d583af09b17a77761f5c8d3b06d1ad5 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Wed, 17 Aug 2022 11:15:53 +0200 Subject: [PATCH 08/12] Move norm to batch pipe --- finetune_t0.py | 12 ++++++++++++ megatron/model/gpt_model.py | 10 ---------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/finetune_t0.py b/finetune_t0.py index e4af9b91f..a6fbd3a78 100644 --- a/finetune_t0.py +++ b/finetune_t0.py @@ -48,6 +48,14 @@ def model_provider(pre_process=True, post_process=True): return model +def fast_normalize(loss_mask: torch.Tensor): + """ + Turn loss_mask from [0,0,0,1,1,0,0,1,0,0,1,1,1] > [0,0,0,0.5,0.5,0,0,1,0,0,0.3,0.3,0.3] + """ + _, inverse_indices, counts = torch.unique_consecutive(loss_mask, return_inverse=True, return_counts=True) + counts = torch.gather(dim=0, index=inverse_indices, input=counts) + return loss_mask / counts + def get_batch_pipe(data): """ Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator` & in packed fashion @@ -95,6 +103,10 @@ def get_batch_pipe(data): segment_ids=segment_ids.long(), ) + if args.norm_target_loss: + loss_mask = loss_mask.view(-1) + loss_mask = fast_normalize(loss_mask) + if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]: raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.") diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 1d23b4ec4..fee4e1d8e 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -158,14 +158,6 @@ def load_state_dict(self, state_dict, strict=True): self.language_model.load_state_dict(state_dict, strict=strict) -def fast_normalize(loss_mask: torch.Tensor): - """ - Turn loss_mask from [0,0,0,1,1,0,0,1,0,0,1,1,1] > [0,0,0,0.5,0.5,0,0,1,0,0,0.3,0.3,0.3] - """ - _, inverse_indices, counts = torch.unique_consecutive(loss_mask, return_inverse=True, return_counts=True) - counts = torch.gather(dim=0, index=inverse_indices, input=counts) - return loss_mask / counts - def get_cross_entropy(is_prefix: bool): def CrossEntropy(output, labels): labels, loss_mask = labels[0], labels[1] @@ -194,8 +186,6 @@ def CrossEntropy(output, labels): average_tokens_per_sample = sequence_length expected_number_of_tokens = average_tokens_per_sample * micro_batch_size elif args.norm_target_loss: - loss_mask = loss_mask.view(-1) - loss_mask = fast_normalize(loss_mask) expected_num_of_target_seqs = loss_mask.sum() loss = torch.sum(losses.view(-1) * loss_mask) / expected_num_of_target_seqs return loss From 549f4993ef57095ef90d324c8c92a19124fae397 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Sat, 20 Aug 2022 15:54:30 +0200 Subject: [PATCH 09/12] Reshape loss mask --- megatron/model/gpt_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index fee4e1d8e..a508ed0a5 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -166,6 +166,8 @@ def CrossEntropy(output, labels): losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels) + loss_mask = loss_mask.view(-1) + if is_prefix: micro_batch_size, sequence_length = loss_mask.shape average_tokens_per_sample: torch.Tensor From d9a91febb2fbca83e28a49b4596225e47595570a Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Sat, 20 Aug 2022 16:03:56 +0200 Subject: [PATCH 10/12] Move view --- megatron/model/gpt_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index a508ed0a5..3b8cbb960 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -166,8 +166,6 @@ def CrossEntropy(output, labels): losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels) - loss_mask = loss_mask.view(-1) - if is_prefix: micro_batch_size, sequence_length = loss_mask.shape average_tokens_per_sample: torch.Tensor @@ -194,6 +192,7 @@ def CrossEntropy(output, labels): else: expected_number_of_tokens = loss_mask.sum() + loss_mask = loss_mask.view(-1) loss = torch.sum(losses.view(-1) * loss_mask) / expected_number_of_tokens return loss return CrossEntropy From 6c1018f67bbae58ad88a7cac52f29cca3b5b7f6a Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Tue, 30 Aug 2022 19:24:18 +0200 Subject: [PATCH 11/12] Add multiple evaluation compat --- finetune_t0.py | 47 ++++++++++++++++++++++++++----------- megatron/model/gpt_model.py | 2 +- 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/finetune_t0.py b/finetune_t0.py index a6fbd3a78..88e90a586 100644 --- a/finetune_t0.py +++ b/finetune_t0.py @@ -2,7 +2,9 @@ import torch +from pretrain_gpt import get_batch_pipe as get_batch_pipe_gpt from megatron import get_args, get_tokenizer, print_rank_0, mpu +from megatron.data.gpt_dataset import build_dataset_group as build_dataset_group_gpt from megatron.data.decoder_packed_mtf_dataset import build_train_valid_test_datasets, build_dataset_group from megatron.enums import PositionEmbeddingType, AttnMaskType from megatron.model import GPTModelPipe @@ -65,6 +67,9 @@ def get_batch_pipe(data): decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]] decoder_is_inputs = [[1, 1, 0, 1, 1, 0, 0]] """ + if 'text' in data: + return get_batch_pipe_gpt(data) + args = get_args() tokenizer = get_tokenizer() @@ -154,20 +159,34 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): eval(f"args.{s}_weighted_split_splits"), eval(f"args.{s}_weighted_split_names")) for paths, weights, splits, name in data_groups: - d = build_dataset_group( - dataset_group_name=name, - paths=paths, - weights=weights, - splits=splits, - data_impl=args.data_impl, - train_valid_test_num_samples=train_val_test_num_samples, - seq_length=args.seq_length + 1, - pad_token=tokenizer.pad, - eos_token=tokenizer.eos, - seed=args.seed, - skip_warmup=(not args.mmap_warmup), - train_valid_test=s - ) + if "merged-meg-ds_v3_pii" in paths[0]: + d = build_dataset_group_gpt( + dataset_group_name=name, + paths=paths, + weights=weights, + splits=splits, + data_impl=args.data_impl, + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=(not args.mmap_warmup), + train_valid_test=s + ) + else: + d = build_dataset_group( + dataset_group_name=name, + paths=paths, + weights=weights, + splits=splits, + data_impl=args.data_impl, + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length + 1, + pad_token=tokenizer.pad, + eos_token=tokenizer.eos, + seed=args.seed, + skip_warmup=(not args.mmap_warmup), + train_valid_test=s + ) eval(f"{s}_ds").append(d) else: raise NotImplementedError("No dataloading argument passed") diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 3b8cbb960..57a4cebfc 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -185,7 +185,7 @@ def CrossEntropy(output, labels): else: average_tokens_per_sample = sequence_length expected_number_of_tokens = average_tokens_per_sample * micro_batch_size - elif args.norm_target_loss: + elif args.norm_target_loss and (loss_mask.dim() == 1): expected_num_of_target_seqs = loss_mask.sum() loss = torch.sum(losses.view(-1) * loss_mask) / expected_num_of_target_seqs return loss From 477cda6d4757225840f879277067cd12c6796773 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Sat, 3 Sep 2022 14:19:21 +0200 Subject: [PATCH 12/12] Set iteration to args by default --- megatron/training.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/megatron/training.py b/megatron/training.py index bd00bc77e..65694c410 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -183,7 +183,7 @@ def pretrain(train_valid_test_dataset_provider, timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup']) print_rank_0('training ...') - iteration = 0 + iteration = args.iteration if args.do_train and args.train_iters > 0: iteration = train(forward_step_func, model, optimizer, lr_scheduler, @@ -199,7 +199,8 @@ def pretrain(train_valid_test_dataset_provider, iterator, model, iteration, False, data_group_name=name) - if args.save and iteration != 0: + # Do not save if the iteration has not changed + if args.save and iteration != args.iteration: save_checkpoint(iteration, model, optimizer, lr_scheduler) if args.do_test: