File tree 2 files changed +9
-8
lines changed
2 files changed +9
-8
lines changed Original file line number Diff line number Diff line change @@ -93,7 +93,9 @@ function run_all_tests {
93
93
run_opbyop python3 " $CDIR /test_operations.py" " $@ " --verbosity=$VERBOSITY
94
94
run_eager_debug python3 " $CDIR /test_operations.py" " $@ " --verbosity=$VERBOSITY
95
95
run_async_rng python3 " $CDIR /test_operations.py" " $@ " --verbosity=$VERBOSITY
96
- run_test python3 " $CDIR /test_checkpoint.py"
96
+ # TODO: enable this test after tf update, currently optimization_barrier does not
97
+ # work on CPU.
98
+ # run_test python3 "$CDIR/test_checkpoint.py"
97
99
run_test python3 " $CDIR /test_mp_replication.py"
98
100
run_test python3 " $CDIR /test_mp_all_to_all.py"
99
101
run_test python3 " $CDIR /test_mp_collective_permute.py"
Original file line number Diff line number Diff line change 5
5
import torch_xla .utils .checkpoint as checkpoint
6
6
7
7
8
- def run (grad_checkpoint ):
8
+ def run ():
9
9
device = xm .xla_device ()
10
10
model = torch .nn .ModuleList ([
11
11
torch .nn .Sequential (
@@ -22,16 +22,15 @@ def run(grad_checkpoint):
22
22
optimizer .zero_grad ()
23
23
x = dummy_data
24
24
for n_l , layer in enumerate (model ):
25
- x = checkpoint .checkpoint (layer , x )
25
+ if n_l > 0 :
26
+ x = checkpoint .checkpoint (layer , x )
27
+ else :
28
+ x = layer (x )
26
29
dummy_loss = x .sum ()
27
30
dummy_loss .backward ()
28
31
optimizer .step ()
29
32
xm .mark_step ()
30
- xm .wait_device_ops ()
31
33
32
34
33
35
if __name__ == "__main__" :
34
- parser = argparse .ArgumentParser ()
35
- parser .add_argument ("--grad_checkpoint" , type = int , required = True )
36
- args = parser .parse_args ()
37
- run (args .grad_checkpoint )
36
+ run ()
You can’t perform that action at this time.
0 commit comments