|
| 1 | +import unittest |
| 2 | +import torch |
| 3 | +import torch_xla |
| 4 | +import torch_xla.core.xla_model as xm |
| 5 | +import torch_xla.test.test_utils as test_utils |
| 6 | +from torch_xla.experimental.gradient_accumulation import gradient_accumulation |
| 7 | + |
| 8 | +from test_utils import XlaTestCase # type:ignore |
| 9 | + |
| 10 | +class SimpleModel(torch.nn.Module): |
| 11 | + def __init__(self, input_dim=10, hidden_dim=20, output_dim=5): |
| 12 | + super(SimpleModel, self).__init__() |
| 13 | + self.fc1 = torch.nn.Linear(input_dim, hidden_dim) |
| 14 | + self.fc2 = torch.nn.Linear(hidden_dim, output_dim) |
| 15 | + |
| 16 | + def forward(self, x): |
| 17 | + x = torch.relu(self.fc1(x)) |
| 18 | + return self.fc2(x) |
| 19 | + |
| 20 | +class GradAccumulationTest(XlaTestCase): |
| 21 | + def setUp(self): |
| 22 | + self.device = xm.xla_device() |
| 23 | + torch.manual_seed(123) |
| 24 | + |
| 25 | + def test_basic(self): |
| 26 | + """Compare results with and without the XLA loop""" |
| 27 | + batch_size = 8 |
| 28 | + hidden_dim = 20 |
| 29 | + input_dim = 10 |
| 30 | + output_dim = 5 |
| 31 | + |
| 32 | + inputs = torch.randn(batch_size, input_dim).to(self.device) |
| 33 | + targets = torch.randn(batch_size, output_dim).to(self.device) |
| 34 | + |
| 35 | + def train_step_fw(input_batch, target_batch, carried_tensor): |
| 36 | + output = model_ga(input_batch) |
| 37 | + loss = torch.nn.functional.mse_loss(output, target_batch) |
| 38 | + new_carried_tensor = carried_tensor + 5 |
| 39 | + return loss, new_carried_tensor |
| 40 | + |
| 41 | + # Gradient accumulation with XLA loop |
| 42 | + torch.manual_seed(43) |
| 43 | + model_ga = SimpleModel(input_dim, hidden_dim, output_dim).to(self.device) |
| 44 | + carried_tensor_ga = torch.tensor([5,5]).to(self.device) |
| 45 | + |
| 46 | + |
| 47 | + accumulated_loss_ga, accum_carried_tensor_ga = gradient_accumulation( |
| 48 | + train_step_fw, |
| 49 | + (inputs, targets), |
| 50 | + model_ga, |
| 51 | + carried_tensor_ga |
| 52 | + ) |
| 53 | + |
| 54 | + torch_xla.sync() |
| 55 | + |
| 56 | + # Traditional accumulation |
| 57 | + torch.manual_seed(43) |
| 58 | + model_manual = SimpleModel(input_dim, hidden_dim, output_dim).to(self.device) |
| 59 | + carried_tensor_manual = torch.tensor([5,5]).to(self.device) |
| 60 | + |
| 61 | + accumulated_loss_manual = torch.tensor(0.0).to(self.device) |
| 62 | + for i in range(batch_size): |
| 63 | + loss, carried_tensor_manual = train_step_fw(inputs[i:i+1], targets[i:i+1], carried_tensor_manual) |
| 64 | + loss = loss / batch_size |
| 65 | + loss.backward() |
| 66 | + accumulated_loss_manual += loss.detach() |
| 67 | + |
| 68 | + torch_xla.sync() |
| 69 | + |
| 70 | + # Compare losses, carried tensors and resulting gradients |
| 71 | + super().compareResults([accumulated_loss_ga], [accumulated_loss_manual]) |
| 72 | + super().compareResults([accum_carried_tensor_ga], [carried_tensor_manual]) |
| 73 | + super().compareResults(model_ga.parameters(), model_manual.parameters()) |
| 74 | + |
| 75 | + def test_with_carried_tensors(self): |
| 76 | + """Test gradient accumulation with carried tensors, including with RNG""" |
| 77 | + batch_size = 2 |
| 78 | + hidden_dim = 20 |
| 79 | + input_dim = 10 |
| 80 | + output_dim = 5 |
| 81 | + |
| 82 | + model = SimpleModel(input_dim, hidden_dim, output_dim).to(self.device) |
| 83 | + |
| 84 | + inputs = torch.randn(batch_size, input_dim).to(self.device) |
| 85 | + targets = torch.randn(batch_size, output_dim).to(self.device) |
| 86 | + |
| 87 | + # Carried tensors |
| 88 | + counter = torch.tensor(0).to(self.device) |
| 89 | + tensor0 = torch.tensor(0.0).to(self.device) |
| 90 | + tensor0_baseline = tensor0.clone() |
| 91 | + |
| 92 | + # Define train step function that updates the carried tensor. In the case of |
| 93 | + # RNG, we negate the previous value, in order to validate that we get unique |
| 94 | + # RNG seeds for each iteration. |
| 95 | + def train_step(input_batch, target_batch, counter, tensor0): |
| 96 | + output = model(input_batch) |
| 97 | + loss = torch.nn.functional.mse_loss(output, target_batch) |
| 98 | + # Update counter |
| 99 | + new_counter = counter + 1 |
| 100 | + new_tensor0 = torch.rand_like(tensor0, device=self.device) - tensor0 |
| 101 | + return loss, new_counter, new_tensor0 |
| 102 | + |
| 103 | + # Run gradient accumulation |
| 104 | + accumulated_loss, final_counter, final_tensor0 = gradient_accumulation( |
| 105 | + train_step, |
| 106 | + (inputs, targets), |
| 107 | + model, |
| 108 | + counter, tensor0 |
| 109 | + ) |
| 110 | + |
| 111 | + torch_xla.sync() |
| 112 | + |
| 113 | + self.assertEqual(final_counter.item(), batch_size) |
| 114 | + # Ensure that the result is not 0, showcasing that the RNG is unique |
| 115 | + # per iteration. |
| 116 | + self.assertNotEqual(final_tensor0.item(), 0.0) |
| 117 | + |
| 118 | + def test_error_empty_iterable_tensors(self): |
| 119 | + """Test that empty iterable_tensors raises an error.""" |
| 120 | + model = SimpleModel().to(self.device) |
| 121 | + |
| 122 | + def train_step(): |
| 123 | + pass |
| 124 | + |
| 125 | + with self.assertRaises(ValueError): |
| 126 | + gradient_accumulation(train_step, [], model) |
| 127 | + |
| 128 | + def test_error_mutated_input_tensors(self): |
| 129 | + """Test that mutating input tensors raises an error.""" |
| 130 | + batch_size = 2 |
| 131 | + hidden_dim = 20 |
| 132 | + input_dim = 10 |
| 133 | + output_dim = 5 |
| 134 | + |
| 135 | + model = SimpleModel(input_dim, hidden_dim, output_dim).to(self.device) |
| 136 | + |
| 137 | + inputs = torch.randn(batch_size, input_dim).to(self.device) |
| 138 | + targets = torch.randn(batch_size, output_dim).to(self.device) |
| 139 | + counter = torch.tensor(0).to(self.device) |
| 140 | + |
| 141 | + def train_step(input_batch, target_batch, counter): |
| 142 | + output = model(input_batch) |
| 143 | + loss = torch.nn.functional.mse_loss(output, target_batch) |
| 144 | + # In-place mutation of an input tensor. |
| 145 | + counter += 1 |
| 146 | + return loss, counter |
| 147 | + |
| 148 | + with self.assertRaises(AssertionError): |
| 149 | + accumulated_loss, final_counter = gradient_accumulation( |
| 150 | + train_step, |
| 151 | + (inputs, targets), |
| 152 | + model, |
| 153 | + counter |
| 154 | + ) |
| 155 | + |
| 156 | +if __name__ == '__main__': |
| 157 | + test = unittest.main() |
| 158 | + sys.exit(0 if test.result.wasSuccessful() else 1) |
| 159 | + |
0 commit comments