Skip to content

Commit 563d13f

Browse files
Test substring for quimb objects (quantumlib#5529)
Fixes quantumlib#5524. Based on quantumlib#5525 discussion, we don't need to pin quimb (yet!), but testing exact strings for quimb objects is asking for trouble. This reduces the tests to substring checks, which should be less vulnerable to minor quimb movements (e.g. this passes with both quimb v1.3.0 and v1.4.0).
1 parent 0ac0953 commit 563d13f

File tree

4 files changed

+56
-48
lines changed

4 files changed

+56
-48
lines changed

cirq/contrib/quimb/mps_simulator_test.py

Lines changed: 12 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -266,19 +266,12 @@ def test_trial_result_str():
266266
prng=value.parse_random_state(0),
267267
simulation_options=ccq.mps_simulator.MPSOptions(),
268268
)
269-
assert (
270-
str(
271-
ccq.mps_simulator.MPSTrialResult(
272-
params=cirq.ParamResolver({}),
273-
measurements={'m': np.array([[1]])},
274-
final_simulator_state=final_simulator_state,
275-
)
276-
)
277-
== """measurements: m=1
278-
output state: TensorNetwork([
279-
Tensor(shape=(2,), inds=('i_0',), tags=oset([])),
280-
])"""
269+
result = ccq.mps_simulator.MPSTrialResult(
270+
params=cirq.ParamResolver({}),
271+
measurements={'m': np.array([[1]])},
272+
final_simulator_state=final_simulator_state,
281273
)
274+
assert 'output state: TensorNetwork' in str(result)
282275

283276

284277
def test_trial_result_repr_pretty():
@@ -293,40 +286,22 @@ def test_trial_result_repr_pretty():
293286
measurements={'m': np.array([[1]])},
294287
final_simulator_state=final_simulator_state,
295288
)
296-
cirq.testing.assert_repr_pretty(
297-
result,
298-
"""measurements: m=1
299-
output state: TensorNetwork([
300-
Tensor(shape=(2,), inds=('i_0',), tags=oset([])),
301-
])""",
302-
)
289+
cirq.testing.assert_repr_pretty_contains(result, 'output state: TensorNetwork')
303290
cirq.testing.assert_repr_pretty(result, "cirq.MPSTrialResult(...)", cycle=True)
304291

305292

306293
def test_empty_step_result():
307294
q0 = cirq.LineQubit(0)
308295
sim = ccq.mps_simulator.MPSSimulator()
309296
step_result = next(sim.simulate_moment_steps(cirq.Circuit(cirq.measure(q0))))
310-
assert (
311-
str(step_result)
312-
== """q(0)=0
313-
TensorNetwork([
314-
Tensor(shape=(2,), inds=('i_0',), tags=oset([])),
315-
])"""
316-
)
297+
assert 'TensorNetwork' in str(step_result)
317298

318299

319300
def test_step_result_repr_pretty():
320301
q0 = cirq.LineQubit(0)
321302
sim = ccq.mps_simulator.MPSSimulator()
322303
step_result = next(sim.simulate_moment_steps(cirq.Circuit(cirq.measure(q0))))
323-
cirq.testing.assert_repr_pretty(
324-
step_result,
325-
"""q(0)=0
326-
TensorNetwork([
327-
Tensor(shape=(2,), inds=('i_0',), tags=oset([])),
328-
])""",
329-
)
304+
cirq.testing.assert_repr_pretty_contains(step_result, 'TensorNetwork')
330305
cirq.testing.assert_repr_pretty(step_result, "cirq.MPSSimulatorStepResult(...)", cycle=True)
331306

332307

@@ -391,13 +366,8 @@ def test_simulate_moment_steps_sample():
391366
step._simulator_state().to_numpy(),
392367
np.asarray([1.0 / math.sqrt(2), 0.0, 1.0 / math.sqrt(2), 0.0]),
393368
)
394-
assert (
395-
str(step)
396-
== """TensorNetwork([
397-
Tensor(shape=(2,), inds=('i_0',), tags=oset([])),
398-
Tensor(shape=(2,), inds=('i_1',), tags=oset([])),
399-
])"""
400-
)
369+
# There are two "Tensor()" copies in the string.
370+
assert len(str(step).split('Tensor(')) == 3
401371
samples = step.sample([q0, q1], repetitions=10)
402372
for sample in samples:
403373
assert np.array_equal(sample, [True, False]) or np.array_equal(
@@ -412,13 +382,8 @@ def test_simulate_moment_steps_sample():
412382
step._simulator_state().to_numpy(),
413383
np.asarray([1.0 / math.sqrt(2), 0.0, 0.0, 1.0 / math.sqrt(2)]),
414384
)
415-
assert (
416-
str(step)
417-
== """TensorNetwork([
418-
Tensor(shape=(2, 2), inds=('i_0', 'mu_0_1'), tags=oset([])),
419-
Tensor(shape=(2, 2), inds=('mu_0_1', 'i_1'), tags=oset([])),
420-
])"""
421-
)
385+
# There are two "Tensor()" copies in the string.
386+
assert len(str(step).split('Tensor(')) == 3
422387
samples = step.sample([q0, q1], repetitions=10)
423388
for sample in samples:
424389
assert np.array_equal(sample, [True, True]) or np.array_equal(

cirq/testing/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@
9292
random_two_qubit_circuit_with_czs,
9393
)
9494

95-
from cirq.testing.repr_pretty_tester import assert_repr_pretty, FakePrinter
95+
from cirq.testing.repr_pretty_tester import (
96+
assert_repr_pretty,
97+
assert_repr_pretty_contains,
98+
FakePrinter,
99+
)
96100

97101
from cirq.testing.sample_circuits import nonoptimal_toffoli_circuit

cirq/testing/repr_pretty_tester.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,22 @@ def assert_repr_pretty(val: Any, text: str, cycle: bool = False):
5353
p = FakePrinter()
5454
val._repr_pretty_(p, cycle=cycle)
5555
assert p.text_pretty == text, f"{p.text_pretty} != {text}"
56+
57+
58+
def assert_repr_pretty_contains(val: Any, substr: str, cycle: bool = False):
59+
"""Assert that the given object has a `_repr_pretty_` output that contains the given text.
60+
61+
Args:
62+
val: The object to test.
63+
substr: The string that `_repr_pretty_` is expected to contain.
64+
cycle: The value of `cycle` passed to `_repr_pretty_`. `cycle` represents whether
65+
the call is made with a potential cycle. Typically one should handle the
66+
`cycle` equals `True` case by returning text that does not recursively call
67+
the `_repr_pretty_` to break this cycle.
68+
69+
Raises:
70+
AssertionError: If `_repr_pretty_` does not pretty print the given text.
71+
"""
72+
p = FakePrinter()
73+
val._repr_pretty_(p, cycle=cycle)
74+
assert substr in p.text_pretty, f"{substr} not in {p.text_pretty}"

cirq/testing/repr_pertty_tester_test.py renamed to cirq/testing/repr_pretty_tester_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,23 @@ def _repr_pretty_(self, p, cycle):
4242

4343
cirq.testing.assert_repr_pretty(TestClassMultipleTexts(), "I'm so pretty I am")
4444
cirq.testing.assert_repr_pretty(TestClassMultipleTexts(), "TestClass", cycle=True)
45+
46+
47+
def test_assert_repr_pretty_contains():
48+
class TestClass:
49+
def _repr_pretty_(self, p, cycle):
50+
p.text("TestClass" if cycle else "I'm so pretty")
51+
52+
cirq.testing.assert_repr_pretty_contains(TestClass(), "pretty")
53+
cirq.testing.assert_repr_pretty_contains(TestClass(), "Test", cycle=True)
54+
55+
class TestClassMultipleTexts:
56+
def _repr_pretty_(self, p, cycle):
57+
if cycle:
58+
p.text("TestClass")
59+
else:
60+
p.text("I'm so pretty")
61+
p.text(" I am")
62+
63+
cirq.testing.assert_repr_pretty_contains(TestClassMultipleTexts(), "I am")
64+
cirq.testing.assert_repr_pretty_contains(TestClassMultipleTexts(), "Class", cycle=True)

0 commit comments

Comments
 (0)