Skip to content

Commit 7892143

Browse files
authored
Print multi-qubit circuit with asymmetric depolarizing noise correctly (#5931)
Closes #5927. Fixes based on discussion in the issue.
1 parent af1267d commit 7892143

File tree

2 files changed

+45
-13
lines changed

2 files changed

+45
-13
lines changed

cirq-core/cirq/ops/common_channels.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,9 @@ def __repr__(self) -> str:
139139
def __str__(self) -> str:
140140
return 'asymmetric_depolarize(' + f"error_probabilities={self._error_probabilities})"
141141

142-
def _circuit_diagram_info_(self, args: 'protocols.CircuitDiagramInfoArgs') -> str:
142+
def _circuit_diagram_info_(
143+
self, args: 'protocols.CircuitDiagramInfoArgs'
144+
) -> Union[str, Iterable[str]]:
143145
if self._num_qubits == 1:
144146
if args.precision is not None:
145147
return (
@@ -154,7 +156,9 @@ def _circuit_diagram_info_(self, args: 'protocols.CircuitDiagramInfoArgs') -> st
154156
]
155157
else:
156158
error_probabilities = [f"{pauli}:{p}" for pauli, p in self._error_probabilities.items()]
157-
return f"A({', '.join(error_probabilities)})"
159+
return [f"A({', '.join(error_probabilities)})"] + [
160+
f'({i})' for i in range(1, self._num_qubits)
161+
]
158162

159163
@property
160164
def p_i(self) -> float:
@@ -193,13 +197,9 @@ def _json_dict_(self) -> Dict[str, Any]:
193197
return protocols.obj_to_dict_helper(self, ['error_probabilities'])
194198

195199
def _approx_eq_(self, other: Any, atol: float) -> bool:
196-
return (
197-
self._num_qubits == other._num_qubits
198-
and np.isclose(self.p_i, other.p_i, atol=atol).item()
199-
and np.isclose(self.p_x, other.p_x, atol=atol).item()
200-
and np.isclose(self.p_y, other.p_y, atol=atol).item()
201-
and np.isclose(self.p_z, other.p_z, atol=atol).item()
202-
)
200+
self_keys, self_values = zip(*sorted(self.error_probabilities.items()))
201+
other_keys, other_values = zip(*sorted(other.error_probabilities.items()))
202+
return self_keys == other_keys and protocols.approx_eq(self_values, other_values, atol=atol)
203203

204204

205205
def asymmetric_depolarize(

cirq-core/cirq/ops/common_channels_test.py

+36-4
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,38 @@ def test_multi_asymmetric_depolarizing_channel_repr():
805805
)
806806

807807

808+
def test_multi_asymmetric_depolarizing_eq():
809+
a = cirq.asymmetric_depolarize(error_probabilities={'I': 0.8, 'X': 0.2})
810+
b = cirq.asymmetric_depolarize(error_probabilities={'II': 0.8, 'XX': 0.2})
811+
812+
assert not cirq.approx_eq(a, b)
813+
814+
a = cirq.asymmetric_depolarize(error_probabilities={'II': 0.8, 'XX': 0.2})
815+
b = cirq.asymmetric_depolarize(error_probabilities={'II': 2 / 3, 'XX': 1 / 3})
816+
817+
assert not cirq.approx_eq(a, b)
818+
819+
a = cirq.asymmetric_depolarize(error_probabilities={'II': 2 / 3, 'ZZ': 1 / 3})
820+
b = cirq.asymmetric_depolarize(error_probabilities={'II': 2 / 3, 'XX': 1 / 3})
821+
822+
assert not cirq.approx_eq(a, b)
823+
824+
a = cirq.asymmetric_depolarize(0.1, 0.2)
825+
b = cirq.asymmetric_depolarize(error_probabilities={'II': 2 / 3, 'XX': 1 / 3})
826+
827+
assert not cirq.approx_eq(a, b)
828+
829+
a = cirq.asymmetric_depolarize(error_probabilities={'II': 0.667, 'XX': 0.333})
830+
b = cirq.asymmetric_depolarize(error_probabilities={'II': 2 / 3, 'XX': 1 / 3})
831+
832+
assert cirq.approx_eq(a, b, atol=1e-3)
833+
834+
a = cirq.asymmetric_depolarize(error_probabilities={'II': 0.667, 'XX': 0.333})
835+
b = cirq.asymmetric_depolarize(error_probabilities={'XX': 1 / 3, 'II': 2 / 3})
836+
837+
assert cirq.approx_eq(a, b, atol=1e-3)
838+
839+
808840
def test_multi_asymmetric_depolarizing_channel_str():
809841
assert str(cirq.asymmetric_depolarize(error_probabilities={'II': 0.8, 'XX': 0.2})) == (
810842
"asymmetric_depolarize(error_probabilities={'II': 0.8, 'XX': 0.2})"
@@ -814,16 +846,16 @@ def test_multi_asymmetric_depolarizing_channel_str():
814846
def test_multi_asymmetric_depolarizing_channel_text_diagram():
815847
a = cirq.asymmetric_depolarize(error_probabilities={'II': 2 / 3, 'XX': 1 / 3})
816848
assert cirq.circuit_diagram_info(a, args=no_precision) == cirq.CircuitDiagramInfo(
817-
wire_symbols=('A(II:0.6666666666666666, XX:0.3333333333333333)',)
849+
wire_symbols=('A(II:0.6666666666666666, XX:0.3333333333333333)', '(1)')
818850
)
819851
assert cirq.circuit_diagram_info(a, args=round_to_6_prec) == cirq.CircuitDiagramInfo(
820-
wire_symbols=('A(II:0.666667, XX:0.333333)',)
852+
wire_symbols=('A(II:0.666667, XX:0.333333)', '(1)')
821853
)
822854
assert cirq.circuit_diagram_info(a, args=round_to_2_prec) == cirq.CircuitDiagramInfo(
823-
wire_symbols=('A(II:0.67, XX:0.33)',)
855+
wire_symbols=('A(II:0.67, XX:0.33)', '(1)')
824856
)
825857
assert cirq.circuit_diagram_info(a, args=no_precision) == cirq.CircuitDiagramInfo(
826-
wire_symbols=('A(II:0.6666666666666666, XX:0.3333333333333333)',)
858+
wire_symbols=('A(II:0.6666666666666666, XX:0.3333333333333333)', '(1)')
827859
)
828860

829861

0 commit comments

Comments
 (0)