Skip to content

tmp #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,8 @@ def _add_learning_rate_args(parser):
'from checkpoint and ignore input arguments.')
group.add_argument('--universal-checkpoint', action='store_true',
help='Loading a universal format checkpoint.')

group.add_argument('--reset-progress', action='store_true', default=None,
help='Reset iteration to 0 & do not load args.')
return parser


Expand Down
6 changes: 3 additions & 3 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
set_checkpoint_version(state_dict.get('checkpoint_version', 0))

# Set iteration.
if args.finetune or release:
if args.finetune or release or args.reset_progress:
iteration = 0
else:
try:
Expand All @@ -361,7 +361,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
# Check arguments.
assert args.consumed_train_samples == 0
assert args.consumed_valid_samples == 0
if 'args' in state_dict:
if 'args' in state_dict and not args.reset_progress:
checkpoint_args = state_dict['args']
if not args.universal_checkpoint:
check_checkpoint_args(checkpoint_args)
Expand Down Expand Up @@ -480,4 +480,4 @@ def _checkpoint_info():
return {
"padded_vocab_size": args.padded_vocab_size,
"original_vocab_size": tokenizer.vocab_size,
}
}
72 changes: 3 additions & 69 deletions megatron/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ def build_tokenizer(args):
vocab_extra_ids=args.vocab_extra_ids)
elif args.tokenizer_type == 'GPT2BPETokenizer':
assert args.merge_file is not None
tokenizer = _GPT2BPETokenizer(
args.vocab_file, args.merge_file, vocab_extra_ids=args.vocab_extra_ids
)
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
elif args.tokenizer_type == "PretrainedFromHF":
assert args.tokenizer_name_or_path is not None

Expand Down Expand Up @@ -288,36 +286,13 @@ def additional_special_tokens(self, value):
class _GPT2BPETokenizer(AbstractTokenizer):
"""Original GPT2 BPE tokenizer."""

def __init__(self, vocab_file, merge_file, vocab_extra_ids=0):
def __init__(self, vocab_file, merge_file):
name = 'GPT2 BPE'
super().__init__(name)

self.tokenizer = GPT2Tokenizer(vocab_file, merge_file, errors='replace',
special_tokens=[], max_len=None)
self.eod_id = self.eos_token_id = self.tokenizer.encoder['<|endoftext|>']

self.bod_id = self.bos_token_id = self.tokenizer.encoder['[EOS]']
self.sep_id = self.tokenizer.encoder['[SEP]']
self.mask_id = self.tokenizer.encoder['[MASK]']
self.pad_id = self.tokenizer.encoder['[PAD]']

additional_special_tokens = []
self._additional_special_tokens = []
additional_special_tokens.extend(
["<extra_id_{}>".format(i) for i in range(vocab_extra_ids)])
self.add_additional_special_tokens(additional_special_tokens)

def add_additional_special_tokens(self, tokens_list):
setattr(self, "additional_special_tokens", tokens_list)
for value in tokens_list:
self.add_token(value)

def add_token(self, token):
if token not in self.vocab:
self.inv_vocab[self.vocab_size] = token
# self.vocab_size comes from len(vocab)
# and it will increase as we add elements
self.vocab[token] = self.vocab_size
self.eod_id = self.tokenizer.encoder['<|endoftext|>']

@property
def vocab_size(self):
Expand All @@ -341,35 +316,6 @@ def detokenize(self, token_ids):
def eod(self):
return self.eod_id

@property
def bod(self):
return self.bod_id

@property
def sep(self):
return self.sep_id

@property
def mask(self):
return self.mask_id

@property
def pad(self):
return self.pad_id

@property
def additional_special_tokens(self):
""" All the additional special tokens you may want to use (list of strings)."""
return self._additional_special_tokens

@property
def additional_special_tokens_ids(self):
""" Ids of all the additional special tokens in the vocabulary (list of integers)."""
return [self.vocab.get(token) for token in self._additional_special_tokens]

@additional_special_tokens.setter
def additional_special_tokens(self, value):
self._additional_special_tokens = value

class _AutoTokenizer(AbstractTokenizer):
"""AutoTokenizer for Hf Pretrained model loading."""
Expand Down Expand Up @@ -442,18 +388,6 @@ def eos(self):
candidate = self.tokenizer.eos_token_id
return self._check_token_candidate(candidate)

@property
def bos_token_id(self):
"""Id of the beginning of sentence token in the vocabulary."""
candidate = self.tokenizer.bos_token_id
return self._check_token_candidate(candidate)

@property
def eos_token_id(self):
"""Id of the end of sentence token in the vocabulary."""
candidate = self.tokenizer.eos_token_id
return self._check_token_candidate(candidate)

@property
def additional_special_tokens_ids(self):
""" All the additional special tokens you may want to use (list of strings)."""
Expand Down
14 changes: 10 additions & 4 deletions tasks/eval_harness/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,8 @@ def tasks_args(parser):
group.add_argument('--intermed_results', default = False, action='store_true', help='Whether to print & write intermediate results for each task')
group.add_argument('--bootstrap_iters', type=int, default=100000, help='How many iterations to use for stderr estimation')
group.add_argument('--micro_bs_multiplier', type=int, default=1, help='Increase the global batch size to remove bubble when pipeline parallel')
group.add_argument('--fewshots', type=int, default=0, help='Num fewshots')
group.add_argument('--limit', type=int, default=None, help='Limit samples')
group.add_argument('--add_denoiser', default = False, action='store_true', help='Whether to add a denoiser to the model')
return parser

Expand All @@ -407,6 +409,10 @@ def main():
# parse the megatron args. But wait with initalizing megatron.
# avoid printing the arguments, since they will later be overridden.
args = _parse_args(tasks_args)
if os.path.exists(args.results_path):
print("Exists ", args.results_path)
exit()

load_path = args.load
model = load_ds_checkpoint_and_setup_megatron(args)

Expand All @@ -431,11 +437,11 @@ def main():
global_results = {"results": {}, "versions": {}}
timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
iteration_id = load_path.split("/")[-1].replace("/", "")
results_path = args.results_path.replace(".json", f"_lm-eval_{iteration_id}_{timestamp}.json")
results_path = args.results_path#.replace(".json", f"_lm-eval_{iteration_id}_{timestamp}_{args.fewshots}shots.json")
# Backup file in case of interruption during writing
results_path_backup = args.results_path.replace(".json", f"_lm-eval_{iteration_id}_{timestamp}_backup.json")
results_path_backup = args.results_path.replace(".json", f"_lm-eval_{iteration_id}_{timestamp}_{args.fewshots}shots_backup.json")
for task_name, task in task_dict.items():
results = evaluator.evaluate(adaptor, {task_name: task}, False, 0, None, bootstrap_iters=args.bootstrap_iters)
results = evaluator.evaluate(adaptor, {task_name: task}, False, args.fewshots, bootstrap_iters=args.bootstrap_iters, limit=args.limit)
global_results["results"] = {**global_results["results"], **results["results"]}
global_results["versions"] = {**global_results["versions"], **results["versions"]}
if mpu.is_pipeline_last_stage() and mpu.get_tensor_model_parallel_rank() == 0:
Expand All @@ -445,7 +451,7 @@ def main():
with open(results_path_backup, 'w') as outfile:
json.dump(global_results, outfile, indent=4)
else:
global_results = evaluator.evaluate(adaptor, task_dict, False, 0, None, bootstrap_iters=args.bootstrap_iters)
global_results = evaluator.evaluate(adaptor, task_dict, False, args.fewshots, bootstrap_iters=args.bootstrap_iters, limit=args.limit)
if mpu.is_pipeline_last_stage() and mpu.get_tensor_model_parallel_rank() == 0:
print(json.dumps(global_results, indent=2))
with open(args.results_path, 'w') as outfile:
Expand Down