@@ -617,9 +617,9 @@ def test_inplace_add_with_sharding(self):
617
617
618
618
# avoid calling xr.addressable_device_count here otherwise it will init the test
619
619
# in non-spmd mode.
620
- @unittest .skipIf (
621
- xr . device_type () == 'CPU' ,
622
- "sharding will be the same for both tensors on single device" )
620
+ @unittest .skipIf (xr . device_type () == 'CPU' ,
621
+ "sharding will be the same for both tensors on single device"
622
+ )
623
623
def test_shard_hashing (self ):
624
624
xt1 = torch .ones (2 , 2 ).to (xm .xla_device ())
625
625
xt2 = torch .ones (2 , 2 ).to (xm .xla_device ())
@@ -1379,9 +1379,8 @@ def test_get_1d_mesh(self):
1379
1379
self .assertEqual (mesh_without_name .mesh_shape ,
1380
1380
(xr .global_runtime_device_count (),))
1381
1381
1382
- @unittest .skipUnless (
1383
- xr .global_runtime_device_count () > 1 ,
1384
- "Multiple devices required for dataloader sharding test" )
1382
+ @unittest .skipUnless (xr .global_runtime_device_count () > 1 ,
1383
+ "Multiple devices required for dataloader sharding test" )
1385
1384
def test_data_loader_with_sharding (self ):
1386
1385
device = torch_xla .device ()
1387
1386
mesh = xs .get_1d_mesh ("data" )
@@ -1402,9 +1401,8 @@ def test_data_loader_with_sharding(self):
1402
1401
f"{{devices=[{ mesh .size ()} ,1,1,1]{ ',' .join ([str (i ) for i in range (mesh .size ())])} }}"
1403
1402
)
1404
1403
1405
- @unittest .skipUnless (
1406
- xr .global_runtime_device_count () > 1 ,
1407
- "Multiple devices required for dataloader sharding test" )
1404
+ @unittest .skipUnless (xr .global_runtime_device_count () > 1 ,
1405
+ "Multiple devices required for dataloader sharding test" )
1408
1406
def test_data_loader_with_non_batch_size (self ):
1409
1407
device = torch_xla .device ()
1410
1408
mesh = xs .get_1d_mesh ("data" )
@@ -1425,9 +1423,8 @@ def test_data_loader_with_non_batch_size(self):
1425
1423
f"{{devices=[{ mesh .size ()} ,1,1,1]{ ',' .join ([str (i ) for i in range (mesh .size ())])} }}"
1426
1424
)
1427
1425
1428
- @unittest .skipUnless (
1429
- xr .global_runtime_device_count () > 1 ,
1430
- "Multiple devices required for dataloader sharding test" )
1426
+ @unittest .skipUnless (xr .global_runtime_device_count () > 1 ,
1427
+ "Multiple devices required for dataloader sharding test" )
1431
1428
def test_data_loader_with_non_batch_size_and_mini_batch (self ):
1432
1429
device = torch_xla .device ()
1433
1430
mesh = xs .get_1d_mesh ("data" )
@@ -1568,9 +1565,9 @@ def test_mark_sharding_with_gradients_annotation(self):
1568
1565
# Check that the gradient has sharding.
1569
1566
self .assertIn (sharding_spec , x_grad_sharding )
1570
1567
1571
- @unittest .skipIf (
1572
- xr . device_type () == 'CPU' ,
1573
- "sharding will be the same for both tensors on single device" )
1568
+ @unittest .skipIf (xr . device_type () == 'CPU' ,
1569
+ "sharding will be the same for both tensors on single device"
1570
+ )
1574
1571
def test_shard_as (self ):
1575
1572
mesh = self ._get_mesh ((self .n_devices ,))
1576
1573
partition_spec = (0 ,)
0 commit comments