Skip to content

Commit da0356e

Browse files
committed
Add checkpoint code
1 parent 1a0ae30 commit da0356e

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ function run_all_tests {
9393
run_opbyop python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
9494
run_eager_debug python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
9595
run_async_rng python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
96+
run_test python3 "$CDIR/test_checkpoint.py"
9697
run_test python3 "$CDIR/test_mp_replication.py"
9798
run_test python3 "$CDIR/test_mp_all_to_all.py"
9899
run_test python3 "$CDIR/test_mp_collective_permute.py"

test/test_checkpoint.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import torch
2+
import torch_xla.core.xla_model as xm
3+
import torch_xla.debug.metrics as met
4+
import torch_xla
5+
import torch_xla.utils.checkpoint as checkpoint
6+
7+
8+
def run(grad_checkpoint):
9+
device = xm.xla_device()
10+
model = torch.nn.ModuleList([
11+
torch.nn.Sequential(
12+
torch.nn.Conv2d(1024, 1024, 1),
13+
torch.nn.ReLU(),
14+
torch.nn.Conv2d(1024, 1024, 1),
15+
torch.nn.ReLU(),
16+
) for _ in range(2)
17+
]).to(device)
18+
optimizer = torch.optim.SGD(model.parameters(), lr=0.0)
19+
20+
for step in range(20):
21+
dummy_data = torch.zeros(64, 1024, 14, 14, device=device)
22+
optimizer.zero_grad()
23+
x = dummy_data
24+
for n_l, layer in enumerate(model):
25+
x = checkpoint.checkpoint(layer, x)
26+
dummy_loss = x.sum()
27+
dummy_loss.backward()
28+
optimizer.step()
29+
xm.mark_step()
30+
xm.wait_device_ops()
31+
32+
33+
if __name__ == "__main__":
34+
parser = argparse.ArgumentParser()
35+
parser.add_argument("--grad_checkpoint", type=int, required=True)
36+
args = parser.parse_args()
37+
run(args.grad_checkpoint)

0 commit comments

Comments
 (0)