-
Notifications
You must be signed in to change notification settings - Fork 890
/
Copy pathcheckpoint.py
410 lines (307 loc) · 13 KB
/
checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
import functools
import io
import json
import time
import jax
import jax.numpy as jnp
import numpy as np
import multiprocessing
import ray
from smart_open import open
from mesh_transformer.util import head_print
pieces = 16 # how many files to split each shard across
def fix_dtype(pytree):
def fix(x):
if x.dtype == np.dtype('V2'):
x.dtype = jnp.bfloat16
return jnp.asarray(x)
return jax.tree_map(fix, pytree)
@functools.partial(jax.jit, backend="cpu")
def index_weights(weights, idx):
cpu_device = jax.devices("cpu")[0]
return jax.device_put(jax.tree_map(lambda i: i[idx], weights), cpu_device)
def write(x, ckpt_dir):
# start = time.time()
idx, i = x
file_path = ckpt_dir + f"{idx}.npz"
for _ in range(3):
try:
with open(file_path, "wb") as f:
np.savez(f, *i)
# cloudpickle.dump(i, f)
# print(f"written {idx} in {time.time() - start:.06}s")
return
except:
print("save failed, trying again")
print("save failed 3 times, exiting")
raise Exception("save failed")
def split(a, n):
k, m = divmod(len(a), n)
return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n))
def write_ckpt(pytree, dir, shard):
# ckpt_dir = Path(dir)
# ckpt_dir.mkdir(parents=True, exist_ok=True)
flattened, structure = jax.tree_flatten(pytree)
start = time.time()
# cpu_flattened = jax.device_put(flattened, cpu_device)
cpu_flattened = index_weights(flattened, shard)
# print(f"Moved indexed in {time.time() - start:.06}s")
cpu_flattened_chunked = split(cpu_flattened, pieces)
# start = time.time()
# cpu_float = move_weights(cpu_flattened)
# print(f"changed weight types in {time.time() - start:.06}s")
with multiprocessing.pool.ThreadPool(pieces) as p:
write_fn = functools.partial(write, ckpt_dir=f"{dir}shard_{shard}/")
start = time.time()
list((p.imap_unordered(write_fn, enumerate(cpu_flattened_chunked))))
# print(f"written to gcs in {time.time() - start:.06}s")
def read_shard(ckpt_dir):
out = []
for idx in range(16):
file_path = ckpt_dir + f"{idx}.npz"
with open(file_path, "rb") as f:
buf = f.read()
f_io = io.BytesIO(buf)
deserialized = np.load(f_io)
for i in deserialized:
out.append(deserialized[i])
return out
def reshard(x, old_shape):
if len(x.shape) == 1:
# print("epoch")
# print(x)
out = x[0:1]
elif len(x.shape) == 2:
# print(f"LN/bias {x.shape}")
# print(x[:, :16])
if (x[1:] == x[-1]).all():
# print("LN")
if (x[1:] == 0).all() or (x[1:] == 1).all():
out = x[0:1]
else:
# print("shard bias")
out = x[0:1] * x.shape[0] / old_shape[0]
else:
# print("bias")
out = x.reshape(old_shape)
print(out[:, :16])
elif len(x.shape) == 3:
# print(f"weight {x.shape}")
if x.shape[0] * x.shape[2] == old_shape[2]:
# print("case 1")
out = jnp.transpose(x, (1, 0, 2)).reshape(old_shape)
elif x.shape[0] * x.shape[1] == old_shape[1]:
# print("case 2")
out = x.reshape(old_shape)
else:
raise Exception(f"unimplemented, {x.shape}, {old_shape}")
else:
raise Exception(f"unimplemented, {x}")
return out
def read_ckpt(pytree, dir, shards_in, shards_out=None, load_opt=True):
if shards_out is None:
shards_out = shards_in
old_flattened, structure = jax.tree_flatten(pytree)
original_opt_state = pytree["opt_state"]
# TODO: figure out how to use a process pool here for more speed
with multiprocessing.pool.ThreadPool(shards_in) as p:
start = time.time()
shards = list((p.imap(read_shard, [f"{dir}shard_{i}/" for i in range(shards_in)])))
print(f"read from disk/gcs in {time.time() - start:.06}s")
def _unshard(shards, old_flattened):
unsharded = []
for old, *all_shards in zip(old_flattened, *shards):
x = np.stack(all_shards)
# No idea why this is V2...?
if x.dtype == np.dtype('V2'):
x.dtype = jnp.bfloat16
if shards_out != shards_in:
x = reshard(x, old.shape)
unsharded.append(x)
assert x.shape == old.shape, f"Incompatible checkpoints {x.shape} vs {old.shape}"
return unsharded
try:
unsharded = _unshard(shards, old_flattened)
except AssertionError:
load_opt = False # no opt to load in ckpt
del pytree['opt_state']
old_flattened, structure = jax.tree_flatten(pytree)
unsharded = _unshard(shards, old_flattened)
loaded_pytree = jax.tree_unflatten(structure, unsharded)
if not load_opt:
loaded_pytree['opt_state'] = original_opt_state
return loaded_pytree
def read_ckpt_lowmem(pytree, dir, shards_in, shards_out=None, load_opt=True):
if shards_out is None:
shards_out = shards_in
old_flattened, structure = jax.tree_flatten(pytree)
original_opt_state = pytree["opt_state"]
def _unshard():
start = time.time()
unsharded = []
devices = jax.devices()
device_count = len(devices)
device_index = 0
for file_index in range(pieces):
array_keys = [*np.load(f"{dir}shard_0/{file_index}.npz").keys()]
for array_index in range(len(array_keys)):
unstacked = []
for shard_index in range(shards_in):
npz = np.load(f"{dir}shard_{shard_index}/{file_index}.npz")
array = npz[array_keys[array_index]]
if array.dtype == 'V2':
array.dtype = jnp.bfloat16
unstacked.append(array)
x = jax.device_put(jnp.stack(unstacked), device=devices[device_index % device_count])
if shards_out != shards_in:
x = reshard(x, old_flattened[device_index].shape)
unsharded.append(x)
assert x.shape == old_flattened[device_index].shape, f"Incompatible checkpoints {x.shape} vs {old_flattened[device_index].shape}"
device_index += 1
print(f"read from disk/gcs in {time.time() - start:.06}s")
return unsharded
try:
unsharded = _unshard()
except AssertionError:
load_opt = False # no opt to load in ckpt
del pytree['opt_state']
old_flattened, structure = jax.tree_flatten(pytree)
unsharded = _unshard()
loaded_pytree = jax.tree_unflatten(structure, unsharded)
if not load_opt:
loaded_pytree['opt_state'] = original_opt_state
return loaded_pytree
def parallel_write(arrays, fname):
# TODO: make this actually parallel
with open(fname, "wb") as f:
np.savez(f, *arrays)
def parallel_read(old, fname, validate=True):
old_vals, treedef = jax.tree_flatten(old)
if "gs://" in fname:
# TODO: make this actually parallel
with open(fname, "rb") as f:
buf = f.read()
f_io = io.BytesIO(buf)
loaded = np.load(f_io)
else:
loaded = np.load(fname, mmap_mode='r')
new_vals = []
for i in loaded:
new_vals.append(loaded[i])
assert len(new_vals) == len(old_vals), "Incompatible checkpoint"
for o, n in zip(new_vals, old_vals):
if validate:
assert o.shape == n.shape, "Incompatible checkpoint"
return jax.tree_unflatten(treedef, fix_dtype(new_vals))
def tree_flatten_with_names(pytree, is_leaf, path="", to_id=id):
id_to_name = {}
if getattr(pytree, "items", None):
for k, v in pytree.items():
k_path = f"{path}/{k}"
if is_leaf(v):
id_to_name[to_id(v)] = k_path
else:
id_to_name = {**id_to_name, **tree_flatten_with_names(v, is_leaf=is_leaf, path=k_path)}
elif getattr(pytree, "__getitem__", None):
for v in pytree:
if is_leaf(v):
id_to_name[to_id(v)] = path
else:
id_to_name = {**id_to_name, **tree_flatten_with_names(v, is_leaf=is_leaf, path=path)}
else:
id_to_name[to_id(pytree)] = path
return id_to_name
def tree_leaves_with_names(pytree, to_id=id):
leaves = jax.tree_leaves(pytree)
is_leaf = lambda x: not isinstance(x, list) and to_id(x) in [to_id(x) for x in leaves]
return tree_flatten_with_names(pytree, is_leaf)
def write_ckpt_v2(model_state, dir):
start = time.time()
if jax.host_id() == 0:
param_map = tree_leaves_with_names(model_state["params"])
opt_map = tree_leaves_with_names(model_state["opt_state"])
meta = {
"total_hosts": jax.host_count(),
"step": int(model_state["step"]),
"param_order": [param_map[id(i)] for i in jax.tree_leaves(model_state["params"])],
"opt_order": [opt_map[id(i)] for i in jax.tree_leaves(model_state["opt_state"])]
}
print("step:", model_state["step"])
with open(dir + "/meta.json", "w") as f:
json.dump(meta, f)
print(f"meta written in {time.time() - start:.06}s")
start = time.time()
parallel_write(jax.tree_flatten(model_state["params"])[0], dir + f"/params/shard_{jax.host_id()}.npz")
head_print(f"params written in {time.time() - start:.06}s")
start = time.time()
parallel_write(jax.tree_flatten(model_state["opt_state"])[0], dir + f"/opt_state/shard_{jax.host_id()}.npz")
head_print(f"opt_state written in {time.time() - start:.06}s")
def read_sharded_v2(state, dir, checkpoint_hosts, state_shard):
files_per_host = checkpoint_hosts // jax.host_count()
assert files_per_host >= 1, "can't restore model to larger pod than was trained on (yet)"
assert jax.host_count() * files_per_host == checkpoint_hosts, "weird host count"
if files_per_host == 1:
head_print("using fast path of checkpoint restore (save shards == read shards)")
parallel_read(state, dir + f"/shard_{jax.host_id()}.npz")
@ray.remote
def read_remote(old, fname):
return parallel_read(old, fname, validate=False)
start_idx = files_per_host * jax.host_id()
skeleton = jax.tree_map(lambda x: jnp.zeros_like(x, shape=()), state) # a full pytree just to carry dtypes
refs = [
read_remote.remote(skeleton, f"{dir}/shard_{i}.npz")
for i in range(start_idx, start_idx + files_per_host)
]
values = ray.get(refs)
def all_array_equal(iterator):
try:
iterator = iter(iterator)
first = next(iterator)
return all(jnp.array_equal(first, rest) for rest in iterator)
except StopIteration:
return True
def reshard_v2(old, shard_strategy, *new_values):
rep_dim_count = shard_strategy.count(None)
total_dim_count = len(shard_strategy)
# head_print("old.shape", old.shape)
# head_print("shard_strategy", shard_strategy)
assert len(old.shape) == total_dim_count
if rep_dim_count == total_dim_count:
# fully replicated
assert all_array_equal(new_values)
return fix_dtype(new_values[0])
shard_dim = [idx for idx, dim in enumerate(shard_strategy) if dim is not None and "mp" in dim]
# only support sharding in 1d for now
assert len(shard_dim) == 1
shard_dim = shard_dim[0]
ret_val = jnp.concatenate(fix_dtype(new_values), axis=shard_dim)
assert old.shape == ret_val.shape
return jax.device_put(ret_val, jax.devices("cpu")[0])
# head_print("state", jax.tree_structure(state))
# head_print("state_shard", jax.tree_structure(state_shard))
# head_print("values", jax.tree_structure(values[0]))
return jax.tree_multimap(reshard_v2, *([state, state_shard] + values))
def load_ckpt_v2(model_state, dir, state_shard, load_opt):
start = time.time()
with open(dir + "meta.json", "r") as f:
meta = json.load(f)
ckpt_hosts = meta["total_hosts"]
head_print(f"meta loaded in {time.time() - start:.06}s")
new_state = {
"step": np.array([meta["step"]]),
}
start = time.time()
new_state["params"] = read_sharded_v2(model_state["params"],
dir + "params",
ckpt_hosts,
state_shard["params"])
head_print(f"params loaded in {time.time() - start:.06}s")
if not load_opt:
return new_state
start = time.time()
new_state["opt_state"] = read_sharded_v2(model_state["opt_state"],
dir + "opt_state",
ckpt_hosts,
state_shard["opt_state"])
head_print(f"opt_state loaded in {time.time() - start:.06}s")
return new_state