Skip to content

Commit 6796c57

Browse files
committed
miri: make NaN generation non-deterministic
1 parent d087c6f commit 6796c57

File tree

6 files changed

+385
-25
lines changed

6 files changed

+385
-25
lines changed

compiler/rustc_const_eval/src/interpret/intrinsics.rs

+6
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,9 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
500500
b: &ImmTy<'tcx, M::Provenance>,
501501
dest: &PlaceTy<'tcx, M::Provenance>,
502502
) -> InterpResult<'tcx> {
503+
assert_eq!(a.layout.ty, b.layout.ty);
504+
assert!(matches!(a.layout.ty.kind(), ty::Int(..) | ty::Uint(..)));
505+
503506
// Performs an exact division, resulting in undefined behavior where
504507
// `x % y != 0` or `y == 0` or `x == T::MIN && y == -1`.
505508
// First, check x % y != 0 (or if that computation overflows).
@@ -522,7 +525,10 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
522525
l: &ImmTy<'tcx, M::Provenance>,
523526
r: &ImmTy<'tcx, M::Provenance>,
524527
) -> InterpResult<'tcx, Scalar<M::Provenance>> {
528+
assert_eq!(l.layout.ty, r.layout.ty);
529+
assert!(matches!(l.layout.ty.kind(), ty::Int(..) | ty::Uint(..)));
525530
assert!(matches!(mir_op, BinOp::Add | BinOp::Sub));
531+
526532
let (val, overflowed) = self.overflowing_binary_op(mir_op, l, r)?;
527533
Ok(if overflowed {
528534
let size = l.layout.size;

compiler/rustc_const_eval/src/interpret/machine.rs

+8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::borrow::{Borrow, Cow};
66
use std::fmt::Debug;
77
use std::hash::Hash;
88

9+
use rustc_apfloat::Float;
910
use rustc_ast::{InlineAsmOptions, InlineAsmTemplatePiece};
1011
use rustc_middle::mir;
1112
use rustc_middle::ty::layout::TyAndLayout;
@@ -240,6 +241,13 @@ pub trait Machine<'mir, 'tcx: 'mir>: Sized {
240241
right: &ImmTy<'tcx, Self::Provenance>,
241242
) -> InterpResult<'tcx, (ImmTy<'tcx, Self::Provenance>, bool)>;
242243

244+
/// Generate the NaN returned by a float operation, given the list of inputs.
245+
/// (This is all inputs, not just NaN inputs!)
246+
fn generate_nan<F: Float>(_ecx: &InterpCx<'mir, 'tcx, Self>, _inputs: &[F]) -> F {
247+
// By default we always return the preferred NaN.
248+
F::NAN
249+
}
250+
243251
/// Called before writing the specified `local` of the `frame`.
244252
/// Since writing a ZST is not actually accessing memory or locals, this is never invoked
245253
/// for ZST reads.

compiler/rustc_const_eval/src/interpret/operator.rs

+10-5
Original file line numberDiff line numberDiff line change
@@ -113,18 +113,23 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
113113
) -> (ImmTy<'tcx, M::Provenance>, bool) {
114114
use rustc_middle::mir::BinOp::*;
115115

116+
// Performs appropriate non-deterministic adjustments of NaN results.
117+
let adjust_nan = |f: F| -> F {
118+
if f.is_nan() { M::generate_nan(self, &[l, r]) } else { f }
119+
};
120+
116121
let val = match bin_op {
117122
Eq => ImmTy::from_bool(l == r, *self.tcx),
118123
Ne => ImmTy::from_bool(l != r, *self.tcx),
119124
Lt => ImmTy::from_bool(l < r, *self.tcx),
120125
Le => ImmTy::from_bool(l <= r, *self.tcx),
121126
Gt => ImmTy::from_bool(l > r, *self.tcx),
122127
Ge => ImmTy::from_bool(l >= r, *self.tcx),
123-
Add => ImmTy::from_scalar((l + r).value.into(), layout),
124-
Sub => ImmTy::from_scalar((l - r).value.into(), layout),
125-
Mul => ImmTy::from_scalar((l * r).value.into(), layout),
126-
Div => ImmTy::from_scalar((l / r).value.into(), layout),
127-
Rem => ImmTy::from_scalar((l % r).value.into(), layout),
128+
Add => ImmTy::from_scalar(adjust_nan((l + r).value).into(), layout),
129+
Sub => ImmTy::from_scalar(adjust_nan((l - r).value).into(), layout),
130+
Mul => ImmTy::from_scalar(adjust_nan((l * r).value).into(), layout),
131+
Div => ImmTy::from_scalar(adjust_nan((l / r).value).into(), layout),
132+
Rem => ImmTy::from_scalar(adjust_nan((l % r).value).into(), layout),
128133
_ => span_bug!(self.cur_span(), "invalid float op: `{:?}`", bin_op),
129134
};
130135
(val, false)

src/tools/miri/src/machine.rs

+5
Original file line numberDiff line numberDiff line change
@@ -1001,6 +1001,11 @@ impl<'mir, 'tcx> Machine<'mir, 'tcx> for MiriMachine<'mir, 'tcx> {
10011001
ecx.binary_ptr_op(bin_op, left, right)
10021002
}
10031003

1004+
#[inline(always)]
1005+
fn generate_nan<F: rustc_apfloat::Float>(ecx: &InterpCx<'mir, 'tcx, Self>, inputs: &[F]) -> F {
1006+
ecx.generate_nan(inputs)
1007+
}
1008+
10041009
fn thread_local_static_base_pointer(
10051010
ecx: &mut MiriInterpCx<'mir, 'tcx>,
10061011
def_id: DefId,

src/tools/miri/src/operator.rs

+40-20
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,16 @@
1+
use std::iter;
2+
13
use log::trace;
24

5+
use rand::{seq::IteratorRandom, Rng};
6+
use rustc_apfloat::Float;
37
use rustc_middle::mir;
48
use rustc_target::abi::Size;
59

610
use crate::*;
711

8-
pub trait EvalContextExt<'tcx> {
9-
fn binary_ptr_op(
10-
&self,
11-
bin_op: mir::BinOp,
12-
left: &ImmTy<'tcx, Provenance>,
13-
right: &ImmTy<'tcx, Provenance>,
14-
) -> InterpResult<'tcx, (ImmTy<'tcx, Provenance>, bool)>;
15-
}
16-
17-
impl<'mir, 'tcx> EvalContextExt<'tcx> for super::MiriInterpCx<'mir, 'tcx> {
12+
impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {}
13+
pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
1814
fn binary_ptr_op(
1915
&self,
2016
bin_op: mir::BinOp,
@@ -23,12 +19,13 @@ impl<'mir, 'tcx> EvalContextExt<'tcx> for super::MiriInterpCx<'mir, 'tcx> {
2319
) -> InterpResult<'tcx, (ImmTy<'tcx, Provenance>, bool)> {
2420
use rustc_middle::mir::BinOp::*;
2521

22+
let this = self.eval_context_ref();
2623
trace!("ptr_op: {:?} {:?} {:?}", *left, bin_op, *right);
2724

2825
Ok(match bin_op {
2926
Eq | Ne | Lt | Le | Gt | Ge => {
3027
assert_eq!(left.layout.abi, right.layout.abi); // types an differ, e.g. fn ptrs with different `for`
31-
let size = self.pointer_size();
28+
let size = this.pointer_size();
3229
// Just compare the bits. ScalarPairs are compared lexicographically.
3330
// We thus always compare pairs and simply fill scalars up with 0.
3431
let left = match **left {
@@ -50,34 +47,57 @@ impl<'mir, 'tcx> EvalContextExt<'tcx> for super::MiriInterpCx<'mir, 'tcx> {
5047
Ge => left >= right,
5148
_ => bug!(),
5249
};
53-
(ImmTy::from_bool(res, *self.tcx), false)
50+
(ImmTy::from_bool(res, *this.tcx), false)
5451
}
5552

5653
// Some more operations are possible with atomics.
5754
// The return value always has the provenance of the *left* operand.
5855
Add | Sub | BitOr | BitAnd | BitXor => {
5956
assert!(left.layout.ty.is_unsafe_ptr());
6057
assert!(right.layout.ty.is_unsafe_ptr());
61-
let ptr = left.to_scalar().to_pointer(self)?;
58+
let ptr = left.to_scalar().to_pointer(this)?;
6259
// We do the actual operation with usize-typed scalars.
63-
let left = ImmTy::from_uint(ptr.addr().bytes(), self.machine.layouts.usize);
60+
let left = ImmTy::from_uint(ptr.addr().bytes(), this.machine.layouts.usize);
6461
let right = ImmTy::from_uint(
65-
right.to_scalar().to_target_usize(self)?,
66-
self.machine.layouts.usize,
62+
right.to_scalar().to_target_usize(this)?,
63+
this.machine.layouts.usize,
6764
);
68-
let (result, overflowing) = self.overflowing_binary_op(bin_op, &left, &right)?;
65+
let (result, overflowing) = this.overflowing_binary_op(bin_op, &left, &right)?;
6966
// Construct a new pointer with the provenance of `ptr` (the LHS).
7067
let result_ptr = Pointer::new(
7168
ptr.provenance,
72-
Size::from_bytes(result.to_scalar().to_target_usize(self)?),
69+
Size::from_bytes(result.to_scalar().to_target_usize(this)?),
7370
);
7471
(
75-
ImmTy::from_scalar(Scalar::from_maybe_pointer(result_ptr, self), left.layout),
72+
ImmTy::from_scalar(Scalar::from_maybe_pointer(result_ptr, this), left.layout),
7673
overflowing,
7774
)
7875
}
7976

80-
_ => span_bug!(self.cur_span(), "Invalid operator on pointers: {:?}", bin_op),
77+
_ => span_bug!(this.cur_span(), "Invalid operator on pointers: {:?}", bin_op),
8178
})
8279
}
80+
81+
fn generate_nan<F: Float>(&self, inputs: &[F]) -> F {
82+
let this = self.eval_context_ref();
83+
let mut rand = this.machine.rng.borrow_mut();
84+
// Assemble an iterator of possible NaNs: preferred, unchanged propagation, quieting propagation.
85+
let preferred_nan = F::qnan(Some(0));
86+
let nans = iter::once(preferred_nan)
87+
.chain(inputs.iter().filter(|f| f.is_nan()).copied())
88+
.chain(inputs.iter().filter(|f| f.is_signaling()).map(|f| {
89+
// Make it quiet, by setting the bit. We assume that `preferred_nan`
90+
// only has bits set that all quiet NaNs need to have set.
91+
F::from_bits(f.to_bits() | preferred_nan.to_bits())
92+
}));
93+
// Pick one of the NaNs.
94+
let nan = nans.choose(&mut *rand).unwrap();
95+
// Non-deterministically flip the sign.
96+
if rand.gen() {
97+
// This will properly flip even for NaN.
98+
-nan
99+
} else {
100+
nan
101+
}
102+
}
83103
}

0 commit comments

Comments
 (0)