diff --git a/finetune_t0.py b/finetune_t0.py index e4af9b91f..88e90a586 100644 --- a/finetune_t0.py +++ b/finetune_t0.py @@ -2,7 +2,9 @@ import torch +from pretrain_gpt import get_batch_pipe as get_batch_pipe_gpt from megatron import get_args, get_tokenizer, print_rank_0, mpu +from megatron.data.gpt_dataset import build_dataset_group as build_dataset_group_gpt from megatron.data.decoder_packed_mtf_dataset import build_train_valid_test_datasets, build_dataset_group from megatron.enums import PositionEmbeddingType, AttnMaskType from megatron.model import GPTModelPipe @@ -48,6 +50,14 @@ def model_provider(pre_process=True, post_process=True): return model +def fast_normalize(loss_mask: torch.Tensor): + """ + Turn loss_mask from [0,0,0,1,1,0,0,1,0,0,1,1,1] > [0,0,0,0.5,0.5,0,0,1,0,0,0.3,0.3,0.3] + """ + _, inverse_indices, counts = torch.unique_consecutive(loss_mask, return_inverse=True, return_counts=True) + counts = torch.gather(dim=0, index=inverse_indices, input=counts) + return loss_mask / counts + def get_batch_pipe(data): """ Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator` & in packed fashion @@ -57,6 +67,9 @@ def get_batch_pipe(data): decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]] decoder_is_inputs = [[1, 1, 0, 1, 1, 0, 0]] """ + if 'text' in data: + return get_batch_pipe_gpt(data) + args = get_args() tokenizer = get_tokenizer() @@ -95,6 +108,10 @@ def get_batch_pipe(data): segment_ids=segment_ids.long(), ) + if args.norm_target_loss: + loss_mask = loss_mask.view(-1) + loss_mask = fast_normalize(loss_mask) + if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]: raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.") @@ -142,20 +159,34 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): eval(f"args.{s}_weighted_split_splits"), eval(f"args.{s}_weighted_split_names")) for paths, weights, splits, name in data_groups: - d = build_dataset_group( - dataset_group_name=name, - paths=paths, - weights=weights, - splits=splits, - data_impl=args.data_impl, - train_valid_test_num_samples=train_val_test_num_samples, - seq_length=args.seq_length + 1, - pad_token=tokenizer.pad, - eos_token=tokenizer.eos, - seed=args.seed, - skip_warmup=(not args.mmap_warmup), - train_valid_test=s - ) + if "merged-meg-ds_v3_pii" in paths[0]: + d = build_dataset_group_gpt( + dataset_group_name=name, + paths=paths, + weights=weights, + splits=splits, + data_impl=args.data_impl, + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=(not args.mmap_warmup), + train_valid_test=s + ) + else: + d = build_dataset_group( + dataset_group_name=name, + paths=paths, + weights=weights, + splits=splits, + data_impl=args.data_impl, + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length + 1, + pad_token=tokenizer.pad, + eos_token=tokenizer.eos, + seed=args.seed, + skip_warmup=(not args.mmap_warmup), + train_valid_test=s + ) eval(f"{s}_ds").append(d) else: raise NotImplementedError("No dataloading argument passed") diff --git a/megatron/arguments.py b/megatron/arguments.py index 6622df924..a51ac6a33 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -930,6 +930,8 @@ def __call__(self, parser, args, values, option_string=None): help='Mask loss for the end of document tokens.') group.add_argument('--loss-on-targets-only', action='store_true', help='Mask loss on input sequence.') + group.add_argument('--norm-target-loss', action='store_true', + help='Normalize the loss per target. Used for multi-task finetuning with packing.') group.add_argument('--reweight-loss-based-on-position-frequency', action="store_true", help='Some objectives require us to sample loss_mask. This might introduce bias towards ' 'specific positions. This option tries to un-bias the loss by reweighting loss on specific ' diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 6896df7f4..57a4cebfc 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -15,7 +15,6 @@ """GPT-2 model.""" -from functools import partial import torch from megatron import get_args @@ -186,6 +185,10 @@ def CrossEntropy(output, labels): else: average_tokens_per_sample = sequence_length expected_number_of_tokens = average_tokens_per_sample * micro_batch_size + elif args.norm_target_loss and (loss_mask.dim() == 1): + expected_num_of_target_seqs = loss_mask.sum() + loss = torch.sum(losses.view(-1) * loss_mask) / expected_num_of_target_seqs + return loss else: expected_number_of_tokens = loss_mask.sum() diff --git a/megatron/training.py b/megatron/training.py index bd00bc77e..65694c410 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -183,7 +183,7 @@ def pretrain(train_valid_test_dataset_provider, timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup']) print_rank_0('training ...') - iteration = 0 + iteration = args.iteration if args.do_train and args.train_iters > 0: iteration = train(forward_step_func, model, optimizer, lr_scheduler, @@ -199,7 +199,8 @@ def pretrain(train_valid_test_dataset_provider, iterator, model, iteration, False, data_group_name=name) - if args.save and iteration != 0: + # Do not save if the iteration has not changed + if args.save and iteration != args.iteration: save_checkpoint(iteration, model, optimizer, lr_scheduler) if args.do_test: