Skip to content

Commit aa7e5ca

Browse files
committed
Refactor UninhabitedEnumBranching to mark targets unreachable.
1 parent 28f6ea4 commit aa7e5ca

6 files changed

+83
-67
lines changed

compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs

+46-61
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
use crate::MirPass;
44
use rustc_data_structures::fx::FxHashSet;
55
use rustc_middle::mir::{
6-
BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, SwitchTargets, Terminator,
7-
TerminatorKind,
6+
BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, Terminator, TerminatorKind,
87
};
98
use rustc_middle::ty::layout::TyAndLayout;
109
use rustc_middle::ty::{Ty, TyCtxt};
@@ -30,18 +29,16 @@ fn get_switched_on_type<'tcx>(
3029
let terminator = block_data.terminator();
3130

3231
// Only bother checking blocks which terminate by switching on a local.
33-
if let Some(local) = get_discriminant_local(&terminator.kind) {
34-
let stmt_before_term = (!block_data.statements.is_empty())
35-
.then(|| &block_data.statements[block_data.statements.len() - 1].kind);
36-
37-
if let Some(StatementKind::Assign(box (l, Rvalue::Discriminant(place)))) = stmt_before_term
38-
{
39-
if l.as_local() == Some(local) {
40-
let ty = place.ty(body, tcx).ty;
41-
if ty.is_enum() {
42-
return Some(ty);
43-
}
44-
}
32+
let local = get_discriminant_local(&terminator.kind)?;
33+
34+
let stmt_before_term = block_data.statements.last()?;
35+
36+
if let StatementKind::Assign(box (l, Rvalue::Discriminant(place))) = stmt_before_term.kind
37+
&& l.as_local() == Some(local)
38+
{
39+
let ty = place.ty(body, tcx).ty;
40+
if ty.is_enum() {
41+
return Some(ty);
4542
}
4643
}
4744

@@ -72,28 +69,6 @@ fn variant_discriminants<'tcx>(
7269
}
7370
}
7471

75-
/// Ensures that the `otherwise` branch leads to an unreachable bb, returning `None` if so and a new
76-
/// bb to use as the new target if not.
77-
fn ensure_otherwise_unreachable<'tcx>(
78-
body: &Body<'tcx>,
79-
targets: &SwitchTargets,
80-
) -> Option<BasicBlockData<'tcx>> {
81-
let otherwise = targets.otherwise();
82-
let bb = &body.basic_blocks[otherwise];
83-
if bb.terminator().kind == TerminatorKind::Unreachable
84-
&& bb.statements.iter().all(|s| matches!(&s.kind, StatementKind::StorageDead(_)))
85-
{
86-
return None;
87-
}
88-
89-
let mut new_block = BasicBlockData::new(Some(Terminator {
90-
source_info: bb.terminator().source_info,
91-
kind: TerminatorKind::Unreachable,
92-
}));
93-
new_block.is_cleanup = bb.is_cleanup;
94-
Some(new_block)
95-
}
96-
9772
impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
9873
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
9974
sess.mir_opt_level() > 0
@@ -102,13 +77,16 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
10277
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
10378
trace!("UninhabitedEnumBranching starting for {:?}", body.source);
10479

105-
for bb in body.basic_blocks.indices() {
80+
let mut removable_switchs = Vec::new();
81+
82+
for (bb, bb_data) in body.basic_blocks.iter_enumerated() {
10683
trace!("processing block {:?}", bb);
10784

108-
let Some(discriminant_ty) = get_switched_on_type(&body.basic_blocks[bb], tcx, body)
109-
else {
85+
if bb_data.is_cleanup {
11086
continue;
111-
};
87+
}
88+
89+
let Some(discriminant_ty) = get_switched_on_type(&bb_data, tcx, body) else { continue };
11290

11391
let layout = tcx.layout_of(
11492
tcx.param_env_reveal_all_normalized(body.source.def_id()).and(discriminant_ty),
@@ -122,31 +100,38 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
122100

123101
trace!("allowed_variants = {:?}", allowed_variants);
124102

125-
if let TerminatorKind::SwitchInt { targets, .. } =
126-
&mut body.basic_blocks_mut()[bb].terminator_mut().kind
127-
{
128-
let mut new_targets = SwitchTargets::new(
129-
targets.iter().filter(|(val, _)| allowed_variants.contains(val)),
130-
targets.otherwise(),
131-
);
132-
133-
if new_targets.iter().count() == allowed_variants.len() {
134-
if let Some(updated) = ensure_otherwise_unreachable(body, &new_targets) {
135-
let new_otherwise = body.basic_blocks_mut().push(updated);
136-
*new_targets.all_targets_mut().last_mut().unwrap() = new_otherwise;
137-
}
138-
}
103+
let terminator = bb_data.terminator();
104+
let TerminatorKind::SwitchInt { targets, .. } = &terminator.kind else { bug!() };
139105

140-
if let TerminatorKind::SwitchInt { targets, .. } =
141-
&mut body.basic_blocks_mut()[bb].terminator_mut().kind
142-
{
143-
*targets = new_targets;
106+
let mut reachable_count = 0;
107+
for (index, (val, _)) in targets.iter().enumerate() {
108+
if allowed_variants.contains(&val) {
109+
reachable_count += 1;
144110
} else {
145-
unreachable!()
111+
removable_switchs.push((bb, index));
146112
}
147-
} else {
148-
unreachable!()
149113
}
114+
115+
if reachable_count == allowed_variants.len() {
116+
removable_switchs.push((bb, targets.iter().count()));
117+
}
118+
}
119+
120+
if removable_switchs.is_empty() {
121+
return;
122+
}
123+
124+
let new_block = BasicBlockData::new(Some(Terminator {
125+
source_info: body.basic_blocks[removable_switchs[0].0].terminator().source_info,
126+
kind: TerminatorKind::Unreachable,
127+
}));
128+
let unreachable_block = body.basic_blocks.as_mut().push(new_block);
129+
130+
for (bb, index) in removable_switchs {
131+
let bb = &mut body.basic_blocks.as_mut()[bb];
132+
let terminator = bb.terminator_mut();
133+
let TerminatorKind::SwitchInt { targets, .. } = &mut terminator.kind else { bug!() };
134+
targets.all_targets_mut()[index] = unreachable_block;
150135
}
151136
}
152137
}

tests/mir-opt/uninhabited_enum_branching.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir

+7-1
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,20 @@ fn main() -> () {
1212
let mut _8: isize;
1313
let _9: &str;
1414
let mut _10: bool;
15+
let mut _11: bool;
16+
let mut _12: bool;
1517

1618
bb0: {
1719
StorageLive(_1);
1820
StorageLive(_2);
1921
_2 = Test1::C;
2022
_3 = discriminant(_2);
21-
_10 = Eq(_3, const 2_isize);
23+
_10 = Ne(_3, const 0_isize);
2224
assume(move _10);
25+
_11 = Ne(_3, const 1_isize);
26+
assume(move _11);
27+
_12 = Eq(_3, const 2_isize);
28+
assume(move _12);
2329
StorageLive(_5);
2430
_5 = const "C";
2531
_1 = &(*_5);

tests/mir-opt/uninhabited_enum_branching.main.UninhabitedEnumBranching.diff

+7-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
_2 = Test1::C;
2020
_3 = discriminant(_2);
2121
- switchInt(move _3) -> [0: bb3, 1: bb4, 2: bb1, otherwise: bb2];
22-
+ switchInt(move _3) -> [2: bb1, otherwise: bb2];
22+
+ switchInt(move _3) -> [0: bb9, 1: bb9, 2: bb1, otherwise: bb9];
2323
}
2424

2525
bb1: {
@@ -54,7 +54,8 @@
5454
StorageLive(_7);
5555
_7 = Test2::D;
5656
_8 = discriminant(_7);
57-
switchInt(move _8) -> [4: bb7, 5: bb6, otherwise: bb2];
57+
- switchInt(move _8) -> [4: bb7, 5: bb6, otherwise: bb2];
58+
+ switchInt(move _8) -> [4: bb7, 5: bb6, otherwise: bb9];
5859
}
5960

6061
bb6: {
@@ -75,6 +76,10 @@
7576
StorageDead(_6);
7677
_0 = const ();
7778
return;
79+
+ }
80+
+
81+
+ bb9: {
82+
+ unreachable;
7883
}
7984
}
8085

tests/mir-opt/uninhabited_enum_branching2.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir

+12
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ fn main() -> () {
1515
let _11: &str;
1616
let _12: &str;
1717
let _13: &str;
18+
let mut _14: bool;
19+
let mut _15: bool;
20+
let mut _16: bool;
21+
let mut _17: bool;
1822
scope 1 {
1923
debug plop => _1;
2024
}
@@ -29,6 +33,10 @@ fn main() -> () {
2933
StorageLive(_4);
3034
_4 = &(_1.1: Test1);
3135
_5 = discriminant((*_4));
36+
_16 = Ne(_5, const 0_isize);
37+
assume(move _16);
38+
_17 = Ne(_5, const 1_isize);
39+
assume(move _17);
3240
switchInt(move _5) -> [2: bb3, 3: bb1, otherwise: bb2];
3341
}
3442

@@ -57,6 +65,10 @@ fn main() -> () {
5765
StorageDead(_3);
5866
StorageLive(_9);
5967
_10 = discriminant((_1.1: Test1));
68+
_14 = Ne(_10, const 0_isize);
69+
assume(move _14);
70+
_15 = Ne(_10, const 1_isize);
71+
assume(move _15);
6072
switchInt(move _10) -> [2: bb6, 3: bb5, otherwise: bb2];
6173
}
6274

tests/mir-opt/uninhabited_enum_branching2.main.UninhabitedEnumBranching.diff

+6-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
_4 = &(_1.1: Test1);
3232
_5 = discriminant((*_4));
3333
- switchInt(move _5) -> [0: bb3, 1: bb4, 2: bb5, 3: bb1, otherwise: bb2];
34-
+ switchInt(move _5) -> [2: bb5, 3: bb1, otherwise: bb2];
34+
+ switchInt(move _5) -> [0: bb12, 1: bb12, 2: bb5, 3: bb1, otherwise: bb12];
3535
}
3636

3737
bb1: {
@@ -73,7 +73,7 @@
7373
StorageLive(_9);
7474
_10 = discriminant((_1.1: Test1));
7575
- switchInt(move _10) -> [0: bb8, 1: bb9, 2: bb10, 3: bb7, otherwise: bb2];
76-
+ switchInt(move _10) -> [2: bb10, 3: bb7, otherwise: bb2];
76+
+ switchInt(move _10) -> [0: bb12, 1: bb12, 2: bb10, 3: bb7, otherwise: bb12];
7777
}
7878

7979
bb7: {
@@ -110,6 +110,10 @@
110110
_0 = const ();
111111
StorageDead(_1);
112112
return;
113+
+ }
114+
+
115+
+ bb12: {
116+
+ unreachable;
113117
}
114118
}
115119

tests/mir-opt/uninhabited_fallthrough_elimination.keep_fallthrough.UninhabitedEnumBranching.diff

+5-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
bb0: {
1010
_2 = discriminant(_1);
1111
- switchInt(move _2) -> [0: bb2, 1: bb3, otherwise: bb1];
12-
+ switchInt(move _2) -> [1: bb3, otherwise: bb1];
12+
+ switchInt(move _2) -> [0: bb5, 1: bb3, otherwise: bb1];
1313
}
1414

1515
bb1: {
@@ -29,6 +29,10 @@
2929

3030
bb4: {
3131
return;
32+
+ }
33+
+
34+
+ bb5: {
35+
+ unreachable;
3236
}
3337
}
3438

0 commit comments

Comments
 (0)