Skip to content

Commit 4e63515

Browse files
committed
Add more tests
1 parent 07270c3 commit 4e63515

File tree

1 file changed

+44
-2
lines changed

1 file changed

+44
-2
lines changed

test/test_assume_pure.py

+44-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from absl.testing import absltest
22

33
import torch
4+
import torch.nn as nn
45
import torch_xla # Required for XLA device and sync
56
from torch_xla.experimental.assume_pure import assume_pure
67
import torch_xla.core.xla_model as xm # For xm.xla_device() and xm.mark_step() / sync()
8+
import torch_xla.core.xla_builder as xb
79
from torch_xla._internal.jax_workarounds import jax_import_guard
810

911

@@ -40,14 +42,54 @@ def simple_torch_function(a, b):
4042
return torch.sin(a @ b)
4143

4244
a = torch.ones((3, 3), device='xla', requires_grad=True)
43-
b = torch.ones((3, 3), device='xla', requires_grad=True)
44-
o = simple_torch_function(a, b)
45+
o = simple_torch_function(a, a)
4546
o.sum().backward()
4647

4748
torch_xla.sync()
4849
torch.testing.assert_close(
4950
o, torch.sin(torch.ones(3, 3) @ torch.ones(3, 3)), check_device=False)
5051

52+
def test_assume_pure_module(self):
53+
model = nn.Linear(3, 3).to('xla')
54+
55+
@assume_pure
56+
def simple_torch_function(params, x):
57+
return torch.func.functional_call(model, params, x)
58+
59+
a = torch.ones((3, 3), device='xla', requires_grad=True)
60+
o = simple_torch_function(dict(model.named_parameters()), a)
61+
o.sum().backward()
62+
63+
torch_xla.sync()
64+
65+
torch.testing.assert_close(
66+
o, model(torch.ones(3, 3).to('xla')), check_device=False)
67+
68+
def test_assume_pure_avoid_retracing_avoid_rejit(self):
69+
starting_lowerings = xb._jax_to_hlo_cache_num_misses()
70+
trace_counter = 0
71+
72+
@assume_pure
73+
def simple_torch_function(a, b):
74+
nonlocal trace_counter
75+
trace_counter += 1
76+
return torch.sin(a @ b)
77+
78+
# Simulate a training loop.
79+
for _ in range(5):
80+
a = torch.ones((3, 3), device='xla', requires_grad=True)
81+
o = simple_torch_function(a, a)
82+
o.sum().backward()
83+
torch_xla.sync()
84+
85+
ending_lowerings = xb._jax_to_hlo_cache_num_misses()
86+
87+
# Check that we only trace once.
88+
self.assertEqual(trace_counter, 1)
89+
90+
# Check that we only lower to HLO twice (once for forward, once for backward).
91+
self.assertEqual(ending_lowerings - starting_lowerings, 2)
92+
5193
def test_assume_pure_matmul_grads(self):
5294
"""Tests matmul with all inputs requiring gradients."""
5395

0 commit comments

Comments
 (0)