Skip to content

Commit dddf046

Browse files
committed
mir-opt: Merge all branch BBs into a single copy statement
1 parent b2cedc4 commit dddf046

13 files changed

+602
-61
lines changed

Diff for: compiler/rustc_mir_transform/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ mod lower_intrinsics;
8585
mod lower_slice_len;
8686
mod match_branches;
8787
mod mentioned_items;
88+
mod merge_branches;
8889
mod multiple_return_terminators;
8990
mod nrvo;
9091
mod post_drop_elaboration;
@@ -609,6 +610,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
609610
&dead_store_elimination::DeadStoreElimination::Initial,
610611
&gvn::GVN,
611612
&simplify::SimplifyLocals::AfterGVN,
613+
&merge_branches::MergeBranchSimplification,
612614
&dataflow_const_prop::DataflowConstProp,
613615
&single_use_consts::SingleUseConsts,
614616
&o1(simplify_branches::SimplifyConstCondition::AfterConstProp),

Diff for: compiler/rustc_mir_transform/src/merge_branches.rs

+279
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
//! This pass attempts to merge all branches to eliminate switch terminator.
2+
//! Ideally, we could combine it with `MatchBranchSimplification`, as these two passes
3+
//! match and merge statements with different patterns. Given the compile time and
4+
//! code complexity, we have not merged them into a more general pass for now.
5+
use rustc_const_eval::const_eval::mk_eval_cx_for_const_val;
6+
use rustc_index::bit_set::BitSet;
7+
use rustc_middle::mir::patch::MirPatch;
8+
use rustc_middle::mir::*;
9+
use rustc_middle::ty;
10+
use rustc_middle::ty::util::Discr;
11+
use rustc_middle::ty::{ParamEnv, TyCtxt};
12+
use rustc_mir_dataflow::impls::{MaybeTransitiveLiveLocals, borrowed_locals};
13+
use rustc_mir_dataflow::{Analysis, ResultsCursor};
14+
15+
pub(super) struct MergeBranchSimplification;
16+
17+
impl<'tcx> crate::MirPass<'tcx> for MergeBranchSimplification {
18+
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
19+
sess.mir_opt_level() >= 2
20+
}
21+
22+
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
23+
let def_id = body.source.def_id();
24+
let param_env = tcx.param_env_reveal_all_normalized(def_id);
25+
26+
let borrowed_locals = borrowed_locals(body);
27+
let mut maybe_live: ResultsCursor<'_, '_, MaybeTransitiveLiveLocals<'_>> =
28+
MaybeTransitiveLiveLocals::new(&borrowed_locals)
29+
.into_engine(tcx, body)
30+
.iterate_to_fixpoint()
31+
.into_results_cursor(body);
32+
for i in 0..body.basic_blocks.len() {
33+
let bbs = &*body.basic_blocks;
34+
let switch_bb_idx = BasicBlock::from_usize(i);
35+
let Some((switch_discr, targets)) = bbs[switch_bb_idx].terminator().kind.as_switch()
36+
else {
37+
continue;
38+
};
39+
// Check if the copy source matches the following pattern.
40+
// _2 = discriminant(*_1); // "*_1" is the expected the copy source.
41+
// switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1];
42+
let Some(&Statement {
43+
kind: StatementKind::Assign(box (discr_place, Rvalue::Discriminant(src_place))),
44+
..
45+
}) = bbs[switch_bb_idx].statements.last()
46+
else {
47+
continue;
48+
};
49+
if switch_discr.place() != Some(discr_place) {
50+
continue;
51+
}
52+
let src_ty = src_place.ty(body.local_decls(), tcx);
53+
if !src_ty.ty.is_enum() || src_ty.variant_index.is_some() {
54+
continue;
55+
}
56+
// We require that the possible target blocks all be distinct.
57+
if !targets.is_distinct() {
58+
continue;
59+
}
60+
if !bbs[targets.otherwise()].is_empty_unreachable() {
61+
continue;
62+
}
63+
// Check that destinations are identical, and if not, then don't optimize this block.
64+
let mut targets_iter = targets.iter();
65+
let first_terminator_kind = &bbs[targets_iter.next().unwrap().1].terminator().kind;
66+
if !targets_iter.all(|(_, other_target)| {
67+
first_terminator_kind == &bbs[other_target].terminator().kind
68+
}) {
69+
continue;
70+
}
71+
if let Some(dest_place) = can_simplify_to_copy(
72+
tcx,
73+
param_env,
74+
body,
75+
targets,
76+
src_place,
77+
src_ty,
78+
&borrowed_locals,
79+
&mut maybe_live,
80+
) {
81+
let statement_index = bbs[switch_bb_idx].statements.len();
82+
let parent_end = Location { block: switch_bb_idx, statement_index };
83+
let mut patch = MirPatch::new(body);
84+
patch.add_assign(parent_end, dest_place, Rvalue::Use(Operand::Copy(src_place)));
85+
patch.patch_terminator(switch_bb_idx, first_terminator_kind.clone());
86+
patch.apply(body);
87+
super::simplify::remove_dead_blocks(body);
88+
// After modifying the MIR, the result of `MaybeTransitiveLiveLocals` may become invalid,
89+
// keeping it simple to process only once.
90+
break;
91+
}
92+
}
93+
}
94+
}
95+
96+
/// The GVN simplified
97+
/// ```ignore (syntax-highlighting-only)
98+
/// match a {
99+
/// Foo::A(x) => Foo::A(*x),
100+
/// Foo::B => Foo::B
101+
/// }
102+
/// ```
103+
/// to
104+
/// ```ignore (syntax-highlighting-only)
105+
/// match a {
106+
/// Foo::A(_x) => a, // copy a
107+
/// Foo::B => Foo::B
108+
/// }
109+
/// ```
110+
/// This function answers whether it can be simplified to a copy statement
111+
/// by returning the copy destination.
112+
fn can_simplify_to_copy<'tcx>(
113+
tcx: TyCtxt<'tcx>,
114+
param_env: ParamEnv<'tcx>,
115+
body: &Body<'tcx>,
116+
targets: &SwitchTargets,
117+
src_place: Place<'tcx>,
118+
src_ty: tcx::PlaceTy<'tcx>,
119+
borrowed_locals: &BitSet<Local>,
120+
maybe_live: &mut ResultsCursor<'_, 'tcx, MaybeTransitiveLiveLocals<'_>>,
121+
) -> Option<Place<'tcx>> {
122+
let mut targets_iter = targets.iter();
123+
let dest_place = targets_iter.next().and_then(|(index, target)| {
124+
find_copy_assign(
125+
tcx,
126+
param_env,
127+
body,
128+
index,
129+
target,
130+
src_place,
131+
src_ty,
132+
borrowed_locals,
133+
maybe_live,
134+
)
135+
})?;
136+
let dest_ty = dest_place.ty(body.local_decls(), tcx);
137+
if dest_ty.ty != src_ty.ty || dest_ty.variant_index.is_some() {
138+
return None;
139+
}
140+
if targets_iter.any(|(other_index, other_target)| {
141+
Some(dest_place)
142+
!= find_copy_assign(
143+
tcx,
144+
param_env,
145+
body,
146+
other_index,
147+
other_target,
148+
src_place,
149+
src_ty,
150+
borrowed_locals,
151+
maybe_live,
152+
)
153+
}) {
154+
return None;
155+
}
156+
Some(dest_place)
157+
}
158+
159+
fn find_copy_assign<'tcx>(
160+
tcx: TyCtxt<'tcx>,
161+
param_env: ParamEnv<'tcx>,
162+
body: &Body<'tcx>,
163+
index: u128,
164+
target_block: BasicBlock,
165+
src_place: Place<'tcx>,
166+
src_ty: tcx::PlaceTy<'tcx>,
167+
borrowed_locals: &BitSet<Local>,
168+
maybe_live: &mut ResultsCursor<'_, 'tcx, MaybeTransitiveLiveLocals<'_>>,
169+
) -> Option<Place<'tcx>> {
170+
let statements = &body.basic_blocks[target_block].statements;
171+
if statements.is_empty() {
172+
return None;
173+
}
174+
let assign_stmt = if statements.len() == 1 {
175+
0
176+
} else {
177+
// We are matching a statement copied from the source to the same destination from the BB,
178+
// and dead statements can be ignored.
179+
// We can treat the rvalue is the source if it's equal to the source.
180+
let mut lived_stmts: BitSet<usize> = BitSet::new_filled(statements.len());
181+
let mut expected_assign_stmt = None;
182+
for (statement_index, statement) in statements.iter().enumerate().rev() {
183+
let loc = Location { block: target_block, statement_index };
184+
if let StatementKind::Assign(assign) = &statement.kind {
185+
if !assign.1.is_safe_to_remove() {
186+
return None;
187+
}
188+
}
189+
match &statement.kind {
190+
StatementKind::Assign(box (dest_place, _))
191+
| StatementKind::SetDiscriminant { place: box dest_place, .. }
192+
| StatementKind::Deinit(box dest_place) => {
193+
if dest_place.is_indirect() || borrowed_locals.contains(dest_place.local) {
194+
return None;
195+
}
196+
maybe_live.seek_before_primary_effect(loc);
197+
if !maybe_live.get().contains(dest_place.local) {
198+
lived_stmts.remove(statement_index);
199+
} else if matches!(statement.kind, StatementKind::Assign(_))
200+
&& expected_assign_stmt.is_none()
201+
{
202+
// There is only one statement that cannot be ignored
203+
// that can be used as an expected copy statement.
204+
expected_assign_stmt = Some(statement_index);
205+
lived_stmts.remove(statement_index);
206+
} else {
207+
return None;
208+
}
209+
}
210+
StatementKind::StorageLive(_)
211+
| StatementKind::StorageDead(_)
212+
| StatementKind::Nop => (),
213+
214+
StatementKind::Retag(_, _)
215+
| StatementKind::Coverage(_)
216+
| StatementKind::Intrinsic(_)
217+
| StatementKind::ConstEvalCounter
218+
| StatementKind::PlaceMention(_)
219+
| StatementKind::FakeRead(_)
220+
| StatementKind::AscribeUserType(_, _) => {
221+
return None;
222+
}
223+
}
224+
}
225+
let expected_assign = expected_assign_stmt?;
226+
// We can ignore the paired StorageLive and StorageDead.
227+
let mut storage_live_locals: BitSet<Local> = BitSet::new_empty(body.local_decls.len());
228+
for stmt_index in lived_stmts.iter() {
229+
let statement = &statements[stmt_index];
230+
match &statement.kind {
231+
StatementKind::StorageLive(local) if storage_live_locals.insert(*local) => {}
232+
StatementKind::StorageDead(local) if storage_live_locals.remove(*local) => {}
233+
StatementKind::Nop => {}
234+
_ => return None,
235+
}
236+
}
237+
if !storage_live_locals.is_empty() {
238+
return None;
239+
}
240+
expected_assign
241+
};
242+
let Statement { kind: StatementKind::Assign(box (dest_place, ref rvalue)), .. } =
243+
statements[assign_stmt]
244+
else {
245+
return None;
246+
};
247+
let dest_ty = dest_place.ty(body.local_decls(), tcx);
248+
if dest_ty.ty != src_ty.ty || dest_ty.variant_index.is_some() {
249+
return None;
250+
}
251+
let ty::Adt(def, _) = dest_ty.ty.kind() else {
252+
return None;
253+
};
254+
match rvalue {
255+
// Check if `_3 = const Foo::B` can be transformed to `_3 = copy *_1`.
256+
Rvalue::Use(Operand::Constant(box constant))
257+
if let Const::Val(const_, ty) = constant.const_ =>
258+
{
259+
let (ecx, op) = mk_eval_cx_for_const_val(tcx.at(constant.span), param_env, const_, ty)?;
260+
let variant = ecx.read_discriminant(&op).discard_err()?;
261+
if !def.variants()[variant].fields.is_empty() {
262+
return None;
263+
}
264+
let Discr { val, .. } = ty.discriminant_for_variant(tcx, variant)?;
265+
if val != index {
266+
return None;
267+
}
268+
}
269+
Rvalue::Use(Operand::Copy(place)) if *place == src_place => {}
270+
// Check if `_3 = Foo::B` can be transformed to `_3 = copy *_1`.
271+
Rvalue::Aggregate(box AggregateKind::Adt(_, variant_index, _, _, None), fields)
272+
if fields.is_empty()
273+
&& let Some(Discr { val, .. }) =
274+
src_ty.ty.discriminant_for_variant(tcx, *variant_index)
275+
&& val == index => {}
276+
_ => return None,
277+
}
278+
Some(dest_place)
279+
}

Diff for: tests/codegen/match-optimizes-away.rs

+8-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
//
2-
//@ compile-flags: -O
1+
//@ compile-flags: -O -Cno-prepopulate-passes
2+
33
#![crate_type = "lib"]
44

55
pub enum Three {
@@ -19,8 +19,9 @@ pub enum Four {
1919
#[no_mangle]
2020
pub fn three_valued(x: Three) -> Three {
2121
// CHECK-LABEL: @three_valued
22-
// CHECK-NEXT: {{^.*:$}}
23-
// CHECK-NEXT: ret i8 %0
22+
// CHECK-SAME: (i8{{.*}} [[X:%x]])
23+
// CHECK-NEXT: start:
24+
// CHECK-NEXT: ret i8 [[X]]
2425
match x {
2526
Three::A => Three::A,
2627
Three::B => Three::B,
@@ -31,8 +32,9 @@ pub fn three_valued(x: Three) -> Three {
3132
#[no_mangle]
3233
pub fn four_valued(x: Four) -> Four {
3334
// CHECK-LABEL: @four_valued
34-
// CHECK-NEXT: {{^.*:$}}
35-
// CHECK-NEXT: ret i16 %0
35+
// CHECK-SAME: (i16{{.*}} [[X:%x]])
36+
// CHECK-NEXT: start:
37+
// CHECK-NEXT: ret i16 [[X]]
3638
match x {
3739
Four::A => Four::A,
3840
Four::B => Four::B,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
- // MIR for `no_fields` before MergeBranchSimplification
2+
+ // MIR for `no_fields` after MergeBranchSimplification
3+
4+
fn no_fields(_1: NoFields) -> NoFields {
5+
debug a => _1;
6+
let mut _0: NoFields;
7+
let mut _2: isize;
8+
9+
bb0: {
10+
_2 = discriminant(_1);
11+
- switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1];
12+
+ _0 = copy _1;
13+
+ goto -> bb1;
14+
}
15+
16+
bb1: {
17+
- unreachable;
18+
- }
19+
-
20+
- bb2: {
21+
- _0 = NoFields::B;
22+
- goto -> bb4;
23+
- }
24+
-
25+
- bb3: {
26+
- _0 = NoFields::A;
27+
- goto -> bb4;
28+
- }
29+
-
30+
- bb4: {
31+
return;
32+
}
33+
}
34+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
- // MIR for `no_fields_failed` before MergeBranchSimplification
2+
+ // MIR for `no_fields_failed` after MergeBranchSimplification
3+
4+
fn no_fields_failed(_1: NoFields) -> NoFields {
5+
debug a => _1;
6+
let mut _0: NoFields;
7+
let mut _2: isize;
8+
9+
bb0: {
10+
_2 = discriminant(_1);
11+
switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1];
12+
}
13+
14+
bb1: {
15+
unreachable;
16+
}
17+
18+
bb2: {
19+
_0 = NoFields::A;
20+
goto -> bb4;
21+
}
22+
23+
bb3: {
24+
_0 = NoFields::B;
25+
goto -> bb4;
26+
}
27+
28+
bb4: {
29+
return;
30+
}
31+
}
32+

0 commit comments

Comments
 (0)