|
1 | 1 | from absl.testing import absltest
|
2 | 2 |
|
3 | 3 | import torch
|
| 4 | +import torch.nn as nn |
4 | 5 | import torch_xla # Required for XLA device and sync
|
5 | 6 | from torch_xla.experimental.assume_pure import assume_pure
|
6 | 7 | import torch_xla.core.xla_model as xm # For xm.xla_device() and xm.mark_step() / sync()
|
| 8 | +import torch_xla.core.xla_builder as xb |
7 | 9 | from torch_xla._internal.jax_workarounds import jax_import_guard
|
8 | 10 |
|
9 | 11 |
|
@@ -40,14 +42,54 @@ def simple_torch_function(a, b):
|
40 | 42 | return torch.sin(a @ b)
|
41 | 43 |
|
42 | 44 | a = torch.ones((3, 3), device='xla', requires_grad=True)
|
43 |
| - b = torch.ones((3, 3), device='xla', requires_grad=True) |
44 |
| - o = simple_torch_function(a, b) |
| 45 | + o = simple_torch_function(a, a) |
45 | 46 | o.sum().backward()
|
46 | 47 |
|
47 | 48 | torch_xla.sync()
|
48 | 49 | torch.testing.assert_close(
|
49 | 50 | o, torch.sin(torch.ones(3, 3) @ torch.ones(3, 3)), check_device=False)
|
50 | 51 |
|
| 52 | + def test_assume_pure_module(self): |
| 53 | + model = nn.Linear(3, 3).to('xla') |
| 54 | + |
| 55 | + @assume_pure |
| 56 | + def simple_torch_function(params, x): |
| 57 | + return torch.func.functional_call(model, params, x) |
| 58 | + |
| 59 | + a = torch.ones((3, 3), device='xla', requires_grad=True) |
| 60 | + o = simple_torch_function(dict(model.named_parameters()), a) |
| 61 | + o.sum().backward() |
| 62 | + |
| 63 | + torch_xla.sync() |
| 64 | + |
| 65 | + torch.testing.assert_close( |
| 66 | + o, model(torch.ones(3, 3).to('xla')), check_device=False) |
| 67 | + |
| 68 | + def test_assume_pure_avoid_retracing_avoid_rejit(self): |
| 69 | + starting_lowerings = xb._jax_to_hlo_cache_num_misses() |
| 70 | + trace_counter = 0 |
| 71 | + |
| 72 | + @assume_pure |
| 73 | + def simple_torch_function(a, b): |
| 74 | + nonlocal trace_counter |
| 75 | + trace_counter += 1 |
| 76 | + return torch.sin(a @ b) |
| 77 | + |
| 78 | + # Simulate a training loop. |
| 79 | + for _ in range(5): |
| 80 | + a = torch.ones((3, 3), device='xla', requires_grad=True) |
| 81 | + o = simple_torch_function(a, a) |
| 82 | + o.sum().backward() |
| 83 | + torch_xla.sync() |
| 84 | + |
| 85 | + ending_lowerings = xb._jax_to_hlo_cache_num_misses() |
| 86 | + |
| 87 | + # Check that we only trace once. |
| 88 | + self.assertEqual(trace_counter, 1) |
| 89 | + |
| 90 | + # Check that we only lower to HLO twice (once for forward, once for backward). |
| 91 | + self.assertEqual(ending_lowerings - starting_lowerings, 2) |
| 92 | + |
51 | 93 | def test_assume_pure_matmul_grads(self):
|
52 | 94 | """Tests matmul with all inputs requiring gradients."""
|
53 | 95 |
|
|
0 commit comments