Skip to content

Commit f16c243

Browse files
committed
update test
1 parent da0356e commit f16c243

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

test/run_tests.sh

+3-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ function run_all_tests {
9393
run_opbyop python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
9494
run_eager_debug python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
9595
run_async_rng python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
96-
run_test python3 "$CDIR/test_checkpoint.py"
96+
# TODO: enable this test after tf update, currently optimization_barrier does not
97+
# work on CPU.
98+
# run_test python3 "$CDIR/test_checkpoint.py"
9799
run_test python3 "$CDIR/test_mp_replication.py"
98100
run_test python3 "$CDIR/test_mp_all_to_all.py"
99101
run_test python3 "$CDIR/test_mp_collective_permute.py"

test/test_checkpoint.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch_xla.utils.checkpoint as checkpoint
66

77

8-
def run(grad_checkpoint):
8+
def run():
99
device = xm.xla_device()
1010
model = torch.nn.ModuleList([
1111
torch.nn.Sequential(
@@ -22,16 +22,15 @@ def run(grad_checkpoint):
2222
optimizer.zero_grad()
2323
x = dummy_data
2424
for n_l, layer in enumerate(model):
25-
x = checkpoint.checkpoint(layer, x)
25+
if n_l > 0:
26+
x = checkpoint.checkpoint(layer, x)
27+
else:
28+
x = layer(x)
2629
dummy_loss = x.sum()
2730
dummy_loss.backward()
2831
optimizer.step()
2932
xm.mark_step()
30-
xm.wait_device_ops()
3133

3234

3335
if __name__ == "__main__":
34-
parser = argparse.ArgumentParser()
35-
parser.add_argument("--grad_checkpoint", type=int, required=True)
36-
args = parser.parse_args()
37-
run(args.grad_checkpoint)
36+
run()

0 commit comments

Comments
 (0)