Skip to content

Commit 477ca24

Browse files
authored
Add parameter broadcasting to PJRT examples. (#3836)
* Add parameter broadcasting to PJRT examples. * Naming fix in tests
1 parent 7b6747d commit 477ca24

File tree

4 files changed

+61
-10
lines changed

4 files changed

+61
-10
lines changed
+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from numpy.testing import assert_array_equal, assert_raises
2+
import torch
3+
import torch.nn as nn
4+
import torch_xla.core.xla_model as xm
5+
from torch_xla.experimental import pjrt
6+
from absl.testing import absltest, parameterized
7+
8+
9+
def broadcast(sync):
10+
torch.manual_seed(xm.get_ordinal())
11+
device = xm.xla_device()
12+
model = nn.Linear(5, 5).to(device)
13+
if sync:
14+
pjrt.broadcast_master_param(model)
15+
return next(model.parameters()).detach().cpu().numpy()
16+
17+
18+
class TestBroadcastParametersPjrt(parameterized.TestCase):
19+
20+
@parameterized.named_parameters(('synchronized_parameters', True),
21+
('unsynchronized_parameters', False))
22+
def test_broadcast_parameter_sync(self, sync):
23+
torch.set_default_tensor_type('torch.FloatTensor')
24+
results = pjrt.run_multiprocess(broadcast, sync)
25+
master_params = results[0][0]
26+
for process_key in results:
27+
worker_params = results[process_key][0]
28+
if sync:
29+
assert_array_equal(master_params, worker_params)
30+
elif process_key != 0:
31+
assert_raises(AssertionError, assert_array_equal, master_params,
32+
worker_params)
33+
34+
35+
if __name__ == '__main__':
36+
absltest.main()

test/pjrt/test_train_pjrt_imagenet.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _train_update(device, step, loss, tracker, epoch, writer):
114114
summary_writer=writer)
115115

116116

117-
def train_imagenet(state_dict):
117+
def train_imagenet():
118118
print('==> Preparing data..')
119119
img_dim = get_model_property('img_dim')
120120
if FLAGS.fake_data:
@@ -182,8 +182,8 @@ def train_imagenet(state_dict):
182182

183183
device = xm.xla_device()
184184
model = get_model_property('model_fn')()
185-
model.load_state_dict(state_dict)
186185
model = model.to(device)
186+
pjrt.broadcast_master_param(model)
187187
writer = None
188188
if xm.is_master_ordinal():
189189
writer = test_utils.get_summary_writer(FLAGS.logdir)
@@ -262,10 +262,8 @@ def test_loop_fn(loader, epoch):
262262

263263
if __name__ == '__main__':
264264
torch.set_default_tensor_type('torch.FloatTensor')
265-
torch.manual_seed(42)
266-
model = get_model_property('model_fn')()
267265

268-
results = pjrt.run_multiprocess(train_imagenet, model.state_dict())
266+
results = pjrt.run_multiprocess(train_imagenet)
269267
print('Replica max_accuracy:', pprint.pformat(results))
270268
accuracy = np.mean([
271269
np.mean(list(thread_results.values()))

test/pjrt/test_train_pjrt_mnist.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _train_update(device, x, loss, tracker, writer):
5959
summary_writer=writer)
6060

6161

62-
def train_mnist(flags, state_dict):
62+
def train_mnist(flags):
6363
if flags.fake_data:
6464
train_loader = xu.SampleGenerator(
6565
data=(torch.zeros(flags.batch_size, 1, 28,
@@ -112,8 +112,8 @@ def train_mnist(flags, state_dict):
112112

113113
device = xm.xla_device()
114114
model = MNIST()
115-
model.load_state_dict(state_dict)
116115
model = model.to(device)
116+
pjrt.broadcast_master_param(model)
117117
writer = None
118118
if xm.is_master_ordinal():
119119
writer = test_utils.get_summary_writer(flags.logdir)
@@ -177,10 +177,8 @@ def test_loop_fn(loader):
177177

178178
if __name__ == '__main__':
179179
torch.set_default_tensor_type('torch.FloatTensor')
180-
torch.manual_seed(1)
181-
model = MNIST()
182180

183-
results = pjrt.run_multiprocess(train_mnist, FLAGS, model.state_dict())
181+
results = pjrt.run_multiprocess(train_mnist, FLAGS)
184182
print('Replica max_accuracy:', pprint.pformat(results))
185183
accuracy = np.mean([
186184
np.mean(list(thread_results.values()))

torch_xla/experimental/pjrt.py

+19
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
import functools
33
import os
44
import threading
5+
from itertools import chain
56
from typing import Callable, Dict, Optional, TypeVar
67

78
import torch
9+
import torch.nn as nn
810
import torch_xla
911
import torch_xla.core.xla_env_vars as xenv
1012
import torch_xla.core.xla_model as xm
@@ -214,3 +216,20 @@ def run_multiprocess(fn: Callable[..., R], *args,
214216
}
215217

216218
return results
219+
220+
221+
def broadcast_master_param(model: nn.Module) -> None:
222+
"""
223+
Broadcast the model parameters from master process to other processes
224+
"""
225+
parameters_and_buffers = []
226+
is_master = xm.is_master_ordinal(local=False)
227+
for p in chain(model.parameters(), model.buffers()):
228+
# Set all params in non-master devices to zero so that all_reduce is
229+
# equivalent to broadcasting parameters from master to other devices.
230+
scale = torch.tensor(1 if is_master else 0, dtype=p.data.dtype)
231+
scale = scale.to(p.data.device)
232+
p.data.mul_(scale)
233+
parameters_and_buffers.append(p.data)
234+
xm.all_reduce(xm.REDUCE_SUM, parameters_and_buffers)
235+
xm.mark_step()

0 commit comments

Comments
 (0)