Skip to content

Commit 1fccd44

Browse files
authored
Fix state ordering in Python (#1122)
This change uses the vector returned from the internals directly rather than converting into a hashmap so that the state ordering can be preserved for display. Fixes #1119 Before: ![image](https://github.com/microsoft/qsharp/assets/10567287/4ff3d4d1-021b-4b27-b797-266312cc13cc) After: ![image](https://github.com/microsoft/qsharp/assets/10567287/eff67e4e-a756-45e3-b246-16829f9befcf)
1 parent 09be372 commit 1fccd44

File tree

3 files changed

+28
-19
lines changed

3 files changed

+28
-19
lines changed

pip/src/displayable_output.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@ mod tests;
77
use num_bigint::BigUint;
88
use num_complex::{Complex64, ComplexFloat};
99
use qsc::{fmt_basis_state_label, fmt_complex, format_state_id, get_phase};
10-
use rustc_hash::FxHashMap;
1110
use std::fmt::Write;
1211

1312
#[derive(Clone)]
14-
pub struct DisplayableState(pub FxHashMap<BigUint, Complex64>, pub usize);
13+
pub struct DisplayableState(pub Vec<(BigUint, Complex64)>, pub usize);
1514

1615
impl DisplayableState {
1716
pub fn to_plain(&self) -> String {

pip/src/displayable_output/tests.rs

+19-10
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,39 @@
33

44
use num_bigint::BigUint;
55
use num_complex::Complex;
6-
use rustc_hash::FxHashMap;
76

87
use crate::displayable_output::DisplayableState;
98

109
#[test]
1110
fn display_neg_zero() {
12-
let s = DisplayableState(
13-
vec![(BigUint::default(), Complex::new(-0.0, -0.0))]
14-
.into_iter()
15-
.collect::<FxHashMap<_, _>>(),
16-
1,
17-
);
11+
let s = DisplayableState(vec![(BigUint::default(), Complex::new(-0.0, -0.0))], 1);
1812
// -0 should be displayed as 0.0000 without a minus sign
1913
assert_eq!("STATE:\n|0⟩: 0.0000+0.0000𝑖", s.to_plain());
2014
}
2115

2216
#[test]
2317
fn display_rounds_to_neg_zero() {
2418
let s = DisplayableState(
25-
vec![(BigUint::default(), Complex::new(-0.00001, -0.00001))]
26-
.into_iter()
27-
.collect::<FxHashMap<_, _>>(),
19+
vec![(BigUint::default(), Complex::new(-0.00001, -0.00001))],
2820
1,
2921
);
3022
// -0.00001 should be displayed as 0.0000 without a minus sign
3123
assert_eq!("STATE:\n|0⟩: 0.0000+0.0000𝑖", s.to_plain());
3224
}
25+
26+
#[test]
27+
fn display_preserves_order() {
28+
let s = DisplayableState(
29+
vec![
30+
(BigUint::from(0_u64), Complex::new(0.0, 0.0)),
31+
(BigUint::from(1_u64), Complex::new(0.0, 1.0)),
32+
(BigUint::from(2_u64), Complex::new(1.0, 0.0)),
33+
(BigUint::from(3_u64), Complex::new(1.0, 1.0)),
34+
],
35+
2,
36+
);
37+
assert_eq!(
38+
"STATE:\n|00⟩: 0.0000+0.0000𝑖\n|01⟩: 0.0000+1.0000𝑖\n|10⟩: 1.0000+0.0000𝑖\n|11⟩: 1.0000+1.0000𝑖",
39+
s.to_plain()
40+
);
41+
}

pip/src/interpreter.rs

+8-7
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ use qsc::{
2828
PackageType, SourceMap,
2929
};
3030
use resource_estimator::{self as re, estimate_expr};
31-
use rustc_hash::FxHashMap;
3231
use std::fmt::Write;
3332

3433
#[pymodule]
@@ -168,10 +167,7 @@ impl Interpreter {
168167
/// pairs of real and imaginary amplitudes.
169168
fn dump_machine(&mut self) -> StateDump {
170169
let (state, qubit_count) = self.interpreter.get_quantum_state();
171-
StateDump(DisplayableState(
172-
state.into_iter().collect::<FxHashMap<_, _>>(),
173-
qubit_count,
174-
))
170+
StateDump(DisplayableState(state, qubit_count))
175171
}
176172

177173
fn run(
@@ -336,7 +332,13 @@ impl StateDump {
336332
// Pass by value is needed for compatiblity with the pyo3 API.
337333
#[allow(clippy::needless_pass_by_value)]
338334
fn __getitem__(&self, key: BigUint) -> Option<(f64, f64)> {
339-
self.0 .0.get(&key).map(|state| (state.re, state.im))
335+
self.0 .0.iter().find_map(|state| {
336+
if state.0 == key {
337+
Some((state.1.re, state.1.im))
338+
} else {
339+
None
340+
}
341+
})
340342
}
341343

342344
fn __len__(&self) -> usize {
@@ -459,7 +461,6 @@ impl Receiver for OptionalCallbackReceiver<'_> {
459461
qubit_count: usize,
460462
) -> core::result::Result<(), Error> {
461463
if let Some(callback) = &self.callback {
462-
let state = state.into_iter().collect::<FxHashMap<_, _>>();
463464
let out = DisplayableOutput::State(DisplayableState(state, qubit_count));
464465
callback
465466
.call1(

0 commit comments

Comments
 (0)