Skip to content

Commit 5c15061

Browse files
committed
Add test
1 parent 4637fb0 commit 5c15061

File tree

2 files changed

+35
-10
lines changed

2 files changed

+35
-10
lines changed

test/spmd/test_xla_sharding.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -618,9 +618,9 @@ def test_inplace_add_with_sharding(self):
618618

619619
# avoid calling xr.addressable_device_count here otherwise it will init the test
620620
# in non-spmd mode.
621-
@unittest.skipIf(xr.device_type() == 'CPU',
622-
"sharding will be the same for both tensors on single device"
623-
)
621+
@unittest.skipIf(
622+
xr.device_type() == 'CPU',
623+
"sharding will be the same for both tensors on single device")
624624
def test_shard_hashing(self):
625625
xt1 = torch.ones(2, 2).to(xm.xla_device())
626626
xt2 = torch.ones(2, 2).to(xm.xla_device())
@@ -1383,8 +1383,9 @@ def test_get_1d_mesh(self):
13831383
self.assertEqual(mesh_without_name.mesh_shape,
13841384
(xr.global_runtime_device_count(),))
13851385

1386-
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
1387-
"Multiple devices required for dataloader sharding test")
1386+
@unittest.skipUnless(
1387+
xr.global_runtime_device_count() > 1,
1388+
"Multiple devices required for dataloader sharding test")
13881389
def test_data_loader_with_sharding(self):
13891390
device = torch_xla.device()
13901391
mesh = xs.get_1d_mesh("data")
@@ -1405,8 +1406,9 @@ def test_data_loader_with_sharding(self):
14051406
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
14061407
)
14071408

1408-
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
1409-
"Multiple devices required for dataloader sharding test")
1409+
@unittest.skipUnless(
1410+
xr.global_runtime_device_count() > 1,
1411+
"Multiple devices required for dataloader sharding test")
14101412
def test_data_loader_with_non_batch_size(self):
14111413
device = torch_xla.device()
14121414
mesh = xs.get_1d_mesh("data")
@@ -1427,8 +1429,9 @@ def test_data_loader_with_non_batch_size(self):
14271429
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
14281430
)
14291431

1430-
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
1431-
"Multiple devices required for dataloader sharding test")
1432+
@unittest.skipUnless(
1433+
xr.global_runtime_device_count() > 1,
1434+
"Multiple devices required for dataloader sharding test")
14321435
def test_data_loader_with_non_batch_size_and_mini_batch(self):
14331436
device = torch_xla.device()
14341437
mesh = xs.get_1d_mesh("data")
@@ -1569,6 +1572,7 @@ def test_mark_sharding_with_gradients_annotation(self):
15691572
# Check that the gradient has sharding.
15701573
self.assertIn(sharding_spec, x_grad_sharding)
15711574

1575+
<<<<<<< HEAD
15721576
def test_valid_mesh_creation(self):
15731577
mesh_shape = (1, self.n_devices)
15741578
axis_names = ('data', 'model')
@@ -1660,6 +1664,26 @@ def test_get_logical_mesh(self):
16601664
self.assertEqual(logical_mesh.shape, mesh_shape)
16611665
np.testing.assert_array_equal(np.sort(logical_mesh.flatten()), device_ids)
16621666

1667+
def test_shard_as(self):
1668+
mesh = self._get_mesh((self.n_devices,))
1669+
partition_spec = (0,)
1670+
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8],
1671+
dtype=torch.float,
1672+
device=xm.xla_device())
1673+
x = xs.mark_sharding_with_gradients(x, mesh, partition_spec)
1674+
y = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8],
1675+
dtype=torch.float,
1676+
device=xm.xla_device())
1677+
1678+
x, y = xs.shard_as(x, y)
1679+
torch_xla.sync()
1680+
1681+
sharding_spec = '{devices=[%d]' % self.n_devices
1682+
x_sharding = torch_xla._XLAC._get_xla_sharding_spec(x)
1683+
y_sharding = torch_xla._XLAC._get_xla_sharding_spec(y)
1684+
self.assertIn(sharding_spec, x_sharding)
1685+
self.assertEqual(x_sharding, y_sharding)
1686+
16631687

16641688
if __name__ == '__main__':
16651689
test = unittest.main()

torch_xla/distributed/spmd/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
mark_sharding, mark_sharding_with_gradients, clear_sharding, get_1d_mesh,
55
wrap_if_sharded, xla_patched_nn_linear_forward, set_global_mesh,
66
get_global_mesh, _mark_manual_sharding, enable_manual_sharding,
7-
disable_manual_sharding, apply_backward_optimization_barrier)
7+
disable_manual_sharding, apply_backward_optimization_barrier, shard_as)
88
from .api import xla_distribute_tensor, xla_distribute_module, auto_policy
99

1010
__all__ = [
@@ -19,6 +19,7 @@
1919
"MarkShardingFunction"
2020
"mark_sharding",
2121
"mark_sharding_with_gradients",
22+
"shard_as",
2223
"clear_sharding",
2324
"get_1d_mesh",
2425
"wrap_if_sharded",

0 commit comments

Comments
 (0)