Skip to content

Commit 5f84ac1

Browse files
committed
yapf
1 parent d9bbbe4 commit 5f84ac1

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

test/spmd/test_xla_sharding.py

+12-15
Original file line numberDiff line numberDiff line change
@@ -617,9 +617,9 @@ def test_inplace_add_with_sharding(self):
617617

618618
# avoid calling xr.addressable_device_count here otherwise it will init the test
619619
# in non-spmd mode.
620-
@unittest.skipIf(
621-
xr.device_type() == 'CPU',
622-
"sharding will be the same for both tensors on single device")
620+
@unittest.skipIf(xr.device_type() == 'CPU',
621+
"sharding will be the same for both tensors on single device"
622+
)
623623
def test_shard_hashing(self):
624624
xt1 = torch.ones(2, 2).to(xm.xla_device())
625625
xt2 = torch.ones(2, 2).to(xm.xla_device())
@@ -1379,9 +1379,8 @@ def test_get_1d_mesh(self):
13791379
self.assertEqual(mesh_without_name.mesh_shape,
13801380
(xr.global_runtime_device_count(),))
13811381

1382-
@unittest.skipUnless(
1383-
xr.global_runtime_device_count() > 1,
1384-
"Multiple devices required for dataloader sharding test")
1382+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
1383+
"Multiple devices required for dataloader sharding test")
13851384
def test_data_loader_with_sharding(self):
13861385
device = torch_xla.device()
13871386
mesh = xs.get_1d_mesh("data")
@@ -1402,9 +1401,8 @@ def test_data_loader_with_sharding(self):
14021401
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
14031402
)
14041403

1405-
@unittest.skipUnless(
1406-
xr.global_runtime_device_count() > 1,
1407-
"Multiple devices required for dataloader sharding test")
1404+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
1405+
"Multiple devices required for dataloader sharding test")
14081406
def test_data_loader_with_non_batch_size(self):
14091407
device = torch_xla.device()
14101408
mesh = xs.get_1d_mesh("data")
@@ -1425,9 +1423,8 @@ def test_data_loader_with_non_batch_size(self):
14251423
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
14261424
)
14271425

1428-
@unittest.skipUnless(
1429-
xr.global_runtime_device_count() > 1,
1430-
"Multiple devices required for dataloader sharding test")
1426+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
1427+
"Multiple devices required for dataloader sharding test")
14311428
def test_data_loader_with_non_batch_size_and_mini_batch(self):
14321429
device = torch_xla.device()
14331430
mesh = xs.get_1d_mesh("data")
@@ -1568,9 +1565,9 @@ def test_mark_sharding_with_gradients_annotation(self):
15681565
# Check that the gradient has sharding.
15691566
self.assertIn(sharding_spec, x_grad_sharding)
15701567

1571-
@unittest.skipIf(
1572-
xr.device_type() == 'CPU',
1573-
"sharding will be the same for both tensors on single device")
1568+
@unittest.skipIf(xr.device_type() == 'CPU',
1569+
"sharding will be the same for both tensors on single device"
1570+
)
15741571
def test_shard_as(self):
15751572
mesh = self._get_mesh((self.n_devices,))
15761573
partition_spec = (0,)

0 commit comments

Comments
 (0)