Skip to content

Commit de15e74

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo] Activation checkpointing as higher order op (pytorch#101028)
Pull Request resolved: pytorch#101028 Approved by: https://github.com/voznesenskym, https://github.com/zou3519
1 parent c5c75aa commit de15e74

File tree

8 files changed

+269
-9
lines changed

8 files changed

+269
-9
lines changed

test/dynamo/test_higher_order_ops.py

+167
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
# Owner(s): ["module: dynamo"]
2+
import functools
23
import unittest
34

45
import torch
56

67
import torch._dynamo.test_case
8+
import torch._functorch.config
9+
import torch.utils.checkpoint
10+
from torch._dynamo.backends.common import aot_autograd
711
from torch._dynamo.testing import CompileCounter, CompileCounterWithBackend
812
from torch._dynamo.utils import counters
913
from torch._higher_order_ops.wrap import wrap
14+
from torch.testing._internal.inductor_utils import HAS_CUDA
15+
16+
17+
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
1018

1119

1220
# Equivalent to backend="eager", but also records graphs that
@@ -20,6 +28,11 @@ def __call__(self, gm: torch.fx.GraphModule, example_inputs):
2028
return gm
2129

2230

31+
def count_ops(gm, args, freq, op):
32+
assert [node.target for node in gm.graph.nodes].count(op) == freq
33+
return gm
34+
35+
2336
global_var = torch.randn(3)
2437
global_num = 3.14
2538

@@ -406,6 +419,160 @@ def f(x):
406419
self._test_wrap_simple(f, (x,), 3, expected_opcount=2)
407420

408421

422+
class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
423+
def _validate(self, fn, backend, *args, skip_check=False, fullgraph=True):
424+
cloned_args = []
425+
for arg in args:
426+
cloned_args.append(arg.clone().detach().requires_grad_(arg.requires_grad))
427+
428+
expected = fn(*args)
429+
expected.sum().backward()
430+
431+
result = torch.compile(fn, fullgraph=fullgraph, backend=backend)(*cloned_args)
432+
result.sum().backward()
433+
434+
if not skip_check:
435+
self.assertEqual(result, expected)
436+
for arg, cloned_arg in zip(args, cloned_args):
437+
self.assertEqual(arg.grad, cloned_arg.grad)
438+
439+
@requires_cuda()
440+
@torch._functorch.config.patch(functionalize_rng_ops=True)
441+
def test_function(self):
442+
def gn(x, y):
443+
return torch.sigmoid(torch.matmul(x, y))
444+
445+
def fn(x, y):
446+
return torch.utils.checkpoint.checkpoint(gn, torch.sin(x), y)
447+
448+
x = torch.randn(4, 4, requires_grad=True)
449+
y = torch.randn(4, 4, requires_grad=True)
450+
451+
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
452+
bw_compiler = functools.partial(
453+
count_ops, freq=3, op=torch.ops.aten.mm.default
454+
) # mm recomputed in the bwd
455+
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
456+
self._validate(fn, backend, x, y)
457+
458+
@requires_cuda()
459+
@torch._functorch.config.patch(functionalize_rng_ops=True)
460+
def test_function_with_kwargs(self):
461+
def gn(x, y):
462+
return torch.sigmoid(torch.matmul(x, y))
463+
464+
def fn(x, y):
465+
return torch.utils.checkpoint.checkpoint(
466+
gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False
467+
)
468+
469+
x = torch.randn(4, 4, requires_grad=True)
470+
y = torch.randn(4, 4, requires_grad=True)
471+
472+
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
473+
bw_compiler = functools.partial(
474+
count_ops, freq=3, op=torch.ops.aten.mm.default
475+
) # mm recomputed in the bwd
476+
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
477+
self._validate(fn, backend, x, y)
478+
479+
@requires_cuda()
480+
@torch._functorch.config.patch(functionalize_rng_ops=True)
481+
def test_dropout(self):
482+
def gn(x, y):
483+
return torch.nn.functional.dropout(torch.matmul(x, y), p=0.2)
484+
485+
def fn(x, y):
486+
return torch.utils.checkpoint.checkpoint(gn, torch.sin(x), y)
487+
488+
x = torch.randn(4, 4, device="cuda", requires_grad=True)
489+
y = torch.randn(4, 4, device="cuda", requires_grad=True)
490+
491+
fw_compiler = functools.partial(
492+
count_ops, freq=1, op=torch.ops.rngprims.philox_rand.default
493+
)
494+
bw_compiler = functools.partial(
495+
count_ops, freq=1, op=torch.ops.rngprims.philox_rand.default
496+
)
497+
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
498+
self._validate(
499+
fn, backend, x, y, skip_check=True
500+
) # dropout decomp is known to diverge with eager
501+
502+
@requires_cuda()
503+
@torch._functorch.config.patch(functionalize_rng_ops=True)
504+
def test_fallback(self):
505+
def gn(x, y):
506+
torch._dynamo.graph_break()
507+
return torch.sigmoid(torch.matmul(x, y))
508+
509+
def fn(x, y):
510+
return torch.cos(torch.utils.checkpoint.checkpoint(gn, torch.sin(x), y))
511+
512+
x = torch.randn(4, 4, requires_grad=True)
513+
y = torch.randn(4, 4, requires_grad=True)
514+
args = (x, y)
515+
516+
backend = EagerAndRecordGraphs()
517+
cnt = CompileCounterWithBackend(backend)
518+
519+
expected = fn(*args)
520+
result = torch.compile(fn, backend=cnt)(*args)
521+
522+
self.assertEqual(result, expected)
523+
524+
# One graph for torch.sin on the input, and other for torch.cos.
525+
self.assertEqual(cnt.frame_count, 2)
526+
self.assertEqual(cnt.op_count, 2)
527+
self.assertEqual(len(backend.graphs), 2)
528+
529+
def test_without_functionalization_turned_on(self):
530+
def gn(x, y):
531+
return torch.sigmoid(torch.matmul(x, y))
532+
533+
def fn(x, y):
534+
return torch.cos(torch.utils.checkpoint.checkpoint(gn, torch.sin(x), y))
535+
536+
x = torch.randn(4, 4, requires_grad=True)
537+
y = torch.randn(4, 4, requires_grad=True)
538+
args = (x, y)
539+
540+
backend = EagerAndRecordGraphs()
541+
cnt = CompileCounterWithBackend(backend)
542+
543+
expected = fn(*args)
544+
result = torch.compile(fn, backend=cnt)(*args)
545+
546+
self.assertEqual(result, expected)
547+
548+
# Higher order op does not support nn.Modules yet
549+
@unittest.expectedFailure
550+
@requires_cuda()
551+
@torch._functorch.config.patch(functionalize_rng_ops=True)
552+
def test_module(self):
553+
class MockModule(torch.nn.Module):
554+
def __init__(self):
555+
super().__init__()
556+
self.linear = torch.nn.Linear(10, 10)
557+
558+
def forward(self, x):
559+
return torch.sigmoid(self.linear(x))
560+
561+
mod = MockModule()
562+
563+
def fn(x):
564+
return torch.utils.checkpoint.checkpoint(mod, torch.sin(x))
565+
566+
x = torch.randn(10, 10, requires_grad=True)
567+
568+
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
569+
bw_compiler = functools.partial(
570+
count_ops, freq=3, op=torch.ops.aten.mm.default
571+
) # mm recomputed in the bwd
572+
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
573+
self._validate(fn, backend, x)
574+
575+
409576
if __name__ == "__main__":
410577
from torch._dynamo.test_case import run_tests
411578

torch/_dynamo/eval_frame.py

+13
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import torch
2424
import torch.fx
2525
import torch.utils._pytree as pytree
26+
import torch.utils.checkpoint
2627
from torch import _guards
2728
from torch.fx.experimental.proxy_tensor import make_fx
2829
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
@@ -1270,6 +1271,18 @@ def patch():
12701271
# disable future hooking
12711272
opt.step.hooked = True
12721273

1274+
# TorchDynamo does not step inside utils.checkpoint function. The flow
1275+
# looks likes this
1276+
# 1) TorchDynamo tries to wrap utils.checkpoint in a HigherOrderOp by
1277+
# speculatively checking if the forward function is safe to trace.
1278+
# 2) If yes, then Dynamo-generated Fx graph has the wrapped higher
1279+
# order op. As a result, TorchDynamo does not look inside utils.checkpoint.
1280+
# 3) If not, then TorchDynamo falls back to eager by performing a graph
1281+
# break. And here, the following disable wrapper ensures that
1282+
# TorchDynamo does not trigger again on the frames created by
1283+
# utils.checkpoint innards.
1284+
torch.utils.checkpoint.checkpoint = disable(torch.utils.checkpoint.checkpoint)
1285+
12731286
@staticmethod
12741287
def suppress_torch_distributed_warnings(fn):
12751288
def inner_fn(*args, **kwargs):

torch/_dynamo/utils.py

+47
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,17 @@
4848
import importlib
4949

5050
import torch
51+
import torch._functorch.config
52+
import torch._higher_order_ops.wrap
5153
import torch.fx.experimental.symbolic_shapes
54+
import torch.utils.checkpoint
5255
from torch import fx
5356
from torch._dispatch.python import enable_python_dispatcher
5457
from torch._subclasses.fake_tensor import FakeTensor
5558
from torch.nn.modules.lazy import LazyModuleMixin
5659
from torch.utils._pytree import tree_map
5760

61+
5862
counters = collections.defaultdict(collections.Counter)
5963
troubleshooting_url = "https://pytorch.org/docs/master/compile/troubleshooting.html"
6064
nnmodule_doc_url = "https://pytorch.org/docs/master/compile/nn-module.html"
@@ -1620,3 +1624,46 @@ def defake(x):
16201624
)
16211625
y.zero_()
16221626
return y
1627+
1628+
1629+
# NB: The dictionary has to be created lazily after TorchPatcher is called so
1630+
# that we pick up the disabled torch.utils.checkpoint wrapper. Therefore, it is
1631+
# sitting in a separate function.
1632+
@functools.lru_cache(None)
1633+
def higher_order_op_converter():
1634+
return {
1635+
torch.utils.checkpoint.checkpoint: torch._higher_order_ops.wrap.wrap_activation_checkpoint,
1636+
}
1637+
1638+
1639+
def requires_higher_order_op(obj):
1640+
return obj in higher_order_op_converter()
1641+
1642+
1643+
def get_higher_order_op(obj):
1644+
if (
1645+
obj is torch.utils.checkpoint.checkpoint
1646+
and not torch._functorch.config.functionalize_rng_ops
1647+
):
1648+
from .exc import unimplemented
1649+
1650+
# TODO - functionalize_rng_ops flags cannot be turned ON by default
1651+
# because 1) Performance concerns - seed and offset are read and passed
1652+
# to each AOT graph 2) Inductor has rand-specific optimizations and
1653+
# there is work remaining to compose them together with
1654+
# functionalization.
1655+
#
1656+
# Until we make it ON by default, we will have to ask users to turn on
1657+
# this flag manually. TODO - Revisit if there is a simpler way to
1658+
# resolve this problem.
1659+
torch._logging.warning_once(
1660+
log,
1661+
"torch.compile on activation checkpointing is an experimental feature. "
1662+
"Please manually set torch._functorch.config.functionalize_rng_ops=True "
1663+
"to run torch.compile with activation checkpointing. Without this flag, "
1664+
"checkpointed function will not get compiled and fallback to eager.",
1665+
)
1666+
unimplemented(
1667+
"torch.compile requires functioanlization of rng ops to be turned on"
1668+
)
1669+
return higher_order_op_converter().get(obj)

torch/_dynamo/variables/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
TensorVariable,
4747
UnspecializedPythonVariable,
4848
)
49-
from .torch import TorchVariable
49+
from .torch import TorchHigherOrderOperatorVariable, TorchVariable
5050
from .user_defined import UserDefinedClassVariable, UserDefinedObjectVariable
5151

5252
__all__ = [

torch/_dynamo/variables/builder.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
np,
5454
odict_values,
5555
preserve_rng_state,
56+
requires_higher_order_op,
5657
tensor_always_has_static_shape,
5758
torch_np,
5859
tuple_iterator,
@@ -103,7 +104,7 @@
103104
from .torch import (
104105
tensor_dunder_fns,
105106
torch_special_class_types,
106-
TorchHigherOrderOperator,
107+
TorchHigherOrderOperatorVariable,
107108
TorchVariable,
108109
)
109110
from .user_defined import UserDefinedClassVariable, UserDefinedObjectVariable
@@ -425,6 +426,7 @@ def index_source(key):
425426
istype(value, (type, types.FunctionType))
426427
and skipfiles.check(getfile(value), allow_torch=True)
427428
and not inspect.getattr_static(value, "_torchdynamo_inline", False)
429+
and not requires_higher_order_op(value)
428430
):
429431
return SkipFilesVariable(
430432
value,
@@ -489,7 +491,7 @@ def index_source(key):
489491
value, guards=make_guards(GuardBuilder.TYPE_MATCH)
490492
)
491493
elif isinstance(value, HigherOrderOperator):
492-
return TorchHigherOrderOperator(
494+
return TorchHigherOrderOperatorVariable(
493495
value,
494496
guards=self.make_guards(
495497
GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH

torch/_dynamo/variables/builtin.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
from ..utils import (
2020
check_constant_args,
2121
check_unspec_python_args,
22+
get_higher_order_op,
2223
istype,
2324
proxy_args_kwargs,
25+
requires_higher_order_op,
2426
specialize_args_kwargs,
2527
)
2628
from .base import MutableLocal, typestr, VariableTracker
@@ -992,6 +994,7 @@ def call_getattr(
992994
ConstantVariable,
993995
GetAttrVariable,
994996
PythonModuleVariable,
997+
TorchHigherOrderOperatorVariable,
995998
TorchVariable,
996999
UserFunctionVariable,
9971000
)
@@ -1059,7 +1062,11 @@ def call_getattr(
10591062
return GetAttrVariable(obj, name, **options)
10601063
elif isinstance(obj, TorchVariable):
10611064
member = getattr(obj.value, name)
1062-
if is_allowed(member):
1065+
if requires_higher_order_op(member):
1066+
return TorchHigherOrderOperatorVariable(
1067+
get_higher_order_op(member), **options
1068+
)
1069+
elif is_allowed(member):
10631070
return TorchVariable(member, **options)
10641071
elif ConstantVariable.is_literal(member):
10651072
return ConstantVariable(member, **options)

0 commit comments

Comments
 (0)