Skip to content

Commit f7cae83

Browse files
committed
Address review comments
1 parent c80fe6f commit f7cae83

File tree

2 files changed

+32
-22
lines changed

2 files changed

+32
-22
lines changed

tensorflow_quantum/core/serialize/serializer.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -990,10 +990,10 @@ def serialize_projectorsum(projectorsum):
990990
raise TypeError("serialize requires a cirq.ProjectorSum object."
991991
" Given: " + str(type(projectorsum)))
992992

993-
if any(not isinstance(qubit, cirq.GridQubit)
993+
if any(not isinstance(qubit, (cirq.GridQubit, cirq.LineQubit))
994994
for qubit in projectorsum.qubits):
995995
raise ValueError("Attempted to serialize a paulisum that doesn't use "
996-
"only cirq.GridQubits.")
996+
"only cirq.GridQubit or cirq.LineQubit.")
997997

998998
projectorsum_proto = projector_sum_pb2.ProjectorSum()
999999
for term in projectorsum:
@@ -1003,13 +1003,9 @@ def serialize_projectorsum(projectorsum):
10031003
projectorterm_proto.coefficient_imag = term.coefficient.imag
10041004
for qubit, basis_state in sorted(
10051005
term.projector_dict.items()): # sort to keep qubits ordered
1006-
if basis_state:
1007-
projectorterm_proto.projector_dict.add(
1008-
qubit_id=op_serializer.qubit_to_proto(qubit),
1009-
basis_state=basis_state)
1010-
else:
1011-
projectorterm_proto.projector_dict.add(
1012-
qubit_id=op_serializer.qubit_to_proto(qubit))
1006+
projectorterm_proto.projector_dict.add(
1007+
qubit_id=op_serializer.qubit_to_proto(qubit),
1008+
basis_state=basis_state)
10131009

10141010
projectorsum_proto.terms.extend([projectorterm_proto])
10151011

tensorflow_quantum/core/serialize/serializer_test.py

+27-13
Original file line numberDiff line numberDiff line change
@@ -439,24 +439,32 @@ def _get_valid_pauli_proto_pairs(qubit_type='grid'):
439439
return pairs
440440

441441

442-
def _get_valid_projector_proto_pairs():
442+
def _get_valid_projector_proto_pairs(qubit_type='grid'):
443443
"""Generate valid projectorsum proto pairs."""
444444
q0 = cirq.GridQubit(0, 0)
445445
q1 = cirq.GridQubit(1, 0)
446+
q0_str = '0_0'
447+
q1_str = '1_0'
448+
449+
if qubit_type == 'line':
450+
q0 = cirq.LineQubit(0)
451+
q1 = cirq.LineQubit(1)
452+
q0_str = '0'
453+
q1_str = '1'
446454

447455
pairs = [
448456
(cirq.ProjectorSum.from_projector_strings(
449457
cirq.ProjectorString(projector_dict={q0: 0})),
450-
_build_projector_proto([1.0], [[0]], [['0_0']])),
458+
_build_projector_proto([1.0], [[0]], [[q0_str]])),
451459
(cirq.ProjectorSum.from_projector_strings(
452460
cirq.ProjectorString(projector_dict={q0: 0}, coefficient=0.125j)),
453-
_build_projector_proto([0.125j], [[0]], [['0_0']])),
461+
_build_projector_proto([0.125j], [[0]], [[q0_str]])),
454462
(cirq.ProjectorSum.from_projector_strings([
455463
cirq.ProjectorString(projector_dict={
456464
q0: 0,
457465
q1: 1
458466
}),
459-
]), _build_projector_proto([1.0], [[0, 1]], [['0_0', '1_0']])),
467+
]), _build_projector_proto([1.0], [[0, 1]], [[q0_str, q1_str]])),
460468
]
461469

462470
return pairs
@@ -779,26 +787,32 @@ def test_serialize_projectorsum_invalid(self):
779787
with self.assertRaises(ValueError):
780788
serializer.serialize_projectorsum(a)
781789

782-
@parameterized.parameters([{
783-
'sum_proto_pair': v
784-
} for v in _get_valid_projector_proto_pairs()])
790+
@parameterized.parameters(
791+
[{
792+
'sum_proto_pair': v
793+
} for v in _get_valid_projector_proto_pairs(qubit_type='grid') +
794+
_get_valid_projector_proto_pairs(qubit_type='line')])
785795
def test_serialize_projectorsum_simple(self, sum_proto_pair):
786796
"""Ensure serialization is correct."""
787797
self.assertProtoEquals(
788798
sum_proto_pair[1],
789799
serializer.serialize_projectorsum(sum_proto_pair[0]))
790800

791-
@parameterized.parameters([{
792-
'sum_proto_pair': v
793-
} for v in _get_valid_projector_proto_pairs()])
801+
@parameterized.parameters(
802+
[{
803+
'sum_proto_pair': v
804+
} for v in _get_valid_projector_proto_pairs(qubit_type='grid') +
805+
_get_valid_projector_proto_pairs(qubit_type='line')])
794806
def test_deserialize_projectorsum_simple(self, sum_proto_pair):
795807
"""Ensure deserialization is correct."""
796808
self.assertEqual(serializer.deserialize_projectorsum(sum_proto_pair[1]),
797809
sum_proto_pair[0])
798810

799-
@parameterized.parameters([{
800-
'sum_proto_pair': v
801-
} for v in _get_valid_projector_proto_pairs()])
811+
@parameterized.parameters(
812+
[{
813+
'sum_proto_pair': v
814+
} for v in _get_valid_projector_proto_pairs(qubit_type='grid') +
815+
_get_valid_projector_proto_pairs(qubit_type='line')])
802816
def test_serialize_deserialize_projectorsum_consistency(
803817
self, sum_proto_pair):
804818
"""Serialize and deserialize and ensure nothing changed."""

0 commit comments

Comments
 (0)