Skip to content

Commit 90df56a

Browse files
authored
Do not generate default repetition ids if use_repetition_ids=False (#5419)
Fixes #5418
1 parent 872e008 commit 90df56a

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

cirq-core/cirq/circuits/circuit_operation.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,16 @@ def default_repetition_ids(repetitions: IntParam) -> Optional[List[str]]:
5555
return None
5656

5757

58-
def _full_join_string_lists(list1: Optional[List[str]], list2: Optional[List[str]]):
58+
def _full_join_string_lists(
59+
list1: Optional[List[str]], list2: Optional[List[str]]
60+
) -> Optional[List[str]]:
5961
if list1 is None and list2 is None:
6062
return None # coverage: ignore
6163
if list1 is None:
6264
return list2 # coverage: ignore
6365
if list2 is None:
6466
return list1
65-
return [
66-
f'{REPETITION_ID_SEPARATOR.join([first, second])}' for first in list1 for second in list2
67-
]
67+
return [f'{first}{REPETITION_ID_SEPARATOR}{second}' for first in list1 for second in list2]
6868

6969

7070
@dataclasses.dataclass(frozen=True)
@@ -224,7 +224,7 @@ def qubits(self) -> Tuple['cirq.Qid', ...]:
224224
return tuple(self.qubit_map.get(q, q) for q in ordered_qubits)
225225

226226
def _default_repetition_ids(self) -> Optional[List[str]]:
227-
return default_repetition_ids(self.repetitions)
227+
return default_repetition_ids(self.repetitions) if self.use_repetition_ids else None
228228

229229
def _qid_shape_(self) -> Tuple[int, ...]:
230230
return tuple(q.dimension for q in self.qubits)
@@ -524,7 +524,8 @@ def repeat(
524524
expected_repetition_id_length = abs(repetitions)
525525

526526
if repetition_ids is None:
527-
repetition_ids = default_repetition_ids(expected_repetition_id_length)
527+
if self.use_repetition_ids:
528+
repetition_ids = default_repetition_ids(expected_repetition_id_length)
528529
elif len(repetition_ids) != expected_repetition_id_length:
529530
raise ValueError(
530531
f'Expected repetition_ids={repetition_ids} length to be '

cirq-core/cirq/circuits/circuit_operation_test.py

+23
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import unittest.mock as mock
1415
from typing import Optional
1516

1617
import numpy as np
1718
import pytest
1819
import sympy
1920

2021
import cirq
22+
import cirq.circuits.circuit_operation as circuit_operation
2123
from cirq.circuits.circuit_operation import _full_join_string_lists
2224

2325
ALL_SIMULATORS = (cirq.Simulator(), cirq.DensityMatrixSimulator(), cirq.CliffordSimulator())
@@ -346,6 +348,27 @@ def test_repeat_zero_times(add_measurements, use_repetition_ids, initial_reps):
346348
assert np.allclose(result.state_vector(), [1, 0])
347349

348350

351+
def test_no_repetition_ids():
352+
def default_repetition_ids(self):
353+
assert False, "Should not call default_repetition_ids"
354+
355+
with mock.patch.object(circuit_operation, 'default_repetition_ids', new=default_repetition_ids):
356+
q = cirq.LineQubit(0)
357+
op = cirq.CircuitOperation(
358+
cirq.Circuit(cirq.X(q), cirq.measure(q)).freeze(),
359+
repetitions=1_000_000,
360+
use_repetition_ids=False,
361+
)
362+
assert op.repetitions == 1_000_000
363+
assert op.repetition_ids is None
364+
_ = repr(op)
365+
_ = str(op)
366+
367+
op2 = op.repeat(10)
368+
assert op2.repetitions == 10_000_000
369+
assert op2.repetition_ids is None
370+
371+
349372
def test_parameterized_repeat():
350373
q = cirq.LineQubit(0)
351374
op = cirq.CircuitOperation(cirq.FrozenCircuit(cirq.X(q))) ** sympy.Symbol('a')

0 commit comments

Comments
 (0)