-
Notifications
You must be signed in to change notification settings - Fork 9.6k
/
Copy pathmodel.py
252 lines (221 loc) · 10.9 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
"""
Full definition of a GPT Language Model, all of it in this single file.
Adapted from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
"""
from dataclasses import dataclass
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
@dataclass
class GPTConfig:
model_type: str = 'gpt2'
# model configurations
n_layer: int = None
n_head: int = None
n_embd: int = None
# openai's values for gpt2
vocab_size: int = 50257
block_size: int = 1024
# dropout hyperparameters
embd_pdrop: float = 0.1
resid_pdrop: float = 0.1
attn_pdrop: float = 0.1
@dataclass
class OptimizerConfig:
learning_rate: float = 3e-4
weight_decay: float = 0.1
class MultiheadAttentionLayer(nn.Module):
"""
A multi-head masked self-attention layer with a projection at the end.
"""
def __init__(self, config, device="cpu", dtype=torch.float32):
super().__init__()
assert config.n_embd % config.n_head == 0
self.resid_drop = nn.Dropout(config.resid_pdrop)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, device=device, dtype=dtype)
self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
self.attn = torch.nn.MultiheadAttention(
embed_dim=config.n_embd,
num_heads=config.n_head,
dropout=config.attn_pdrop,
batch_first=True,
device=device,
dtype=dtype
)
def forward(self, x):
_, seq_size, _ = x.size()
y = self.attn(x, x, x, attn_mask=self.mask[0, 0, :seq_size, :seq_size])[0]
y = self.resid_drop(self.c_proj(y))
return y
class Block(nn.Module):
""" an unassuming Transformer block """
def __init__(self, config: GPTConfig):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
self.attn = MultiheadAttentionLayer(config)
self.mlp = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd),
nn.GELU(),
nn.Linear(4 * config.n_embd, config.n_embd),
nn.Dropout(config.resid_pdrop),
)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class EmbeddingStem(nn.Module):
def __init__(self, config: GPTConfig, device="cpu", dtype=torch.float32):
super().__init__()
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd, device=device, dtype=dtype)
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd, device=device, dtype=dtype))
self.drop = nn.Dropout(config.embd_pdrop)
self.block_size = config.block_size
def reset_parameters(self):
self.tok_emb.reset_parameters()
def forward(self, idx):
b, t = idx.size()
assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) embedding vector
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) position vector
return self.drop(token_embeddings + position_embeddings)
class GPT(nn.Module):
""" GPT Language Model """
def __init__(self, config: GPTConfig):
super().__init__()
self.block_size = config.block_size
config = self._set_model_config(config)
# input embedding stem
self.emb_stem = EmbeddingStem(config)
# transformer
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
# decoder head
self.ln_f = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
self.apply(self._init_weights)
for pn, p in self.named_parameters():
if pn.endswith('c_proj.weight'):
p.data.normal_(mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
# report number of parameters (note we don't count the decoder parameters in lm_head)
n_params = sum(p.numel() for p in self.blocks.parameters())
print("number of parameters: %.2fM" % (n_params/1e6,))
def _set_model_config(self, config):
type_given = config.model_type is not None
params_given = all([config.n_layer is not None, config.n_head is not None, config.n_embd is not None])
# assert type_given ^ params_given # exactly one of these (XOR)
if type_given and not params_given:
# translate from model_type to detailed configuration
config.__dict__.update({
# names follow the huggingface naming conventions
# GPT-1
'openai-gpt': dict(n_layer=12, n_head=12, n_embd=768), # 117M params
# GPT-2 configs
'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
# Gophers
'gopher-44m': dict(n_layer=8, n_head=16, n_embd=512),
# (there are a number more...)
# I made these tiny models up
'gpt-mini': dict(n_layer=6, n_head=6, n_embd=192),
'gpt-micro': dict(n_layer=4, n_head=4, n_embd=128),
'gpt-nano': dict(n_layer=3, n_head=3, n_embd=48),
}[config.model_type])
return config
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, idx, targets=None):
x = self.emb_stem(idx)
x = self.blocks(x)
x = self.ln_f(x)
logits = self.head(x)
# if we are given some desired targets also calculate the loss
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
"""
for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at block_size
idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
# forward the model to get the logits for the index in the sequence
logits, _ = self(idx_cond)
# pluck the logits at the final step and scale by desired temperature
logits = logits[:, -1, :] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
# either sample from the distribution or take the most likely element
if do_sample:
idx_next = torch.multinomial(probs, num_samples=1)
else:
_, idx_next = torch.topk(probs, k=1, dim=-1)
# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)
return idx
def create_optimizer(model: torch.nn.Module, opt_config: OptimizerConfig):
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
"""
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, )
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in model.named_modules():
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
# random note: because named_modules and named_parameters are recursive
# we will see the same tensors p many many times. but doing it this way
# allows us to know which parent module any tensor p belongs to...
if pn.endswith('bias'):
# all biases will not be decayed
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('in_proj_weight'):
# MHA projection layer
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
elif pn.endswith('pos_emb'):
# positional embedding shouldn't be decayed
no_decay.add(fpn)
# validate that we considered every parameter
param_dict = {pn: p for pn, p in model.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
% (str(param_dict.keys() - union_params), )
# create the pytorch optimizer object
optim_groups = [
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": opt_config.weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(optim_groups, lr=opt_config.learning_rate, betas=(0.9, 0.95))
return optimizer