Skip to content

Commit cf46ce1

Browse files
rpsilva-awsroot
authored and
root
committed
Refine the gradient accumulation API
1 parent d16766f commit cf46ce1

File tree

3 files changed

+332
-254
lines changed

3 files changed

+332
-254
lines changed

test/test_gradient_accumulation.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import unittest
2+
import torch
3+
import torch_xla
4+
import torch_xla.core.xla_model as xm
5+
import torch_xla.test.test_utils as test_utils
6+
from torch_xla.experimental.gradient_accumulation import gradient_accumulation
7+
8+
from test_utils import XlaTestCase # type:ignore
9+
10+
class SimpleModel(torch.nn.Module):
11+
def __init__(self, input_dim=10, hidden_dim=20, output_dim=5):
12+
super(SimpleModel, self).__init__()
13+
self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
14+
self.fc2 = torch.nn.Linear(hidden_dim, output_dim)
15+
16+
def forward(self, x):
17+
x = torch.relu(self.fc1(x))
18+
return self.fc2(x)
19+
20+
class GradAccumulationTest(XlaTestCase):
21+
def setUp(self):
22+
self.device = xm.xla_device()
23+
torch.manual_seed(123)
24+
25+
def test_basic(self):
26+
"""Compare results with and without the XLA loop"""
27+
batch_size = 8
28+
hidden_dim = 20
29+
input_dim = 10
30+
output_dim = 5
31+
32+
inputs = torch.randn(batch_size, input_dim).to(self.device)
33+
targets = torch.randn(batch_size, output_dim).to(self.device)
34+
35+
def train_step_fw(input_batch, target_batch, carried_tensor):
36+
output = model_ga(input_batch)
37+
loss = torch.nn.functional.mse_loss(output, target_batch)
38+
new_carried_tensor = carried_tensor + 5
39+
return loss, new_carried_tensor
40+
41+
# Gradient accumulation with XLA loop
42+
torch.manual_seed(43)
43+
model_ga = SimpleModel(input_dim, hidden_dim, output_dim).to(self.device)
44+
carried_tensor_ga = torch.tensor([5,5]).to(self.device)
45+
46+
47+
accumulated_loss_ga, accum_carried_tensor_ga = gradient_accumulation(
48+
train_step_fw,
49+
(inputs, targets),
50+
model_ga,
51+
carried_tensor_ga
52+
)
53+
54+
torch_xla.sync()
55+
56+
# Traditional accumulation
57+
torch.manual_seed(43)
58+
model_manual = SimpleModel(input_dim, hidden_dim, output_dim).to(self.device)
59+
carried_tensor_manual = torch.tensor([5,5]).to(self.device)
60+
61+
accumulated_loss_manual = torch.tensor(0.0).to(self.device)
62+
for i in range(batch_size):
63+
loss, carried_tensor_manual = train_step_fw(inputs[i:i+1], targets[i:i+1], carried_tensor_manual)
64+
loss = loss / batch_size
65+
loss.backward()
66+
accumulated_loss_manual += loss.detach()
67+
68+
torch_xla.sync()
69+
70+
# Compare losses, carried tensors and resulting gradients
71+
super().compareResults([accumulated_loss_ga], [accumulated_loss_manual])
72+
super().compareResults([accum_carried_tensor_ga], [carried_tensor_manual])
73+
super().compareResults(model_ga.parameters(), model_manual.parameters())
74+
75+
def test_with_carried_tensors(self):
76+
"""Test gradient accumulation with carried tensors, including with RNG"""
77+
batch_size = 2
78+
hidden_dim = 20
79+
input_dim = 10
80+
output_dim = 5
81+
82+
model = SimpleModel(input_dim, hidden_dim, output_dim).to(self.device)
83+
84+
inputs = torch.randn(batch_size, input_dim).to(self.device)
85+
targets = torch.randn(batch_size, output_dim).to(self.device)
86+
87+
# Carried tensors
88+
counter = torch.tensor(0).to(self.device)
89+
tensor0 = torch.tensor(0.0).to(self.device)
90+
tensor0_baseline = tensor0.clone()
91+
92+
# Define train step function that updates the carried tensor. In the case of
93+
# RNG, we negate the previous value, in order to validate that we get unique
94+
# RNG seeds for each iteration.
95+
def train_step(input_batch, target_batch, counter, tensor0):
96+
output = model(input_batch)
97+
loss = torch.nn.functional.mse_loss(output, target_batch)
98+
# Update counter
99+
new_counter = counter + 1
100+
new_tensor0 = torch.rand_like(tensor0, device=self.device) - tensor0
101+
return loss, new_counter, new_tensor0
102+
103+
# Run gradient accumulation
104+
accumulated_loss, final_counter, final_tensor0 = gradient_accumulation(
105+
train_step,
106+
(inputs, targets),
107+
model,
108+
counter, tensor0
109+
)
110+
111+
torch_xla.sync()
112+
113+
self.assertEqual(final_counter.item(), batch_size)
114+
# Ensure that the result is not 0, showcasing that the RNG is unique
115+
# per iteration.
116+
self.assertNotEqual(final_tensor0.item(), 0.0)
117+
118+
def test_error_empty_iterable_tensors(self):
119+
"""Test that empty iterable_tensors raises an error."""
120+
model = SimpleModel().to(self.device)
121+
122+
def train_step():
123+
pass
124+
125+
with self.assertRaises(ValueError):
126+
gradient_accumulation(train_step, [], model)
127+
128+
def test_error_mutated_input_tensors(self):
129+
"""Test that mutating input tensors raises an error."""
130+
batch_size = 2
131+
hidden_dim = 20
132+
input_dim = 10
133+
output_dim = 5
134+
135+
model = SimpleModel(input_dim, hidden_dim, output_dim).to(self.device)
136+
137+
inputs = torch.randn(batch_size, input_dim).to(self.device)
138+
targets = torch.randn(batch_size, output_dim).to(self.device)
139+
counter = torch.tensor(0).to(self.device)
140+
141+
def train_step(input_batch, target_batch, counter):
142+
output = model(input_batch)
143+
loss = torch.nn.functional.mse_loss(output, target_batch)
144+
# In-place mutation of an input tensor.
145+
counter += 1
146+
return loss, counter
147+
148+
with self.assertRaises(AssertionError):
149+
accumulated_loss, final_counter = gradient_accumulation(
150+
train_step,
151+
(inputs, targets),
152+
model,
153+
counter
154+
)
155+
156+
if __name__ == '__main__':
157+
test = unittest.main()
158+
sys.exit(0 if test.result.wasSuccessful() else 1)
159+

test/utils/train_spmd_linear_model_grad_acc.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,7 @@ def train_step(input_id, label):
125125

126126
def train_loop_fn(data, target, running_loss):
127127
if FLAGS.use_gradient_accumulation_loop:
128-
running_loss, = gradient_accumulation(train_step, (data, target), model,
129-
None)
128+
running_loss = gradient_accumulation(train_step, (data, target), model)
130129
else:
131130
for i in range(FLAGS.gradient_accumulation_steps):
132131
loss = train_step(data[i], target[i])

0 commit comments

Comments
 (0)