1
1
# Owner(s): ["module: dynamo"]
2
+ import functools
2
3
import unittest
3
4
4
5
import torch
5
6
6
7
import torch ._dynamo .test_case
8
+ import torch ._functorch .config
9
+ import torch .utils .checkpoint
10
+ from torch ._dynamo .backends .common import aot_autograd
7
11
from torch ._dynamo .testing import CompileCounter , CompileCounterWithBackend
8
12
from torch ._dynamo .utils import counters
9
13
from torch ._higher_order_ops .wrap import wrap
14
+ from torch .testing ._internal .inductor_utils import HAS_CUDA
15
+
16
+
17
+ requires_cuda = functools .partial (unittest .skipIf , not HAS_CUDA , "requires cuda" )
10
18
11
19
12
20
# Equivalent to backend="eager", but also records graphs that
@@ -20,6 +28,11 @@ def __call__(self, gm: torch.fx.GraphModule, example_inputs):
20
28
return gm
21
29
22
30
31
+ def count_ops (gm , args , freq , op ):
32
+ assert [node .target for node in gm .graph .nodes ].count (op ) == freq
33
+ return gm
34
+
35
+
23
36
global_var = torch .randn (3 )
24
37
global_num = 3.14
25
38
@@ -406,6 +419,160 @@ def f(x):
406
419
self ._test_wrap_simple (f , (x ,), 3 , expected_opcount = 2 )
407
420
408
421
422
+ class ActivationCheckpointingTests (torch ._dynamo .test_case .TestCase ):
423
+ def _validate (self , fn , backend , * args , skip_check = False , fullgraph = True ):
424
+ cloned_args = []
425
+ for arg in args :
426
+ cloned_args .append (arg .clone ().detach ().requires_grad_ (arg .requires_grad ))
427
+
428
+ expected = fn (* args )
429
+ expected .sum ().backward ()
430
+
431
+ result = torch .compile (fn , fullgraph = fullgraph , backend = backend )(* cloned_args )
432
+ result .sum ().backward ()
433
+
434
+ if not skip_check :
435
+ self .assertEqual (result , expected )
436
+ for arg , cloned_arg in zip (args , cloned_args ):
437
+ self .assertEqual (arg .grad , cloned_arg .grad )
438
+
439
+ @requires_cuda ()
440
+ @torch ._functorch .config .patch (functionalize_rng_ops = True )
441
+ def test_function (self ):
442
+ def gn (x , y ):
443
+ return torch .sigmoid (torch .matmul (x , y ))
444
+
445
+ def fn (x , y ):
446
+ return torch .utils .checkpoint .checkpoint (gn , torch .sin (x ), y )
447
+
448
+ x = torch .randn (4 , 4 , requires_grad = True )
449
+ y = torch .randn (4 , 4 , requires_grad = True )
450
+
451
+ fw_compiler = functools .partial (count_ops , freq = 1 , op = torch .ops .aten .mm .default )
452
+ bw_compiler = functools .partial (
453
+ count_ops , freq = 3 , op = torch .ops .aten .mm .default
454
+ ) # mm recomputed in the bwd
455
+ backend = aot_autograd (fw_compiler = fw_compiler , bw_compiler = bw_compiler )
456
+ self ._validate (fn , backend , x , y )
457
+
458
+ @requires_cuda ()
459
+ @torch ._functorch .config .patch (functionalize_rng_ops = True )
460
+ def test_function_with_kwargs (self ):
461
+ def gn (x , y ):
462
+ return torch .sigmoid (torch .matmul (x , y ))
463
+
464
+ def fn (x , y ):
465
+ return torch .utils .checkpoint .checkpoint (
466
+ gn , torch .sin (x ), y , use_reentrant = True , preserve_rng_state = False
467
+ )
468
+
469
+ x = torch .randn (4 , 4 , requires_grad = True )
470
+ y = torch .randn (4 , 4 , requires_grad = True )
471
+
472
+ fw_compiler = functools .partial (count_ops , freq = 1 , op = torch .ops .aten .mm .default )
473
+ bw_compiler = functools .partial (
474
+ count_ops , freq = 3 , op = torch .ops .aten .mm .default
475
+ ) # mm recomputed in the bwd
476
+ backend = aot_autograd (fw_compiler = fw_compiler , bw_compiler = bw_compiler )
477
+ self ._validate (fn , backend , x , y )
478
+
479
+ @requires_cuda ()
480
+ @torch ._functorch .config .patch (functionalize_rng_ops = True )
481
+ def test_dropout (self ):
482
+ def gn (x , y ):
483
+ return torch .nn .functional .dropout (torch .matmul (x , y ), p = 0.2 )
484
+
485
+ def fn (x , y ):
486
+ return torch .utils .checkpoint .checkpoint (gn , torch .sin (x ), y )
487
+
488
+ x = torch .randn (4 , 4 , device = "cuda" , requires_grad = True )
489
+ y = torch .randn (4 , 4 , device = "cuda" , requires_grad = True )
490
+
491
+ fw_compiler = functools .partial (
492
+ count_ops , freq = 1 , op = torch .ops .rngprims .philox_rand .default
493
+ )
494
+ bw_compiler = functools .partial (
495
+ count_ops , freq = 1 , op = torch .ops .rngprims .philox_rand .default
496
+ )
497
+ backend = aot_autograd (fw_compiler = fw_compiler , bw_compiler = bw_compiler )
498
+ self ._validate (
499
+ fn , backend , x , y , skip_check = True
500
+ ) # dropout decomp is known to diverge with eager
501
+
502
+ @requires_cuda ()
503
+ @torch ._functorch .config .patch (functionalize_rng_ops = True )
504
+ def test_fallback (self ):
505
+ def gn (x , y ):
506
+ torch ._dynamo .graph_break ()
507
+ return torch .sigmoid (torch .matmul (x , y ))
508
+
509
+ def fn (x , y ):
510
+ return torch .cos (torch .utils .checkpoint .checkpoint (gn , torch .sin (x ), y ))
511
+
512
+ x = torch .randn (4 , 4 , requires_grad = True )
513
+ y = torch .randn (4 , 4 , requires_grad = True )
514
+ args = (x , y )
515
+
516
+ backend = EagerAndRecordGraphs ()
517
+ cnt = CompileCounterWithBackend (backend )
518
+
519
+ expected = fn (* args )
520
+ result = torch .compile (fn , backend = cnt )(* args )
521
+
522
+ self .assertEqual (result , expected )
523
+
524
+ # One graph for torch.sin on the input, and other for torch.cos.
525
+ self .assertEqual (cnt .frame_count , 2 )
526
+ self .assertEqual (cnt .op_count , 2 )
527
+ self .assertEqual (len (backend .graphs ), 2 )
528
+
529
+ def test_without_functionalization_turned_on (self ):
530
+ def gn (x , y ):
531
+ return torch .sigmoid (torch .matmul (x , y ))
532
+
533
+ def fn (x , y ):
534
+ return torch .cos (torch .utils .checkpoint .checkpoint (gn , torch .sin (x ), y ))
535
+
536
+ x = torch .randn (4 , 4 , requires_grad = True )
537
+ y = torch .randn (4 , 4 , requires_grad = True )
538
+ args = (x , y )
539
+
540
+ backend = EagerAndRecordGraphs ()
541
+ cnt = CompileCounterWithBackend (backend )
542
+
543
+ expected = fn (* args )
544
+ result = torch .compile (fn , backend = cnt )(* args )
545
+
546
+ self .assertEqual (result , expected )
547
+
548
+ # Higher order op does not support nn.Modules yet
549
+ @unittest .expectedFailure
550
+ @requires_cuda ()
551
+ @torch ._functorch .config .patch (functionalize_rng_ops = True )
552
+ def test_module (self ):
553
+ class MockModule (torch .nn .Module ):
554
+ def __init__ (self ):
555
+ super ().__init__ ()
556
+ self .linear = torch .nn .Linear (10 , 10 )
557
+
558
+ def forward (self , x ):
559
+ return torch .sigmoid (self .linear (x ))
560
+
561
+ mod = MockModule ()
562
+
563
+ def fn (x ):
564
+ return torch .utils .checkpoint .checkpoint (mod , torch .sin (x ))
565
+
566
+ x = torch .randn (10 , 10 , requires_grad = True )
567
+
568
+ fw_compiler = functools .partial (count_ops , freq = 1 , op = torch .ops .aten .mm .default )
569
+ bw_compiler = functools .partial (
570
+ count_ops , freq = 3 , op = torch .ops .aten .mm .default
571
+ ) # mm recomputed in the bwd
572
+ backend = aot_autograd (fw_compiler = fw_compiler , bw_compiler = bw_compiler )
573
+ self ._validate (fn , backend , x )
574
+
575
+
409
576
if __name__ == "__main__" :
410
577
from torch ._dynamo .test_case import run_tests
411
578
0 commit comments