-
Notifications
You must be signed in to change notification settings - Fork 9.6k
/
Copy pathmain.py
45 lines (34 loc) · 1.49 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os
import torch
from torch.utils.data import random_split
from torch.distributed import init_process_group, destroy_process_group
from model import GPT, GPTConfig, OptimizerConfig, create_optimizer
from trainer import Trainer, TrainerConfig
from char_dataset import CharDataset, DataConfig
from omegaconf import DictConfig
import hydra
def ddp_setup():
init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
def get_train_objs(gpt_cfg: GPTConfig, opt_cfg: OptimizerConfig, data_cfg: DataConfig):
dataset = CharDataset(data_cfg)
train_len = int(len(dataset) * data_cfg.train_split)
train_set, test_set = random_split(dataset, [train_len, len(dataset) - train_len])
gpt_cfg.vocab_size = dataset.vocab_size
gpt_cfg.block_size = dataset.block_size
model = GPT(gpt_cfg)
optimizer = create_optimizer(model, opt_cfg)
return model, optimizer, train_set, test_set
@hydra.main(version_base=None, config_path=".", config_name="gpt2_train_cfg")
def main(cfg: DictConfig):
ddp_setup()
gpt_cfg = GPTConfig(**cfg['gpt_config'])
opt_cfg = OptimizerConfig(**cfg['optimizer_config'])
data_cfg = DataConfig(**cfg['data_config'])
trainer_cfg = TrainerConfig(**cfg['trainer_config'])
model, optimizer, train_data, test_data = get_train_objs(gpt_cfg, opt_cfg, data_cfg)
trainer = Trainer(trainer_cfg, model, optimizer, train_data, test_data)
trainer.train()
destroy_process_group()
if __name__ == "__main__":
main()