Skip to content

Commit 7839cb9

Browse files
committed
Change enum->int casts to not go through MIR casts.
Instead we generate a discriminant rvalue and cast the result of that.
1 parent 0e674b3 commit 7839cb9

File tree

15 files changed

+221
-132
lines changed

15 files changed

+221
-132
lines changed

Diff for: compiler/rustc_codegen_cranelift/src/base.rs

-23
Original file line numberDiff line numberDiff line change
@@ -635,29 +635,6 @@ fn codegen_stmt<'tcx>(
635635
let (ptr, _extra) = operand.load_scalar_pair(fx);
636636
lval.write_cvalue(fx, CValue::by_val(ptr, dest_layout))
637637
}
638-
} else if let ty::Adt(adt_def, _substs) = from_ty.kind() {
639-
// enum -> discriminant value
640-
assert!(adt_def.is_enum());
641-
match to_ty.kind() {
642-
ty::Uint(_) | ty::Int(_) => {}
643-
_ => unreachable!("cast adt {} -> {}", from_ty, to_ty),
644-
}
645-
let to_clif_ty = fx.clif_type(to_ty).unwrap();
646-
647-
let discriminant = crate::discriminant::codegen_get_discriminant(
648-
fx,
649-
operand,
650-
fx.layout_of(operand.layout().ty.discriminant_ty(fx.tcx)),
651-
)
652-
.load_scalar(fx);
653-
654-
let res = crate::cast::clif_intcast(
655-
fx,
656-
discriminant,
657-
to_clif_ty,
658-
to_ty.is_signed(),
659-
);
660-
lval.write_cvalue(fx, CValue::by_val(res, dest_layout));
661638
} else {
662639
let to_clif_ty = fx.clif_type(to_ty).unwrap();
663640
let from = operand.load_scalar(fx);

Diff for: compiler/rustc_codegen_ssa/src/mir/rvalue.rs

+8-70
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ use rustc_middle::ty::cast::{CastTy, IntTy};
1212
use rustc_middle::ty::layout::{HasTyCtxt, LayoutOf};
1313
use rustc_middle::ty::{self, adjustment::PointerCast, Instance, Ty, TyCtxt};
1414
use rustc_span::source_map::{Span, DUMMY_SP};
15-
use rustc_target::abi::{Abi, Int, Variants};
1615

1716
impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
1817
#[instrument(level = "debug", skip(self, bx))]
@@ -283,74 +282,12 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
283282
CastTy::from_ty(operand.layout.ty).expect("bad input type for cast");
284283
let r_t_out = CastTy::from_ty(cast.ty).expect("bad output type for cast");
285284
let ll_t_in = bx.cx().immediate_backend_type(operand.layout);
286-
match operand.layout.variants {
287-
Variants::Single { index } => {
288-
if let Some(discr) =
289-
operand.layout.ty.discriminant_for_variant(bx.tcx(), index)
290-
{
291-
let discr_layout = bx.cx().layout_of(discr.ty);
292-
let discr_t = bx.cx().immediate_backend_type(discr_layout);
293-
let discr_val = bx.cx().const_uint_big(discr_t, discr.val);
294-
let discr_val =
295-
bx.intcast(discr_val, ll_t_out, discr.ty.is_signed());
296-
297-
return (
298-
bx,
299-
OperandRef {
300-
val: OperandValue::Immediate(discr_val),
301-
layout: cast,
302-
},
303-
);
304-
}
305-
}
306-
Variants::Multiple { .. } => {}
307-
}
308285
let llval = operand.immediate();
309286

310-
let mut signed = false;
311-
if let Abi::Scalar(scalar) = operand.layout.abi {
312-
if let Int(_, s) = scalar.primitive() {
313-
// We use `i1` for bytes that are always `0` or `1`,
314-
// e.g., `#[repr(i8)] enum E { A, B }`, but we can't
315-
// let LLVM interpret the `i1` as signed, because
316-
// then `i1 1` (i.e., E::B) is effectively `i8 -1`.
317-
signed = !scalar.is_bool() && s;
318-
319-
if !scalar.is_always_valid(bx.cx())
320-
&& scalar.valid_range(bx.cx()).end
321-
>= scalar.valid_range(bx.cx()).start
322-
{
323-
// We want `table[e as usize ± k]` to not
324-
// have bound checks, and this is the most
325-
// convenient place to put the `assume`s.
326-
if scalar.valid_range(bx.cx()).start > 0 {
327-
let enum_value_lower_bound = bx.cx().const_uint_big(
328-
ll_t_in,
329-
scalar.valid_range(bx.cx()).start,
330-
);
331-
let cmp_start = bx.icmp(
332-
IntPredicate::IntUGE,
333-
llval,
334-
enum_value_lower_bound,
335-
);
336-
bx.assume(cmp_start);
337-
}
338-
339-
let enum_value_upper_bound = bx
340-
.cx()
341-
.const_uint_big(ll_t_in, scalar.valid_range(bx.cx()).end);
342-
let cmp_end = bx.icmp(
343-
IntPredicate::IntULE,
344-
llval,
345-
enum_value_upper_bound,
346-
);
347-
bx.assume(cmp_end);
348-
}
349-
}
350-
}
351-
352287
let newval = match (r_t_in, r_t_out) {
353-
(CastTy::Int(_), CastTy::Int(_)) => bx.intcast(llval, ll_t_out, signed),
288+
(CastTy::Int(i), CastTy::Int(_)) => {
289+
bx.intcast(llval, ll_t_out, matches!(i, IntTy::I))
290+
}
354291
(CastTy::Float, CastTy::Float) => {
355292
let srcsz = bx.cx().float_width(ll_t_in);
356293
let dstsz = bx.cx().float_width(ll_t_out);
@@ -362,8 +299,8 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
362299
llval
363300
}
364301
}
365-
(CastTy::Int(_), CastTy::Float) => {
366-
if signed {
302+
(CastTy::Int(i), CastTy::Float) => {
303+
if matches!(i, IntTy::I) {
367304
bx.sitofp(llval, ll_t_out)
368305
} else {
369306
bx.uitofp(llval, ll_t_out)
@@ -372,8 +309,9 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
372309
(CastTy::Ptr(_) | CastTy::FnPtr, CastTy::Ptr(_)) => {
373310
bx.pointercast(llval, ll_t_out)
374311
}
375-
(CastTy::Int(_), CastTy::Ptr(_)) => {
376-
let usize_llval = bx.intcast(llval, bx.cx().type_isize(), signed);
312+
(CastTy::Int(i), CastTy::Ptr(_)) => {
313+
let usize_llval =
314+
bx.intcast(llval, bx.cx().type_isize(), matches!(i, IntTy::I));
377315
bx.inttoptr(usize_llval, ll_t_out)
378316
}
379317
(CastTy::Float, CastTy::Int(IntTy::I)) => {

Diff for: compiler/rustc_const_eval/src/interpret/cast.rs

+2-23
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use rustc_middle::mir::CastKind;
88
use rustc_middle::ty::adjustment::PointerCast;
99
use rustc_middle::ty::layout::{IntegerExt, LayoutOf, TyAndLayout};
1010
use rustc_middle::ty::{self, FloatTy, Ty, TypeAndMut};
11-
use rustc_target::abi::{Integer, Variants};
11+
use rustc_target::abi::Integer;
1212
use rustc_type_ir::sty::TyKind::*;
1313

1414
use super::{
@@ -127,12 +127,10 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
127127
Float(FloatTy::F64) => {
128128
return Ok(self.cast_from_float(src.to_scalar()?.to_f64()?, cast_ty).into());
129129
}
130-
// The rest is integer/pointer-"like", including fn ptr casts and casts from enums that
131-
// are represented as integers.
130+
// The rest is integer/pointer-"like", including fn ptr casts
132131
_ => assert!(
133132
src.layout.ty.is_bool()
134133
|| src.layout.ty.is_char()
135-
|| src.layout.ty.is_enum()
136134
|| src.layout.ty.is_integral()
137135
|| src.layout.ty.is_any_ptr(),
138136
"Unexpected cast from type {:?}",
@@ -142,25 +140,6 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
142140

143141
// # First handle non-scalar source values.
144142

145-
// Handle cast from a ZST enum (0 or 1 variants).
146-
match src.layout.variants {
147-
Variants::Single { index } => {
148-
if src.layout.abi.is_uninhabited() {
149-
// This is dead code, because an uninhabited enum is UB to
150-
// instantiate.
151-
throw_ub!(Unreachable);
152-
}
153-
if let Some(discr) = src.layout.ty.discriminant_for_variant(*self.tcx, index) {
154-
assert!(src.layout.is_zst());
155-
let discr_layout = self.layout_of(discr.ty)?;
156-
157-
let scalar = Scalar::from_uint(discr.val, discr_layout.layout.size());
158-
return Ok(self.cast_from_int_like(scalar, discr_layout, cast_ty)?.into());
159-
}
160-
}
161-
Variants::Multiple { .. } => {}
162-
}
163-
164143
// Handle casting any ptr to raw ptr (might be a fat ptr).
165144
if src.layout.ty.is_any_ptr() && cast_ty.is_unsafe_ptr() {
166145
let dest_layout = self.layout_of(cast_ty)?;

Diff for: compiler/rustc_const_eval/src/transform/validate.rs

+28-4
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ use rustc_middle::mir::interpret::Scalar;
77
use rustc_middle::mir::visit::NonUseContext::VarDebugInfo;
88
use rustc_middle::mir::visit::{PlaceContext, Visitor};
99
use rustc_middle::mir::{
10-
traversal, AggregateKind, BasicBlock, BinOp, Body, BorrowKind, Local, Location, MirPass,
11-
MirPhase, Operand, Place, PlaceElem, PlaceRef, ProjectionElem, Rvalue, SourceScope, Statement,
12-
StatementKind, Terminator, TerminatorKind, UnOp, START_BLOCK,
10+
traversal, AggregateKind, BasicBlock, BinOp, Body, BorrowKind, CastKind, Local, Location,
11+
MirPass, MirPhase, Operand, Place, PlaceElem, PlaceRef, ProjectionElem, Rvalue, SourceScope,
12+
Statement, StatementKind, Terminator, TerminatorKind, UnOp, START_BLOCK,
1313
};
1414
use rustc_middle::ty::fold::BottomUpFolder;
1515
use rustc_middle::ty::{self, InstanceDef, ParamEnv, Ty, TyCtxt, TypeFoldable};
@@ -361,6 +361,7 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
361361
);
362362
}
363363
}
364+
Rvalue::Ref(..) => {}
364365
Rvalue::Len(p) => {
365366
let pty = p.ty(&self.body.local_decls, self.tcx).ty;
366367
check_kinds!(
@@ -503,7 +504,30 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
503504
let a = operand.ty(&self.body.local_decls, self.tcx);
504505
check_kinds!(a, "Cannot shallow init type {:?}", ty::RawPtr(..));
505506
}
506-
_ => {}
507+
Rvalue::Cast(kind, operand, target_type) => {
508+
match kind {
509+
CastKind::Misc => {
510+
let op_ty = operand.ty(self.body, self.tcx);
511+
if op_ty.is_enum() {
512+
self.fail(
513+
location,
514+
format!(
515+
"enum -> int casts should go through `Rvalue::Discriminant`: {operand:?}:{op_ty} as {target_type}",
516+
),
517+
);
518+
}
519+
}
520+
// Nothing to check here
521+
CastKind::PointerFromExposedAddress
522+
| CastKind::PointerExposeAddress
523+
| CastKind::Pointer(_) => {}
524+
}
525+
}
526+
Rvalue::Repeat(_, _)
527+
| Rvalue::ThreadLocalRef(_)
528+
| Rvalue::AddressOf(_, _)
529+
| Rvalue::NullaryOp(_, _)
530+
| Rvalue::Discriminant(_) => {}
507531
}
508532
self.super_rvalue(rvalue, location);
509533
}

Diff for: compiler/rustc_mir_build/src/build/expr/as_rvalue.rs

+25-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! See docs in `build/expr/mod.rs`.
22
33
use rustc_index::vec::Idx;
4+
use rustc_middle::ty::util::IntTypeExt;
45

56
use crate::build::expr::as_place::PlaceBase;
67
use crate::build::expr::category::{Category, RvalueFunc};
@@ -190,7 +191,30 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
190191
}
191192
ExprKind::Cast { source } => {
192193
let source = &this.thir[source];
193-
let from_ty = CastTy::from_ty(source.ty);
194+
195+
// Casting an enum to an integer is equivalent to computing the discriminant and casting the
196+
// discriminant. Previously every backend had to repeat the logic for this operation. Now we
197+
// create all the steps directly in MIR with operations all backends need to support anyway.
198+
let (source, ty) = if let ty::Adt(adt_def, ..) = source.ty.kind() && adt_def.is_enum() {
199+
let discr_ty = adt_def.repr().discr_type().to_ty(this.tcx);
200+
let place = unpack!(block = this.as_place(block, source));
201+
let discr = this.temp(discr_ty, source.span);
202+
this.cfg.push_assign(
203+
block,
204+
source_info,
205+
discr,
206+
Rvalue::Discriminant(place),
207+
);
208+
209+
(Operand::Move(discr), discr_ty)
210+
} else {
211+
let ty = source.ty;
212+
let source = unpack!(
213+
block = this.as_operand(block, scope, source, None, NeedsTemporary::No)
214+
);
215+
(source, ty)
216+
};
217+
let from_ty = CastTy::from_ty(ty);
194218
let cast_ty = CastTy::from_ty(expr.ty);
195219
let cast_kind = match (from_ty, cast_ty) {
196220
(Some(CastTy::Ptr(_) | CastTy::FnPtr), Some(CastTy::Int(_))) => {
@@ -201,9 +225,6 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
201225
}
202226
(_, _) => CastKind::Misc,
203227
};
204-
let source = unpack!(
205-
block = this.as_operand(block, scope, source, None, NeedsTemporary::No)
206-
);
207228
block.and(Rvalue::Cast(cast_kind, source, expr.ty))
208229
}
209230
ExprKind::Pointer { cast, source } => {

Diff for: src/test/codegen/enum-bounds-check-derived-idx.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@ pub enum Bar {
1212
// CHECK-LABEL: @lookup_inc
1313
#[no_mangle]
1414
pub fn lookup_inc(buf: &[u8; 5], f: Bar) -> u8 {
15-
// CHECK-NOT: panic_bounds_check
15+
// FIXME: panic check can be removed by adding the assumes back after https://github.com/rust-lang/rust/pull/98332
16+
// CHECK: panic_bounds_check
1617
buf[f as usize + 1]
1718
}
1819

1920
// CHECK-LABEL: @lookup_dec
2021
#[no_mangle]
2122
pub fn lookup_dec(buf: &[u8; 5], f: Bar) -> u8 {
22-
// CHECK-NOT: panic_bounds_check
23+
// FIXME: panic check can be removed by adding the assumes back after https://github.com/rust-lang/rust/pull/98332
24+
// CHECK: panic_bounds_check
2325
buf[f as usize - 1]
2426
}

Diff for: src/test/codegen/enum-bounds-check-issue-13926.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub enum Exception {
1313
// CHECK-LABEL: @access
1414
#[no_mangle]
1515
pub fn access(array: &[usize; 12], exc: Exception) -> usize {
16-
// CHECK-NOT: panic_bounds_check
16+
// FIXME: panic check can be removed by adding the assumes back after https://github.com/rust-lang/rust/pull/98332
17+
// CHECK: panic_bounds_check
1718
array[(exc as u8 - 4) as usize]
1819
}

Diff for: src/test/codegen/enum-bounds-check-issue-82871.rs

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// compile-flags: -O
1+
// compile-flags: -C opt-level=0
22

33
#![crate_type = "lib"]
44

@@ -9,7 +9,10 @@ pub enum E {
99

1010
// CHECK-LABEL: @index
1111
#[no_mangle]
12-
pub fn index(x: &[u32; 3], ind: E) -> u32{
13-
// CHECK-NOT: panic_bounds_check
12+
pub fn index(x: &[u32; 3], ind: E) -> u32 {
13+
// Canary: we should be able to optimize out the bounds check, but we need
14+
// to track the range of the discriminant result in order to be able to do that.
15+
// oli-obk tried to add that, but that caused miscompilations all over the place.
16+
// CHECK: panic_bounds_check
1417
x[ind as usize]
1518
}

Diff for: src/test/codegen/enum-bounds-check.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ pub enum Bar {
2121
// CHECK-LABEL: @lookup_unmodified
2222
#[no_mangle]
2323
pub fn lookup_unmodified(buf: &[u8; 5], f: Bar) -> u8 {
24-
// CHECK-NOT: panic_bounds_check
24+
// FIXME: panic check can be removed by adding the assumes back after https://github.com/rust-lang/rust/pull/98332
25+
// CHECK: panic_bounds_check
2526
buf[f as usize]
2627
}

Diff for: src/test/mir-opt/enum_cast.bar.mir_map.0.mir

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// MIR for `bar` 0 mir_map
2+
3+
fn bar(_1: Bar) -> usize {
4+
debug bar => _1; // in scope 0 at $DIR/enum_cast.rs:22:8: 22:11
5+
let mut _0: usize; // return place in scope 0 at $DIR/enum_cast.rs:22:21: 22:26
6+
let mut _2: isize; // in scope 0 at $DIR/enum_cast.rs:23:5: 23:8
7+
8+
bb0: {
9+
_2 = discriminant(_1); // scope 0 at $DIR/enum_cast.rs:23:5: 23:17
10+
_0 = move _2 as usize (Misc); // scope 0 at $DIR/enum_cast.rs:23:5: 23:17
11+
return; // scope 0 at $DIR/enum_cast.rs:24:2: 24:2
12+
}
13+
}

Diff for: src/test/mir-opt/enum_cast.boo.mir_map.0.mir

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// MIR for `boo` 0 mir_map
2+
3+
fn boo(_1: Boo) -> usize {
4+
debug boo => _1; // in scope 0 at $DIR/enum_cast.rs:26:8: 26:11
5+
let mut _0: usize; // return place in scope 0 at $DIR/enum_cast.rs:26:21: 26:26
6+
let mut _2: u8; // in scope 0 at $DIR/enum_cast.rs:27:5: 27:8
7+
8+
bb0: {
9+
_2 = discriminant(_1); // scope 0 at $DIR/enum_cast.rs:27:5: 27:17
10+
_0 = move _2 as usize (Misc); // scope 0 at $DIR/enum_cast.rs:27:5: 27:17
11+
return; // scope 0 at $DIR/enum_cast.rs:28:2: 28:2
12+
}
13+
}

0 commit comments

Comments
 (0)