Skip to content

Commit 0db775d

Browse files
rpsilva-awsroot
authored and
root
committed
Refine the gradient accumulation API
1 parent d16766f commit 0db775d

7 files changed

+345
-260
lines changed

test/neuron/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ function run_xla_op_tests1 {
167167
run_xla_hlo_debug "$CDIR/stablehlo/test_stablehlo_save_load.py"
168168
run_save_tensor_ir run_test "$CDIR/spmd/test_spmd_graph_dump.py"
169169
run_save_tensor_hlo run_test "$CDIR/spmd/test_spmd_graph_dump.py"
170+
run_test "$CDIR/test_gradient_accumulation.py"
170171
}
171172

172173
function run_xla_op_tests2 {

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ function run_xla_op_tests3 {
243243
run_test "$CDIR/spmd/test_spmd_parameter_wrapping.py"
244244
run_test "$CDIR/spmd/test_mp_input_sharding.py"
245245
run_test "$CDIR/spmd/test_train_spmd_linear_model.py" "$@" --skip-gradient-checkpointing
246+
run_test "$CDIR/test_gradient_accumulation.py"
246247
run_save_tensor_hlo run_test "$CDIR/spmd/test_spmd_lowering_context.py"
247248
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
248249
run_test "$CDIR/test_input_output_aliases.py"

test/spmd/test_train_spmd_linear_model.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import unittest
66

77
import torch
8+
from torch_xla import runtime as xr
89

910
import test_xla_sharding_base
1011

@@ -19,6 +20,9 @@
1920
# the gradient checkpointing A/B test run for it.
2021
SKIP_GRADIENT_CHECKPOINTING: bool = False
2122

23+
skipOnGpu = unittest.skipIf(xr.device_type() == 'CUDA',
24+
'https://github.com/pytorch/xla/issues/9128')
25+
2226

2327
@contextmanager
2428
def extended_argv(args):
@@ -33,7 +37,7 @@ def extended_argv(args):
3337
class TestSPMDLinearModel(test_xla_sharding_base.XlaShardingTest):
3438

3539
def test_basic(self):
36-
print('Training loop with baseline')
40+
print('Training loop with baseline', flush=True)
3741
with extended_argv([]):
3842
baseline_losses, baseline_result = train_and_evaluate()
3943
# Verify that the model losses are not zero.
@@ -42,7 +46,7 @@ def test_basic(self):
4246
assert not torch.any(baseline_result == 0)
4347

4448
if not SKIP_GRADIENT_CHECKPOINTING:
45-
print('Training loop with gradient checkpointing')
49+
print('Training loop with gradient checkpointing', flush=True)
4650
with extended_argv(['--use_gradient_checkpointing']):
4751
checkpointing_losses, checkpointing_result = train_and_evaluate()
4852
# Verify that the runs match with and without checkpointing.
@@ -62,11 +66,11 @@ def test_gradient_accumulation_matches(self):
6266
"""
6367

6468
COMMON_GRAD_ACC_ARGS = ["--gradient_accumulation_steps", "8"]
65-
print('Training loop with traditional gradient accumulation')
69+
print('Training loop with traditional gradient accumulation', flush=True)
6670
with extended_argv(COMMON_GRAD_ACC_ARGS):
6771
baseline_grad_acc_losses = train_and_evaluate_grad_acc()
6872

69-
print('Training loop with XLA\'s `While` gradient accumulation')
73+
print('Training loop with XLA\'s `While` gradient accumulation', flush=True)
7074
with extended_argv(COMMON_GRAD_ACC_ARGS +
7175
["--use_gradient_accumulation_loop"]):
7276
loop_grad_acc_losses = train_and_evaluate_grad_acc()
@@ -79,8 +83,10 @@ def test_gradient_accumulation_matches(self):
7983
loop_grad_acc_losses))
8084

8185
if not SKIP_GRADIENT_CHECKPOINTING:
82-
print('Training loop with XLA\'s `While` gradient accumulation and '
83-
'gradient checkpointing.')
86+
print(
87+
'Training loop with XLA\'s `While` gradient accumulation and '
88+
'gradient checkpointing.',
89+
flush=True)
8490
with extended_argv(
8591
COMMON_GRAD_ACC_ARGS +
8692
["--use_gradient_accumulation_loop", "--use_gradient_checkpointing"]):

test/test_gradient_accumulation.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import unittest
2+
import torch
3+
import torch_xla
4+
import torch_xla.core.xla_model as xm
5+
import torch_xla.test.test_utils as test_utils
6+
from torch_xla.experimental.gradient_accumulation import gradient_accumulation
7+
8+
from test_utils import XlaTestCase # type:ignore
9+
10+
11+
class SimpleModel(torch.nn.Module):
12+
13+
def __init__(self, input_dim=10, hidden_dim=20, output_dim=5):
14+
super(SimpleModel, self).__init__()
15+
self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
16+
self.fc2 = torch.nn.Linear(hidden_dim, output_dim)
17+
18+
def forward(self, x):
19+
x = torch.relu(self.fc1(x))
20+
return self.fc2(x)
21+
22+
23+
class GradAccumulationTest(XlaTestCase):
24+
25+
def setUp(self):
26+
self.device = xm.xla_device()
27+
torch.manual_seed(123)
28+
29+
def test_basic(self):
30+
"""Compare results with and without the XLA loop"""
31+
batch_size = 8
32+
hidden_dim = 20
33+
input_dim = 10
34+
output_dim = 5
35+
36+
inputs = torch.randn(batch_size, input_dim).to(self.device)
37+
targets = torch.randn(batch_size, output_dim).to(self.device)
38+
39+
def train_step_fw(input_batch, target_batch, carried_tensor):
40+
output = model_ga(input_batch)
41+
loss = torch.nn.functional.mse_loss(output, target_batch)
42+
new_carried_tensor = carried_tensor + 5
43+
return loss, new_carried_tensor
44+
45+
# Gradient accumulation with XLA loop
46+
torch.manual_seed(43)
47+
model_ga = SimpleModel(input_dim, hidden_dim, output_dim).to(self.device)
48+
carried_tensor_ga = torch.tensor([5, 5]).to(self.device)
49+
50+
accumulated_loss_ga, accum_carried_tensor_ga = gradient_accumulation(
51+
train_step_fw, (inputs, targets), model_ga, carried_tensor_ga)
52+
53+
torch_xla.sync()
54+
55+
# Traditional accumulation
56+
torch.manual_seed(43)
57+
model_manual = SimpleModel(input_dim, hidden_dim,
58+
output_dim).to(self.device)
59+
carried_tensor_manual = torch.tensor([5, 5]).to(self.device)
60+
61+
accumulated_loss_manual = torch.tensor(0.0).to(self.device)
62+
for i in range(batch_size):
63+
loss, carried_tensor_manual = train_step_fw(inputs[i:i + 1],
64+
targets[i:i + 1],
65+
carried_tensor_manual)
66+
loss = loss / batch_size
67+
loss.backward()
68+
accumulated_loss_manual += loss.detach()
69+
70+
torch_xla.sync()
71+
72+
# Compare losses, carried tensors and resulting gradients
73+
super().compareResults([accumulated_loss_ga], [accumulated_loss_manual])
74+
super().compareResults([accum_carried_tensor_ga], [carried_tensor_manual])
75+
super().compareResults(model_ga.parameters(), model_manual.parameters())
76+
77+
def test_with_carried_tensors(self):
78+
"""Test gradient accumulation with carried tensors, including with RNG"""
79+
batch_size = 2
80+
hidden_dim = 20
81+
input_dim = 10
82+
output_dim = 5
83+
84+
model = SimpleModel(input_dim, hidden_dim, output_dim).to(self.device)
85+
86+
inputs = torch.randn(batch_size, input_dim).to(self.device)
87+
targets = torch.randn(batch_size, output_dim).to(self.device)
88+
89+
# Carried tensors
90+
counter = torch.tensor(0).to(self.device)
91+
tensor0 = torch.tensor(0.0).to(self.device)
92+
tensor0_baseline = tensor0.clone()
93+
94+
# Define train step function that updates the carried tensor. In the case of
95+
# RNG, we negate the previous value, in order to validate that we get unique
96+
# RNG seeds for each iteration.
97+
def train_step_fw(input_batch, target_batch, counter, tensor0):
98+
output = model(input_batch)
99+
loss = torch.nn.functional.mse_loss(output, target_batch)
100+
# Update counter
101+
new_counter = counter + 1
102+
new_tensor0 = torch.rand_like(tensor0, device=self.device) - tensor0
103+
return loss, new_counter, new_tensor0
104+
105+
# Run gradient accumulation
106+
accumulated_loss, final_counter, final_tensor0 = gradient_accumulation(
107+
train_step_fw, (inputs, targets), model, counter, tensor0)
108+
109+
torch_xla.sync()
110+
111+
self.assertEqual(final_counter.item(), batch_size)
112+
# Ensure that the result is not 0, showcasing that the RNG is unique
113+
# per iteration.
114+
self.assertNotEqual(final_tensor0.item(), 0.0)
115+
116+
def test_error_empty_iterable_tensors(self):
117+
"""Test that empty iterable_tensors raises an error."""
118+
model = SimpleModel().to(self.device)
119+
120+
def train_step_fw():
121+
pass
122+
123+
with self.assertRaises(ValueError):
124+
gradient_accumulation(train_step_fw, [], model)
125+
126+
def test_error_mutated_input_tensors(self):
127+
"""Test that mutating input tensors raises an error."""
128+
batch_size = 2
129+
hidden_dim = 20
130+
input_dim = 10
131+
output_dim = 5
132+
133+
model = SimpleModel(input_dim, hidden_dim, output_dim).to(self.device)
134+
135+
inputs = torch.randn(batch_size, input_dim).to(self.device)
136+
targets = torch.randn(batch_size, output_dim).to(self.device)
137+
counter = torch.tensor(0).to(self.device)
138+
139+
def train_step_fw(input_batch, target_batch, counter):
140+
output = model(input_batch)
141+
loss = torch.nn.functional.mse_loss(output, target_batch)
142+
# In-place mutation of an input tensor.
143+
counter += 1
144+
return loss, counter
145+
146+
with self.assertRaises(AssertionError):
147+
accumulated_loss, final_counter = gradient_accumulation(
148+
train_step_fw, (inputs, targets), model, counter)
149+
150+
151+
if __name__ == '__main__':
152+
test = unittest.main()
153+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/tpu/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ python3 "$TEST_CDIR/spmd/test_train_spmd_linear_model.py"
1919
python3 "$TEST_CDIR/spmd/test_xla_spmd_python_api_interaction.py"
2020
python3 "$TEST_CDIR/spmd/test_xla_auto_sharding.py"
2121
python3 "$TEST_CDIR/spmd/test_fsdp_v2.py"
22+
python3 "$TEST_CDIR/test_gradient_accumulation.py"
2223
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 "$TEST_CDIR/ds/test_dynamic_shape_models.py" -v
2324
python3 "$TEST_CDIR/test_autocast.py"
2425
python3 "$TEST_CDIR/test_fp8.py"

test/utils/train_spmd_linear_model_grad_acc.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,7 @@ def train_step(input_id, label):
125125

126126
def train_loop_fn(data, target, running_loss):
127127
if FLAGS.use_gradient_accumulation_loop:
128-
running_loss, = gradient_accumulation(train_step, (data, target), model,
129-
None)
128+
running_loss = gradient_accumulation(train_step, (data, target), model)
130129
else:
131130
for i in range(FLAGS.gradient_accumulation_steps):
132131
loss = train_step(data[i], target[i])

0 commit comments

Comments
 (0)