Skip to content

Commit 61f8ab4

Browse files
committed
fix checkpoint path
1 parent 9d15698 commit 61f8ab4

File tree

3 files changed

+43
-9
lines changed

3 files changed

+43
-9
lines changed

mesh_transformer/TPU_cluster.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def save(self, step, bucket, path, aux=None, init=False, overwrite=False, keep_n
203203
res.append(node.write_ckpt.remote(f"gs://{bucket}/{path}/step_{step}/", shard_id))
204204
elif self.version == 2:
205205
for node in self.nodes:
206-
res.append(node.write_ckpt.remote(f"gs://{bucket}/{path}/step_{step}/", 0))
206+
res.append(node.write_ckpt.remote(f"gs://{bucket}/{path}/step_{step}", 0))
207207

208208
ray.get(res)
209209
print(f"Wrote checkpoint in {time.time() - start:.06}s")

mesh_transformer/checkpoint.py

+40-7
Original file line numberDiff line numberDiff line change
@@ -206,12 +206,48 @@ def parallel_read(old, fname, validate=True):
206206
return jax.tree_unflatten(treedef, fix_dtype(new_vals))
207207

208208

209+
def tree_flatten_with_names(pytree, is_leaf, path="", to_id=id):
210+
id_to_name = {}
211+
if getattr(pytree, "items", None):
212+
for k, v in pytree.items():
213+
k_path = f"{path}/{k}"
214+
if is_leaf(v):
215+
id_to_name[to_id(v)] = k_path
216+
else:
217+
id_to_name = {**id_to_name, **tree_flatten_with_names(v, is_leaf=is_leaf, path=k_path)}
218+
elif getattr(pytree, "__getitem__", None):
219+
for v in pytree:
220+
if is_leaf(v):
221+
id_to_name[to_id(v)] = path
222+
else:
223+
id_to_name = {**id_to_name, **tree_flatten_with_names(v, is_leaf=is_leaf, path=path)}
224+
else:
225+
id_to_name[to_id(pytree)] = path
226+
return id_to_name
227+
228+
229+
def tree_leaves_with_names(pytree, to_id=id):
230+
leaves = jax.tree_leaves(pytree)
231+
is_leaf = lambda x: not isinstance(x, list) and to_id(x) in [to_id(x) for x in leaves]
232+
return tree_flatten_with_names(pytree, is_leaf)
233+
234+
209235
def write_ckpt_v2(model_state, dir):
210236
start = time.time()
211237
if jax.host_id() == 0:
238+
param_map = tree_leaves_with_names(model_state["params"])
239+
opt_map = tree_leaves_with_names(model_state["opt_state"])
240+
241+
meta = {
242+
"total_hosts": jax.host_count(),
243+
"step": int(model_state["step"]),
244+
"param_order": [param_map[id(i)] for i in jax.tree_leaves(model_state["params"])],
245+
"opt_order": [opt_map[id(i)] for i in jax.tree_leaves(model_state["opt_state"])]
246+
}
247+
212248
print("step:", model_state["step"])
213249
with open(dir + "/meta.json", "w") as f:
214-
json.dump({"total_hosts": jax.host_count(), "step": int(model_state["step"])}, f)
250+
json.dump(meta, f)
215251
print(f"meta written in {time.time() - start:.06}s")
216252

217253
start = time.time()
@@ -289,11 +325,8 @@ def reshard_v2(old, shard_strategy, *new_values):
289325

290326

291327
def load_ckpt_v2(model_state, dir, state_shard, load_opt):
292-
while dir.endswith("/"):
293-
dir = dir[:-1]
294-
295328
start = time.time()
296-
with open(dir + "/meta.json", "r") as f:
329+
with open(dir + "meta.json", "r") as f:
297330
meta = json.load(f)
298331

299332
ckpt_hosts = meta["total_hosts"]
@@ -306,7 +339,7 @@ def load_ckpt_v2(model_state, dir, state_shard, load_opt):
306339

307340
start = time.time()
308341
new_state["params"] = read_sharded_v2(model_state["params"],
309-
dir + "/params",
342+
dir + "params",
310343
ckpt_hosts,
311344
state_shard["params"])
312345
head_print(f"params loaded in {time.time() - start:.06}s")
@@ -316,7 +349,7 @@ def load_ckpt_v2(model_state, dir, state_shard, load_opt):
316349

317350
start = time.time()
318351
new_state["opt_state"] = read_sharded_v2(model_state["opt_state"],
319-
dir + "/opt_state",
352+
dir + "opt_state",
320353
ckpt_hosts,
321354
state_shard["opt_state"])
322355
head_print(f"opt_state loaded in {time.time() - start:.06}s")

mesh_transformer/train_actor.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@ def run(self):
2222
import haiku as hk
2323
# jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = True
2424

25-
# thread_resources.env = ResourceEnv(Mesh(np.empty((), dtype=object), ()))
25+
thread_resources.env = ResourceEnv(Mesh(np.empty((), dtype=object), ()), ())
2626

2727
start = time.time()
2828
jax.devices()
2929

3030
import warnings
3131
warnings.filterwarnings("ignore")
32+
warnings.filterwarnings("ignore", category=ResourceWarning)
3233

3334
if jax.host_id() == 0:
3435
warnings.filterwarnings("default")

0 commit comments

Comments
 (0)