Skip to content

Commit aa7afff

Browse files
committed
Auto merge of rust-lang#3216 - eduardosm:fix-ptestnzc, r=RalfJung
Fix x86 SSE4.1 ptestnzc Fixes ptestnzc by bringing back the original implementation of rust-lang/miri#3214. `(op & mask) != 0 && (op & mask) == !ask` need to be calculated for the whole vector. It cannot be calculated for each element and then folded. For example, given * `op = [0b100, 0b010]` * `mask = [0b100, 0b110]` The correct result would be: * `op & mask = [0b100, 0b010]` Comparisons are done on the vector as a whole: * `all_zero = (op & mask) == [0, 0] = false` * `masked_set = (op & mask) == mask = false` * `!all_zero && !masked_set = true` correct result The previous method: * `op & mask = [0b100, 0b010]` Comparisons are done element-wise: * `all_zero = (op & mask) == [0, 0] = [true, true]` * `masked_set = (op & mask) == mask = [true, false]` * `!all_zero && !masked_set = [true, false]` After folding with AND, the final result would be `false`, which is incorrect.
2 parents 57935c3 + d571256 commit aa7afff

File tree

3 files changed

+41
-36
lines changed

3 files changed

+41
-36
lines changed

src/tools/miri/src/shims/x86/mod.rs

+26-23
Original file line numberDiff line numberDiff line change
@@ -666,30 +666,33 @@ fn conditional_dot_product<'tcx>(
666666
Ok(())
667667
}
668668

669-
/// Folds SIMD vectors `lhs` and `rhs` into a value of type `T` using `f`.
670-
fn bin_op_folded<'tcx, T>(
669+
/// Calculates two booleans.
670+
///
671+
/// The first is true when all the bits of `op & mask` are zero.
672+
/// The second is true when `(op & mask) == mask`
673+
fn test_bits_masked<'tcx>(
671674
this: &crate::MiriInterpCx<'_, 'tcx>,
672-
lhs: &OpTy<'tcx, Provenance>,
673-
rhs: &OpTy<'tcx, Provenance>,
674-
init: T,
675-
mut f: impl FnMut(T, ImmTy<'tcx, Provenance>, ImmTy<'tcx, Provenance>) -> InterpResult<'tcx, T>,
676-
) -> InterpResult<'tcx, T> {
677-
assert_eq!(lhs.layout, rhs.layout);
678-
679-
let (lhs, lhs_len) = this.operand_to_simd(lhs)?;
680-
let (rhs, rhs_len) = this.operand_to_simd(rhs)?;
681-
682-
assert_eq!(lhs_len, rhs_len);
683-
684-
let mut acc = init;
685-
for i in 0..lhs_len {
686-
let lhs = this.project_index(&lhs, i)?;
687-
let rhs = this.project_index(&rhs, i)?;
688-
689-
let lhs = this.read_immediate(&lhs)?;
690-
let rhs = this.read_immediate(&rhs)?;
691-
acc = f(acc, lhs, rhs)?;
675+
op: &OpTy<'tcx, Provenance>,
676+
mask: &OpTy<'tcx, Provenance>,
677+
) -> InterpResult<'tcx, (bool, bool)> {
678+
assert_eq!(op.layout, mask.layout);
679+
680+
let (op, op_len) = this.operand_to_simd(op)?;
681+
let (mask, mask_len) = this.operand_to_simd(mask)?;
682+
683+
assert_eq!(op_len, mask_len);
684+
685+
let mut all_zero = true;
686+
let mut masked_set = true;
687+
for i in 0..op_len {
688+
let op = this.project_index(&op, i)?;
689+
let mask = this.project_index(&mask, i)?;
690+
691+
let op = this.read_scalar(&op)?.to_uint(op.layout.size)?;
692+
let mask = this.read_scalar(&mask)?.to_uint(mask.layout.size)?;
693+
all_zero &= (op & mask) == 0;
694+
masked_set &= (op & mask) == mask;
692695
}
693696

694-
Ok(acc)
697+
Ok((all_zero, masked_set))
695698
}

src/tools/miri/src/shims/x86/sse41.rs

+10-13
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use rustc_span::Symbol;
22
use rustc_target::spec::abi::Abi;
33

4-
use super::{bin_op_folded, conditional_dot_product, round_all, round_first};
4+
use super::{conditional_dot_product, round_all, round_first, test_bits_masked};
55
use crate::*;
66
use shims::foreign_items::EmulateForeignItemResult;
77

@@ -217,21 +217,18 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
217217
}
218218
// Used to implement the _mm_testz_si128, _mm_testc_si128
219219
// and _mm_testnzc_si128 functions.
220-
// Tests `op & mask == 0`, `op & mask == mask` or
221-
// `op & mask != 0 && op & mask != mask`
220+
// Tests `(op & mask) == 0`, `(op & mask) == mask` or
221+
// `(op & mask) != 0 && (op & mask) != mask`
222222
"ptestz" | "ptestc" | "ptestnzc" => {
223223
let [op, mask] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
224224

225-
let res = bin_op_folded(this, op, mask, true, |acc, op, mask| {
226-
let op = op.to_scalar().to_uint(op.layout.size)?;
227-
let mask = mask.to_scalar().to_uint(mask.layout.size)?;
228-
Ok(match unprefixed_name {
229-
"ptestz" => acc && (op & mask) == 0,
230-
"ptestc" => acc && (op & mask) == mask,
231-
"ptestnzc" => acc && (op & mask) != 0 && (op & mask) != mask,
232-
_ => unreachable!(),
233-
})
234-
})?;
225+
let (all_zero, masked_set) = test_bits_masked(this, op, mask)?;
226+
let res = match unprefixed_name {
227+
"ptestz" => all_zero,
228+
"ptestc" => masked_set,
229+
"ptestnzc" => !all_zero && !masked_set,
230+
_ => unreachable!(),
231+
};
235232

236233
this.write_scalar(Scalar::from_i32(res.into()), dest)?;
237234
}

src/tools/miri/tests/pass/intrinsics-x86-sse41.rs

+5
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,11 @@ unsafe fn test_sse41() {
515515
let mask = _mm_set1_epi8(0b101);
516516
let r = _mm_testnzc_si128(a, mask);
517517
assert_eq!(r, 0);
518+
519+
let a = _mm_setr_epi32(0b100, 0, 0, 0b010);
520+
let mask = _mm_setr_epi32(0b100, 0, 0, 0b110);
521+
let r = _mm_testnzc_si128(a, mask);
522+
assert_eq!(r, 1);
518523
}
519524
test_mm_testnzc_si128();
520525
}

0 commit comments

Comments
 (0)