Skip to content

Commit 061c330

Browse files
committed
Auto merge of rust-lang#116551 - RalfJung:nondet-nan, r=oli-obk
miri: make NaN generation non-deterministic This implements the [LLVM semantics for NaN generation](https://llvm.org/docs/LangRef.html#behavior-of-floating-point-nan-values). I will soon submit an RFC to make this also officially the Rust semantics, but it has been our de-facto semantics for a long time so there's no reason Miri has to wait for that RFC. This PR just better aligns Miri with codegen. This PR does that just for the operations that have MIR primitives; a future PR will adjust the intrinsics.
2 parents 5c37696 + 08deb0d commit 061c330

File tree

7 files changed

+531
-29
lines changed

7 files changed

+531
-29
lines changed

compiler/rustc_const_eval/src/interpret/cast.rs

+21-2
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,21 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
311311
F: Float + Into<Scalar<M::Provenance>> + FloatConvert<Single> + FloatConvert<Double>,
312312
{
313313
use rustc_type_ir::sty::TyKind::*;
314+
315+
fn adjust_nan<
316+
'mir,
317+
'tcx: 'mir,
318+
M: Machine<'mir, 'tcx>,
319+
F1: rustc_apfloat::Float + FloatConvert<F2>,
320+
F2: rustc_apfloat::Float,
321+
>(
322+
ecx: &InterpCx<'mir, 'tcx, M>,
323+
f1: F1,
324+
f2: F2,
325+
) -> F2 {
326+
if f2.is_nan() { M::generate_nan(ecx, &[f1]) } else { f2 }
327+
}
328+
314329
match *dest_ty.kind() {
315330
// float -> uint
316331
Uint(t) => {
@@ -330,9 +345,13 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
330345
Scalar::from_int(v, size)
331346
}
332347
// float -> f32
333-
Float(FloatTy::F32) => Scalar::from_f32(f.convert(&mut false).value),
348+
Float(FloatTy::F32) => {
349+
Scalar::from_f32(adjust_nan(self, f, f.convert(&mut false).value))
350+
}
334351
// float -> f64
335-
Float(FloatTy::F64) => Scalar::from_f64(f.convert(&mut false).value),
352+
Float(FloatTy::F64) => {
353+
Scalar::from_f64(adjust_nan(self, f, f.convert(&mut false).value))
354+
}
336355
// That's it.
337356
_ => span_bug!(self.cur_span(), "invalid float to {} cast", dest_ty),
338357
}

compiler/rustc_const_eval/src/interpret/intrinsics.rs

+6
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,9 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
500500
b: &ImmTy<'tcx, M::Provenance>,
501501
dest: &PlaceTy<'tcx, M::Provenance>,
502502
) -> InterpResult<'tcx> {
503+
assert_eq!(a.layout.ty, b.layout.ty);
504+
assert!(matches!(a.layout.ty.kind(), ty::Int(..) | ty::Uint(..)));
505+
503506
// Performs an exact division, resulting in undefined behavior where
504507
// `x % y != 0` or `y == 0` or `x == T::MIN && y == -1`.
505508
// First, check x % y != 0 (or if that computation overflows).
@@ -522,7 +525,10 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
522525
l: &ImmTy<'tcx, M::Provenance>,
523526
r: &ImmTy<'tcx, M::Provenance>,
524527
) -> InterpResult<'tcx, Scalar<M::Provenance>> {
528+
assert_eq!(l.layout.ty, r.layout.ty);
529+
assert!(matches!(l.layout.ty.kind(), ty::Int(..) | ty::Uint(..)));
525530
assert!(matches!(mir_op, BinOp::Add | BinOp::Sub));
531+
526532
let (val, overflowed) = self.overflowing_binary_op(mir_op, l, r)?;
527533
Ok(if overflowed {
528534
let size = l.layout.size;

compiler/rustc_const_eval/src/interpret/machine.rs

+11
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::borrow::{Borrow, Cow};
66
use std::fmt::Debug;
77
use std::hash::Hash;
88

9+
use rustc_apfloat::{Float, FloatConvert};
910
use rustc_ast::{InlineAsmOptions, InlineAsmTemplatePiece};
1011
use rustc_middle::mir;
1112
use rustc_middle::ty::layout::TyAndLayout;
@@ -240,6 +241,16 @@ pub trait Machine<'mir, 'tcx: 'mir>: Sized {
240241
right: &ImmTy<'tcx, Self::Provenance>,
241242
) -> InterpResult<'tcx, (ImmTy<'tcx, Self::Provenance>, bool)>;
242243

244+
/// Generate the NaN returned by a float operation, given the list of inputs.
245+
/// (This is all inputs, not just NaN inputs!)
246+
fn generate_nan<F1: Float + FloatConvert<F2>, F2: Float>(
247+
_ecx: &InterpCx<'mir, 'tcx, Self>,
248+
_inputs: &[F1],
249+
) -> F2 {
250+
// By default we always return the preferred NaN.
251+
F2::NAN
252+
}
253+
243254
/// Called before writing the specified `local` of the `frame`.
244255
/// Since writing a ZST is not actually accessing memory or locals, this is never invoked
245256
/// for ZST reads.

compiler/rustc_const_eval/src/interpret/operator.rs

+13-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use rustc_apfloat::Float;
1+
use rustc_apfloat::{Float, FloatConvert};
22
use rustc_middle::mir;
33
use rustc_middle::mir::interpret::{InterpResult, Scalar};
44
use rustc_middle::ty::layout::TyAndLayout;
@@ -104,7 +104,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
104104
(ImmTy::from_bool(res, *self.tcx), false)
105105
}
106106

107-
fn binary_float_op<F: Float + Into<Scalar<M::Provenance>>>(
107+
fn binary_float_op<F: Float + FloatConvert<F> + Into<Scalar<M::Provenance>>>(
108108
&self,
109109
bin_op: mir::BinOp,
110110
layout: TyAndLayout<'tcx>,
@@ -113,18 +113,23 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
113113
) -> (ImmTy<'tcx, M::Provenance>, bool) {
114114
use rustc_middle::mir::BinOp::*;
115115

116+
// Performs appropriate non-deterministic adjustments of NaN results.
117+
let adjust_nan = |f: F| -> F {
118+
if f.is_nan() { M::generate_nan(self, &[l, r]) } else { f }
119+
};
120+
116121
let val = match bin_op {
117122
Eq => ImmTy::from_bool(l == r, *self.tcx),
118123
Ne => ImmTy::from_bool(l != r, *self.tcx),
119124
Lt => ImmTy::from_bool(l < r, *self.tcx),
120125
Le => ImmTy::from_bool(l <= r, *self.tcx),
121126
Gt => ImmTy::from_bool(l > r, *self.tcx),
122127
Ge => ImmTy::from_bool(l >= r, *self.tcx),
123-
Add => ImmTy::from_scalar((l + r).value.into(), layout),
124-
Sub => ImmTy::from_scalar((l - r).value.into(), layout),
125-
Mul => ImmTy::from_scalar((l * r).value.into(), layout),
126-
Div => ImmTy::from_scalar((l / r).value.into(), layout),
127-
Rem => ImmTy::from_scalar((l % r).value.into(), layout),
128+
Add => ImmTy::from_scalar(adjust_nan((l + r).value).into(), layout),
129+
Sub => ImmTy::from_scalar(adjust_nan((l - r).value).into(), layout),
130+
Mul => ImmTy::from_scalar(adjust_nan((l * r).value).into(), layout),
131+
Div => ImmTy::from_scalar(adjust_nan((l / r).value).into(), layout),
132+
Rem => ImmTy::from_scalar(adjust_nan((l % r).value).into(), layout),
128133
_ => span_bug!(self.cur_span(), "invalid float op: `{:?}`", bin_op),
129134
};
130135
(val, false)
@@ -456,6 +461,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
456461
Ok((ImmTy::from_bool(res, *self.tcx), false))
457462
}
458463
ty::Float(fty) => {
464+
// No NaN adjustment here, `-` is a bitwise operation!
459465
let res = match (un_op, fty) {
460466
(Neg, FloatTy::F32) => Scalar::from_f32(-val.to_f32()?),
461467
(Neg, FloatTy::F64) => Scalar::from_f64(-val.to_f64()?),

src/tools/miri/src/machine.rs

+8
Original file line numberDiff line numberDiff line change
@@ -1001,6 +1001,14 @@ impl<'mir, 'tcx> Machine<'mir, 'tcx> for MiriMachine<'mir, 'tcx> {
10011001
ecx.binary_ptr_op(bin_op, left, right)
10021002
}
10031003

1004+
#[inline(always)]
1005+
fn generate_nan<F1: rustc_apfloat::Float + rustc_apfloat::FloatConvert<F2>, F2: rustc_apfloat::Float>(
1006+
ecx: &InterpCx<'mir, 'tcx, Self>,
1007+
inputs: &[F1],
1008+
) -> F2 {
1009+
ecx.generate_nan(inputs)
1010+
}
1011+
10041012
fn thread_local_static_base_pointer(
10051013
ecx: &mut MiriInterpCx<'mir, 'tcx>,
10061014
def_id: DefId,

src/tools/miri/src/operator.rs

+58-20
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,16 @@
1+
use std::iter;
2+
13
use log::trace;
24

5+
use rand::{seq::IteratorRandom, Rng};
6+
use rustc_apfloat::{Float, FloatConvert};
37
use rustc_middle::mir;
48
use rustc_target::abi::Size;
59

610
use crate::*;
711

8-
pub trait EvalContextExt<'tcx> {
9-
fn binary_ptr_op(
10-
&self,
11-
bin_op: mir::BinOp,
12-
left: &ImmTy<'tcx, Provenance>,
13-
right: &ImmTy<'tcx, Provenance>,
14-
) -> InterpResult<'tcx, (ImmTy<'tcx, Provenance>, bool)>;
15-
}
16-
17-
impl<'mir, 'tcx> EvalContextExt<'tcx> for super::MiriInterpCx<'mir, 'tcx> {
12+
impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {}
13+
pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
1814
fn binary_ptr_op(
1915
&self,
2016
bin_op: mir::BinOp,
@@ -23,12 +19,13 @@ impl<'mir, 'tcx> EvalContextExt<'tcx> for super::MiriInterpCx<'mir, 'tcx> {
2319
) -> InterpResult<'tcx, (ImmTy<'tcx, Provenance>, bool)> {
2420
use rustc_middle::mir::BinOp::*;
2521

22+
let this = self.eval_context_ref();
2623
trace!("ptr_op: {:?} {:?} {:?}", *left, bin_op, *right);
2724

2825
Ok(match bin_op {
2926
Eq | Ne | Lt | Le | Gt | Ge => {
3027
assert_eq!(left.layout.abi, right.layout.abi); // types an differ, e.g. fn ptrs with different `for`
31-
let size = self.pointer_size();
28+
let size = this.pointer_size();
3229
// Just compare the bits. ScalarPairs are compared lexicographically.
3330
// We thus always compare pairs and simply fill scalars up with 0.
3431
let left = match **left {
@@ -50,34 +47,75 @@ impl<'mir, 'tcx> EvalContextExt<'tcx> for super::MiriInterpCx<'mir, 'tcx> {
5047
Ge => left >= right,
5148
_ => bug!(),
5249
};
53-
(ImmTy::from_bool(res, *self.tcx), false)
50+
(ImmTy::from_bool(res, *this.tcx), false)
5451
}
5552

5653
// Some more operations are possible with atomics.
5754
// The return value always has the provenance of the *left* operand.
5855
Add | Sub | BitOr | BitAnd | BitXor => {
5956
assert!(left.layout.ty.is_unsafe_ptr());
6057
assert!(right.layout.ty.is_unsafe_ptr());
61-
let ptr = left.to_scalar().to_pointer(self)?;
58+
let ptr = left.to_scalar().to_pointer(this)?;
6259
// We do the actual operation with usize-typed scalars.
63-
let left = ImmTy::from_uint(ptr.addr().bytes(), self.machine.layouts.usize);
60+
let left = ImmTy::from_uint(ptr.addr().bytes(), this.machine.layouts.usize);
6461
let right = ImmTy::from_uint(
65-
right.to_scalar().to_target_usize(self)?,
66-
self.machine.layouts.usize,
62+
right.to_scalar().to_target_usize(this)?,
63+
this.machine.layouts.usize,
6764
);
68-
let (result, overflowing) = self.overflowing_binary_op(bin_op, &left, &right)?;
65+
let (result, overflowing) = this.overflowing_binary_op(bin_op, &left, &right)?;
6966
// Construct a new pointer with the provenance of `ptr` (the LHS).
7067
let result_ptr = Pointer::new(
7168
ptr.provenance,
72-
Size::from_bytes(result.to_scalar().to_target_usize(self)?),
69+
Size::from_bytes(result.to_scalar().to_target_usize(this)?),
7370
);
7471
(
75-
ImmTy::from_scalar(Scalar::from_maybe_pointer(result_ptr, self), left.layout),
72+
ImmTy::from_scalar(Scalar::from_maybe_pointer(result_ptr, this), left.layout),
7673
overflowing,
7774
)
7875
}
7976

80-
_ => span_bug!(self.cur_span(), "Invalid operator on pointers: {:?}", bin_op),
77+
_ => span_bug!(this.cur_span(), "Invalid operator on pointers: {:?}", bin_op),
8178
})
8279
}
80+
81+
fn generate_nan<F1: Float + FloatConvert<F2>, F2: Float>(&self, inputs: &[F1]) -> F2 {
82+
/// Make the given NaN a signaling NaN.
83+
/// Returns `None` if this would not result in a NaN.
84+
fn make_signaling<F: Float>(f: F) -> Option<F> {
85+
// The quiet/signaling bit is the leftmost bit in the mantissa.
86+
// That's position `PRECISION-1`, since `PRECISION` includes the fixed leading 1 bit,
87+
// and then we subtract 1 more since this is 0-indexed.
88+
let quiet_bit_mask = 1 << (F::PRECISION - 2);
89+
// Unset the bit. Double-check that this wasn't the last bit set in the payload.
90+
// (which would turn the NaN into an infinity).
91+
let f = F::from_bits(f.to_bits() & !quiet_bit_mask);
92+
if f.is_nan() { Some(f) } else { None }
93+
}
94+
95+
let this = self.eval_context_ref();
96+
let mut rand = this.machine.rng.borrow_mut();
97+
// Assemble an iterator of possible NaNs: preferred, quieting propagation, unchanged propagation.
98+
// On some targets there are more possibilities; for now we just generate those options that
99+
// are possible everywhere.
100+
let preferred_nan = F2::qnan(Some(0));
101+
let nans = iter::once(preferred_nan)
102+
.chain(inputs.iter().filter(|f| f.is_nan()).map(|&f| {
103+
// Regular apfloat cast is quieting.
104+
f.convert(&mut false).value
105+
}))
106+
.chain(inputs.iter().filter(|f| f.is_signaling()).filter_map(|&f| {
107+
let f: F2 = f.convert(&mut false).value;
108+
// We have to de-quiet this again for unchanged propagation.
109+
make_signaling(f)
110+
}));
111+
// Pick one of the NaNs.
112+
let nan = nans.choose(&mut *rand).unwrap();
113+
// Non-deterministically flip the sign.
114+
if rand.gen() {
115+
// This will properly flip even for NaN.
116+
-nan
117+
} else {
118+
nan
119+
}
120+
}
83121
}

0 commit comments

Comments
 (0)