Skip to content

Commit 9dcfe1c

Browse files
committed
yapf
1 parent 4337f9e commit 9dcfe1c

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

test/spmd/test_xla_sharding.py

+12-15
Original file line numberDiff line numberDiff line change
@@ -618,9 +618,9 @@ def test_inplace_add_with_sharding(self):
618618

619619
# avoid calling xr.addressable_device_count here otherwise it will init the test
620620
# in non-spmd mode.
621-
@unittest.skipIf(
622-
xr.device_type() == 'CPU',
623-
"sharding will be the same for both tensors on single device")
621+
@unittest.skipIf(xr.device_type() == 'CPU',
622+
"sharding will be the same for both tensors on single device"
623+
)
624624
def test_shard_hashing(self):
625625
xt1 = torch.ones(2, 2).to(xm.xla_device())
626626
xt2 = torch.ones(2, 2).to(xm.xla_device())
@@ -1383,9 +1383,8 @@ def test_get_1d_mesh(self):
13831383
self.assertEqual(mesh_without_name.mesh_shape,
13841384
(xr.global_runtime_device_count(),))
13851385

1386-
@unittest.skipUnless(
1387-
xr.global_runtime_device_count() > 1,
1388-
"Multiple devices required for dataloader sharding test")
1386+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
1387+
"Multiple devices required for dataloader sharding test")
13891388
def test_data_loader_with_sharding(self):
13901389
device = torch_xla.device()
13911390
mesh = xs.get_1d_mesh("data")
@@ -1406,9 +1405,8 @@ def test_data_loader_with_sharding(self):
14061405
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
14071406
)
14081407

1409-
@unittest.skipUnless(
1410-
xr.global_runtime_device_count() > 1,
1411-
"Multiple devices required for dataloader sharding test")
1408+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
1409+
"Multiple devices required for dataloader sharding test")
14121410
def test_data_loader_with_non_batch_size(self):
14131411
device = torch_xla.device()
14141412
mesh = xs.get_1d_mesh("data")
@@ -1429,9 +1427,8 @@ def test_data_loader_with_non_batch_size(self):
14291427
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
14301428
)
14311429

1432-
@unittest.skipUnless(
1433-
xr.global_runtime_device_count() > 1,
1434-
"Multiple devices required for dataloader sharding test")
1430+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
1431+
"Multiple devices required for dataloader sharding test")
14351432
def test_data_loader_with_non_batch_size_and_mini_batch(self):
14361433
device = torch_xla.device()
14371434
mesh = xs.get_1d_mesh("data")
@@ -1663,9 +1660,9 @@ def test_get_logical_mesh(self):
16631660
self.assertEqual(logical_mesh.shape, mesh_shape)
16641661
np.testing.assert_array_equal(np.sort(logical_mesh.flatten()), device_ids)
16651662

1666-
@unittest.skipIf(
1667-
xr.device_type() == 'CPU',
1668-
"sharding will be the same for both tensors on single device")
1663+
@unittest.skipIf(xr.device_type() == 'CPU',
1664+
"sharding will be the same for both tensors on single device"
1665+
)
16691666
def test_shard_as(self):
16701667
mesh = self._get_mesh((self.n_devices,))
16711668
partition_spec = (0,)

0 commit comments

Comments
 (0)