@@ -206,12 +206,48 @@ def parallel_read(old, fname, validate=True):
206
206
return jax .tree_unflatten (treedef , fix_dtype (new_vals ))
207
207
208
208
209
+ def tree_flatten_with_names (pytree , is_leaf , path = "" , to_id = id ):
210
+ id_to_name = {}
211
+ if getattr (pytree , "items" , None ):
212
+ for k , v in pytree .items ():
213
+ k_path = f"{ path } /{ k } "
214
+ if is_leaf (v ):
215
+ id_to_name [to_id (v )] = k_path
216
+ else :
217
+ id_to_name = {** id_to_name , ** tree_flatten_with_names (v , is_leaf = is_leaf , path = k_path )}
218
+ elif getattr (pytree , "__getitem__" , None ):
219
+ for v in pytree :
220
+ if is_leaf (v ):
221
+ id_to_name [to_id (v )] = path
222
+ else :
223
+ id_to_name = {** id_to_name , ** tree_flatten_with_names (v , is_leaf = is_leaf , path = path )}
224
+ else :
225
+ id_to_name [to_id (pytree )] = path
226
+ return id_to_name
227
+
228
+
229
+ def tree_leaves_with_names (pytree , to_id = id ):
230
+ leaves = jax .tree_leaves (pytree )
231
+ is_leaf = lambda x : not isinstance (x , list ) and to_id (x ) in [to_id (x ) for x in leaves ]
232
+ return tree_flatten_with_names (pytree , is_leaf )
233
+
234
+
209
235
def write_ckpt_v2 (model_state , dir ):
210
236
start = time .time ()
211
237
if jax .host_id () == 0 :
238
+ param_map = tree_leaves_with_names (model_state ["params" ])
239
+ opt_map = tree_leaves_with_names (model_state ["opt_state" ])
240
+
241
+ meta = {
242
+ "total_hosts" : jax .host_count (),
243
+ "step" : int (model_state ["step" ]),
244
+ "param_order" : [param_map [id (i )] for i in jax .tree_leaves (model_state ["params" ])],
245
+ "opt_order" : [opt_map [id (i )] for i in jax .tree_leaves (model_state ["opt_state" ])]
246
+ }
247
+
212
248
print ("step:" , model_state ["step" ])
213
249
with open (dir + "/meta.json" , "w" ) as f :
214
- json .dump ({ "total_hosts" : jax . host_count (), "step" : int ( model_state [ "step" ])} , f )
250
+ json .dump (meta , f )
215
251
print (f"meta written in { time .time () - start :.06} s" )
216
252
217
253
start = time .time ()
@@ -289,11 +325,8 @@ def reshard_v2(old, shard_strategy, *new_values):
289
325
290
326
291
327
def load_ckpt_v2 (model_state , dir , state_shard , load_opt ):
292
- while dir .endswith ("/" ):
293
- dir = dir [:- 1 ]
294
-
295
328
start = time .time ()
296
- with open (dir + "/ meta.json" , "r" ) as f :
329
+ with open (dir + "meta.json" , "r" ) as f :
297
330
meta = json .load (f )
298
331
299
332
ckpt_hosts = meta ["total_hosts" ]
@@ -306,7 +339,7 @@ def load_ckpt_v2(model_state, dir, state_shard, load_opt):
306
339
307
340
start = time .time ()
308
341
new_state ["params" ] = read_sharded_v2 (model_state ["params" ],
309
- dir + "/ params" ,
342
+ dir + "params" ,
310
343
ckpt_hosts ,
311
344
state_shard ["params" ])
312
345
head_print (f"params loaded in { time .time () - start :.06} s" )
@@ -316,7 +349,7 @@ def load_ckpt_v2(model_state, dir, state_shard, load_opt):
316
349
317
350
start = time .time ()
318
351
new_state ["opt_state" ] = read_sharded_v2 (model_state ["opt_state" ],
319
- dir + "/ opt_state" ,
352
+ dir + "opt_state" ,
320
353
ckpt_hosts ,
321
354
state_shard ["opt_state" ])
322
355
head_print (f"opt_state loaded in { time .time () - start :.06} s" )
0 commit comments