Skip to content

Commit 4337f9e

Browse files
committed
Fix tests
1 parent 0bbfd1b commit 4337f9e

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

test/scan/test_scan.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,9 @@ def unpack(x):
475475

476476
# Find the input that is stored in the context object.
477477
stored_xs = None
478-
for s in storage:
478+
# Dedupe the tensors because the autograd context may save the same tensor twice.
479+
# Saving a tensor twice won't use extra storage though thanks to ref-counting.
480+
for s in set(storage):
479481
if s.shape == xs.shape:
480482
assert stored_xs is None
481483
stored_xs = s

test/spmd/test_xla_sharding.py

+15-10
Original file line numberDiff line numberDiff line change
@@ -618,9 +618,9 @@ def test_inplace_add_with_sharding(self):
618618

619619
# avoid calling xr.addressable_device_count here otherwise it will init the test
620620
# in non-spmd mode.
621-
@unittest.skipIf(xr.device_type() == 'CPU',
622-
"sharding will be the same for both tensors on single device"
623-
)
621+
@unittest.skipIf(
622+
xr.device_type() == 'CPU',
623+
"sharding will be the same for both tensors on single device")
624624
def test_shard_hashing(self):
625625
xt1 = torch.ones(2, 2).to(xm.xla_device())
626626
xt2 = torch.ones(2, 2).to(xm.xla_device())
@@ -1383,8 +1383,9 @@ def test_get_1d_mesh(self):
13831383
self.assertEqual(mesh_without_name.mesh_shape,
13841384
(xr.global_runtime_device_count(),))
13851385

1386-
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
1387-
"Multiple devices required for dataloader sharding test")
1386+
@unittest.skipUnless(
1387+
xr.global_runtime_device_count() > 1,
1388+
"Multiple devices required for dataloader sharding test")
13881389
def test_data_loader_with_sharding(self):
13891390
device = torch_xla.device()
13901391
mesh = xs.get_1d_mesh("data")
@@ -1405,8 +1406,9 @@ def test_data_loader_with_sharding(self):
14051406
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
14061407
)
14071408

1408-
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
1409-
"Multiple devices required for dataloader sharding test")
1409+
@unittest.skipUnless(
1410+
xr.global_runtime_device_count() > 1,
1411+
"Multiple devices required for dataloader sharding test")
14101412
def test_data_loader_with_non_batch_size(self):
14111413
device = torch_xla.device()
14121414
mesh = xs.get_1d_mesh("data")
@@ -1427,8 +1429,9 @@ def test_data_loader_with_non_batch_size(self):
14271429
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
14281430
)
14291431

1430-
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
1431-
"Multiple devices required for dataloader sharding test")
1432+
@unittest.skipUnless(
1433+
xr.global_runtime_device_count() > 1,
1434+
"Multiple devices required for dataloader sharding test")
14321435
def test_data_loader_with_non_batch_size_and_mini_batch(self):
14331436
device = torch_xla.device()
14341437
mesh = xs.get_1d_mesh("data")
@@ -1569,7 +1572,6 @@ def test_mark_sharding_with_gradients_annotation(self):
15691572
# Check that the gradient has sharding.
15701573
self.assertIn(sharding_spec, x_grad_sharding)
15711574

1572-
<<<<<<< HEAD
15731575
def test_valid_mesh_creation(self):
15741576
mesh_shape = (1, self.n_devices)
15751577
axis_names = ('data', 'model')
@@ -1661,6 +1663,9 @@ def test_get_logical_mesh(self):
16611663
self.assertEqual(logical_mesh.shape, mesh_shape)
16621664
np.testing.assert_array_equal(np.sort(logical_mesh.flatten()), device_ids)
16631665

1666+
@unittest.skipIf(
1667+
xr.device_type() == 'CPU',
1668+
"sharding will be the same for both tensors on single device")
16641669
def test_shard_as(self):
16651670
mesh = self._get_mesh((self.n_devices,))
16661671
partition_spec = (0,)

0 commit comments

Comments
 (0)