@@ -399,6 +399,16 @@ def maybe_get_first(v):
399
399
return grad_carry , grad_x
400
400
401
401
402
+ def _save_for_backward (ctx , pytree ) -> None :
403
+ flat , tree_spec = tree_flatten (pytree )
404
+ ctx ._saved_tensors_spec = tree_spec
405
+ ctx .save_for_backward (* flat )
406
+
407
+
408
+ def _load_from_context (ctx ):
409
+ return tree_unflatten (ctx .saved_tensors , ctx ._saved_tensors_spec )
410
+
411
+
402
412
@pytreeify
403
413
class Scan (torch .autograd .Function ):
404
414
@@ -409,37 +419,28 @@ def forward(ctx, forward, alias_input, backward, init, xs):
409
419
ys , partial_activations = ys
410
420
activations = alias_input (partial_activations , xs )
411
421
ctx ._backward = backward
412
-
413
- if torch_xla .runtime .is_spmd ():
414
- flat_init , carry_spec = tree_flatten (init )
415
- flat_xs , xs_spec = tree_flatten (xs )
416
- ctx ._carry_spec = carry_spec
417
- ctx ._xs_spec = xs_spec
418
- ctx ._flat_init_len = len (flat_init )
419
- ctx ._flat_xs_len = len (flat_xs )
420
- ctx .save_for_backward (* flat_init , * flat_xs , * activations )
421
- else :
422
- ctx .save_for_backward (* activations )
423
-
422
+ _save_for_backward (ctx , {
423
+ "init" : init ,
424
+ "xs" : xs ,
425
+ "activations" : activations
426
+ })
424
427
return carry , ys
425
428
426
429
@staticmethod
427
430
def backward (ctx , grad_carry , grad_ys ): # type: ignore
431
+ saved = _load_from_context (ctx )
428
432
if torch_xla .runtime .is_spmd ():
429
- stuff = ctx .saved_tensors
430
- flat_init , flat_xs , activations = split (stuff , ctx ._flat_init_len ,
431
- ctx ._flat_xs_len )
432
- init = tree_unflatten (flat_init , ctx ._carry_spec )
433
- xs = tree_unflatten (flat_xs , ctx ._xs_spec )
434
433
backward = partial (
435
- _backward_shard_alike , backward = ctx ._backward , init = init , xs = xs )
434
+ _backward_shard_alike ,
435
+ backward = ctx ._backward ,
436
+ init = saved ["init" ],
437
+ xs = saved ["xs" ])
436
438
else :
437
- activations = ctx .saved_tensors
438
439
backward = ctx ._backward
439
440
with torch .no_grad ():
440
441
# Reverse loop to propagate gradients from last iteration to first.
441
442
grad_init , grad_xs = _scan_impl_pytree (
442
- backward , grad_carry , (grad_ys , activations ), reverse = True )
443
+ backward , grad_carry , (grad_ys , saved [ " activations" ] ), reverse = True )
443
444
return None , None , None , grad_init , grad_xs
444
445
445
446
0 commit comments