Skip to content

Commit 4637fb0

Browse files
committed
Simplify
1 parent f625d70 commit 4637fb0

File tree

1 file changed

+21
-20
lines changed

1 file changed

+21
-20
lines changed

torch_xla/experimental/scan.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,16 @@ def maybe_get_first(v):
399399
return grad_carry, grad_x
400400

401401

402+
def _save_for_backward(ctx, pytree) -> None:
403+
flat, tree_spec = tree_flatten(pytree)
404+
ctx._saved_tensors_spec = tree_spec
405+
ctx.save_for_backward(*flat)
406+
407+
408+
def _load_from_context(ctx):
409+
return tree_unflatten(ctx.saved_tensors, ctx._saved_tensors_spec)
410+
411+
402412
@pytreeify
403413
class Scan(torch.autograd.Function):
404414

@@ -409,37 +419,28 @@ def forward(ctx, forward, alias_input, backward, init, xs):
409419
ys, partial_activations = ys
410420
activations = alias_input(partial_activations, xs)
411421
ctx._backward = backward
412-
413-
if torch_xla.runtime.is_spmd():
414-
flat_init, carry_spec = tree_flatten(init)
415-
flat_xs, xs_spec = tree_flatten(xs)
416-
ctx._carry_spec = carry_spec
417-
ctx._xs_spec = xs_spec
418-
ctx._flat_init_len = len(flat_init)
419-
ctx._flat_xs_len = len(flat_xs)
420-
ctx.save_for_backward(*flat_init, *flat_xs, *activations)
421-
else:
422-
ctx.save_for_backward(*activations)
423-
422+
_save_for_backward(ctx, {
423+
"init": init,
424+
"xs": xs,
425+
"activations": activations
426+
})
424427
return carry, ys
425428

426429
@staticmethod
427430
def backward(ctx, grad_carry, grad_ys): # type: ignore
431+
saved = _load_from_context(ctx)
428432
if torch_xla.runtime.is_spmd():
429-
stuff = ctx.saved_tensors
430-
flat_init, flat_xs, activations = split(stuff, ctx._flat_init_len,
431-
ctx._flat_xs_len)
432-
init = tree_unflatten(flat_init, ctx._carry_spec)
433-
xs = tree_unflatten(flat_xs, ctx._xs_spec)
434433
backward = partial(
435-
_backward_shard_alike, backward=ctx._backward, init=init, xs=xs)
434+
_backward_shard_alike,
435+
backward=ctx._backward,
436+
init=saved["init"],
437+
xs=saved["xs"])
436438
else:
437-
activations = ctx.saved_tensors
438439
backward = ctx._backward
439440
with torch.no_grad():
440441
# Reverse loop to propagate gradients from last iteration to first.
441442
grad_init, grad_xs = _scan_impl_pytree(
442-
backward, grad_carry, (grad_ys, activations), reverse=True)
443+
backward, grad_carry, (grad_ys, saved["activations"]), reverse=True)
443444
return None, None, None, grad_init, grad_xs
444445

445446

0 commit comments

Comments
 (0)