Skip to content

Commit 27bc686

Browse files
committed
Auto merge of #129931 - DianQK:match-br-copy, r=<try>
mir-opt: Merge all branch BBs into a single copy statement #128299 simplified ```rust match a { Foo::A(x) => Foo::A(*x), Foo::B => Foo::B } ``` to ```rust match a { Foo::A(x) => a, // copy a Foo::B => Foo::B } ``` The switch branch can be simplified into a single copy statement. This PR implements a relatively general simplification.
2 parents f174fd7 + f05d4af commit 27bc686

15 files changed

+622
-114
lines changed

compiler/rustc_mir_transform/src/dead_store_elimination.rs

+59-36
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,71 @@
1212
//! will still not cause any further changes.
1313
//!
1414
15+
use rustc_index::bit_set::DenseBitSet;
1516
use rustc_middle::bug;
1617
use rustc_middle::mir::visit::Visitor;
1718
use rustc_middle::mir::*;
1819
use rustc_middle::ty::TyCtxt;
19-
use rustc_mir_dataflow::Analysis;
2020
use rustc_mir_dataflow::debuginfo::debuginfo_locals;
2121
use rustc_mir_dataflow::impls::{
2222
LivenessTransferFunction, MaybeTransitiveLiveLocals, borrowed_locals,
2323
};
24+
use rustc_mir_dataflow::{Analysis, ResultsCursor};
2425

2526
use crate::util::is_within_packed;
2627

28+
pub(crate) struct DeadStoreAnalysis<'tcx, 'mir, 'a> {
29+
live: ResultsCursor<'mir, 'tcx, MaybeTransitiveLiveLocals<'a>>,
30+
always_live: &'a DenseBitSet<Local>,
31+
}
32+
33+
impl<'tcx, 'mir, 'a> DeadStoreAnalysis<'tcx, 'mir, 'a> {
34+
pub(crate) fn new(
35+
tcx: TyCtxt<'tcx>,
36+
body: &'mir Body<'tcx>,
37+
always_live: &'a DenseBitSet<Local>,
38+
) -> Self {
39+
let live = MaybeTransitiveLiveLocals::new(&always_live)
40+
.iterate_to_fixpoint(tcx, body, None)
41+
.into_results_cursor(body);
42+
Self { live, always_live }
43+
}
44+
45+
pub(crate) fn is_dead_store(&mut self, loc: Location, stmt_kind: &StatementKind<'tcx>) -> bool {
46+
if let StatementKind::Assign(assign) = stmt_kind {
47+
if !assign.1.is_safe_to_remove() {
48+
return false;
49+
}
50+
}
51+
match stmt_kind {
52+
StatementKind::Assign(box (place, _))
53+
| StatementKind::SetDiscriminant { place: box place, .. }
54+
| StatementKind::Deinit(box place) => {
55+
if !place.is_indirect() && !self.always_live.contains(place.local) {
56+
self.live.seek_before_primary_effect(loc);
57+
!self.live.get().contains(place.local)
58+
} else {
59+
false
60+
}
61+
}
62+
63+
StatementKind::Retag(_, _)
64+
| StatementKind::StorageLive(_)
65+
| StatementKind::StorageDead(_)
66+
| StatementKind::Coverage(_)
67+
| StatementKind::Intrinsic(_)
68+
| StatementKind::ConstEvalCounter
69+
| StatementKind::PlaceMention(_)
70+
| StatementKind::BackwardIncompatibleDropHint { .. }
71+
| StatementKind::Nop => false,
72+
73+
StatementKind::FakeRead(_) | StatementKind::AscribeUserType(_, _) => {
74+
bug!("{:?} not found in this MIR phase!", stmt_kind)
75+
}
76+
}
77+
}
78+
}
79+
2780
/// Performs the optimization on the body
2881
///
2982
/// The `borrowed` set must be a `DenseBitSet` of all the locals that are ever borrowed in this
@@ -36,9 +89,7 @@ fn eliminate<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
3689
let mut always_live = debuginfo_locals(body);
3790
always_live.union(&borrowed_locals);
3891

39-
let mut live = MaybeTransitiveLiveLocals::new(&always_live)
40-
.iterate_to_fixpoint(tcx, body, None)
41-
.into_results_cursor(body);
92+
let mut analysis = DeadStoreAnalysis::new(tcx, body, &always_live);
4293

4394
// For blocks with a call terminator, if an argument copy can be turned into a move,
4495
// record it as (block, argument index).
@@ -50,8 +101,8 @@ fn eliminate<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
50101
let loc = Location { block: bb, statement_index: bb_data.statements.len() };
51102

52103
// Position ourselves between the evaluation of `args` and the write to `destination`.
53-
live.seek_to_block_end(bb);
54-
let mut state = live.get().clone();
104+
analysis.live.seek_to_block_end(bb);
105+
let mut state = analysis.live.get().clone();
55106

56107
for (index, arg) in args.iter().map(|a| &a.node).enumerate().rev() {
57108
if let Operand::Copy(place) = *arg
@@ -73,38 +124,10 @@ fn eliminate<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
73124
LivenessTransferFunction(&mut state).visit_operand(arg, loc);
74125
}
75126
}
76-
77127
for (statement_index, statement) in bb_data.statements.iter().enumerate().rev() {
78128
let loc = Location { block: bb, statement_index };
79-
if let StatementKind::Assign(assign) = &statement.kind {
80-
if !assign.1.is_safe_to_remove() {
81-
continue;
82-
}
83-
}
84-
match &statement.kind {
85-
StatementKind::Assign(box (place, _))
86-
| StatementKind::SetDiscriminant { place: box place, .. }
87-
| StatementKind::Deinit(box place) => {
88-
if !place.is_indirect() && !always_live.contains(place.local) {
89-
live.seek_before_primary_effect(loc);
90-
if !live.get().contains(place.local) {
91-
patch.push(loc);
92-
}
93-
}
94-
}
95-
StatementKind::Retag(_, _)
96-
| StatementKind::StorageLive(_)
97-
| StatementKind::StorageDead(_)
98-
| StatementKind::Coverage(_)
99-
| StatementKind::Intrinsic(_)
100-
| StatementKind::ConstEvalCounter
101-
| StatementKind::PlaceMention(_)
102-
| StatementKind::BackwardIncompatibleDropHint { .. }
103-
| StatementKind::Nop => {}
104-
105-
StatementKind::FakeRead(_) | StatementKind::AscribeUserType(_, _) => {
106-
bug!("{:?} not found in this MIR phase!", statement.kind)
107-
}
129+
if analysis.is_dead_store(loc, &statement.kind) {
130+
patch.push(loc);
108131
}
109132
}
110133
}

compiler/rustc_mir_transform/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ declare_passes! {
157157
mod lower_intrinsics : LowerIntrinsics;
158158
mod lower_slice_len : LowerSliceLenCalls;
159159
mod match_branches : MatchBranchSimplification;
160+
mod merge_branches: MergeBranchSimplification;
160161
mod mentioned_items : MentionedItems;
161162
mod multiple_return_terminators : MultipleReturnTerminators;
162163
mod nrvo : RenameReturnPlace;
@@ -707,6 +708,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
707708
&dead_store_elimination::DeadStoreElimination::Initial,
708709
&gvn::GVN,
709710
&simplify::SimplifyLocals::AfterGVN,
711+
&merge_branches::MergeBranchSimplification,
710712
&dataflow_const_prop::DataflowConstProp,
711713
&single_use_consts::SingleUseConsts,
712714
&o1(simplify_branches::SimplifyConstCondition::AfterConstProp),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
//! This pass attempts to merge all branches to eliminate switch terminator.
2+
//! Ideally, we could combine it with `MatchBranchSimplification`, as these two passes
3+
//! match and merge statements with different patterns. Given the compile time and
4+
//! code complexity, we have not merged them into a more general pass for now.
5+
use rustc_const_eval::const_eval::mk_eval_cx_for_const_val;
6+
use rustc_index::bit_set::DenseBitSet;
7+
use rustc_middle::mir::*;
8+
use rustc_middle::ty::util::Discr;
9+
use rustc_middle::ty::{self, TyCtxt};
10+
use rustc_mir_dataflow::impls::borrowed_locals;
11+
12+
use crate::dead_store_elimination::DeadStoreAnalysis;
13+
use crate::patch::MirPatch;
14+
15+
pub(super) struct MergeBranchSimplification;
16+
17+
impl<'tcx> crate::MirPass<'tcx> for MergeBranchSimplification {
18+
fn is_required(&self) -> bool {
19+
false
20+
}
21+
22+
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
23+
sess.mir_opt_level() >= 2
24+
}
25+
26+
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
27+
let typing_env = body.typing_env(tcx);
28+
let borrowed_locals = borrowed_locals(body);
29+
let mut dead_store_analysis = DeadStoreAnalysis::new(tcx, body, &borrowed_locals);
30+
31+
for switch_bb_idx in body.basic_blocks.indices() {
32+
let bbs = &*body.basic_blocks;
33+
let Some((switch_discr, targets)) = bbs[switch_bb_idx].terminator().kind.as_switch()
34+
else {
35+
continue;
36+
};
37+
// Check that destinations are identical, and if not, then don't optimize this block.
38+
let mut targets_iter = targets.iter();
39+
let first_terminator_kind = &bbs[targets_iter.next().unwrap().1].terminator().kind;
40+
if targets_iter.any(|(_, other_target)| {
41+
first_terminator_kind != &bbs[other_target].terminator().kind
42+
}) {
43+
continue;
44+
}
45+
// We require that the possible target blocks all be distinct.
46+
if !targets.is_distinct() {
47+
continue;
48+
}
49+
if !bbs[targets.otherwise()].is_empty_unreachable() {
50+
continue;
51+
}
52+
// Check if the copy source matches the following pattern.
53+
// _2 = discriminant(*_1); // "*_1" is the expected the copy source.
54+
// switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1];
55+
let Some(&Statement {
56+
kind: StatementKind::Assign(box (discr_place, Rvalue::Discriminant(src_place))),
57+
..
58+
}) = bbs[switch_bb_idx].statements.last()
59+
else {
60+
continue;
61+
};
62+
if switch_discr.place() != Some(discr_place) {
63+
continue;
64+
}
65+
let src_ty = src_place.ty(body.local_decls(), tcx);
66+
if let Some(dest_place) = can_simplify_to_copy(
67+
tcx,
68+
typing_env,
69+
body,
70+
targets,
71+
src_place,
72+
src_ty,
73+
&mut dead_store_analysis,
74+
) {
75+
let statement_index = bbs[switch_bb_idx].statements.len();
76+
let parent_end = Location { block: switch_bb_idx, statement_index };
77+
let mut patch = MirPatch::new(body);
78+
patch.add_assign(parent_end, dest_place, Rvalue::Use(Operand::Copy(src_place)));
79+
patch.patch_terminator(switch_bb_idx, first_terminator_kind.clone());
80+
patch.apply(body);
81+
super::simplify::remove_dead_blocks(body);
82+
// After modifying the MIR, the result of `MaybeTransitiveLiveLocals` may become invalid,
83+
// keeping it simple to process only once.
84+
break;
85+
}
86+
}
87+
}
88+
}
89+
90+
/// The GVN simplified
91+
/// ```ignore (syntax-highlighting-only)
92+
/// match a {
93+
/// Foo::A(x) => Foo::A(*x),
94+
/// Foo::B => Foo::B
95+
/// }
96+
/// ```
97+
/// to
98+
/// ```ignore (syntax-highlighting-only)
99+
/// match a {
100+
/// Foo::A(_x) => a, // copy a
101+
/// Foo::B => Foo::B
102+
/// }
103+
/// ```
104+
/// This function answers whether it can be simplified to a copy statement
105+
/// by returning the copy destination.
106+
fn can_simplify_to_copy<'tcx>(
107+
tcx: TyCtxt<'tcx>,
108+
typing_env: ty::TypingEnv<'tcx>,
109+
body: &Body<'tcx>,
110+
targets: &SwitchTargets,
111+
src_place: Place<'tcx>,
112+
src_ty: PlaceTy<'tcx>,
113+
dead_store_analysis: &mut DeadStoreAnalysis<'tcx, '_, '_>,
114+
) -> Option<Place<'tcx>> {
115+
let mut targets_iter = targets.iter();
116+
let (first_index, first_target) = targets_iter.next()?;
117+
let dest_place = find_copy_assign(
118+
tcx,
119+
typing_env,
120+
body,
121+
first_index,
122+
first_target,
123+
src_place,
124+
src_ty,
125+
dead_store_analysis,
126+
)?;
127+
let dest_ty = dest_place.ty(body.local_decls(), tcx);
128+
if dest_ty.ty != src_ty.ty {
129+
return None;
130+
}
131+
for (other_index, other_target) in targets_iter {
132+
if dest_place
133+
!= find_copy_assign(
134+
tcx,
135+
typing_env,
136+
body,
137+
other_index,
138+
other_target,
139+
src_place,
140+
src_ty,
141+
dead_store_analysis,
142+
)?
143+
{
144+
return None;
145+
}
146+
}
147+
Some(dest_place)
148+
}
149+
150+
// Find the single assignment statement where the source of the copy is from the source.
151+
// All other statements are dead statements or have no effect that can be eliminated.
152+
fn find_copy_assign<'tcx>(
153+
tcx: TyCtxt<'tcx>,
154+
typing_env: ty::TypingEnv<'tcx>,
155+
body: &Body<'tcx>,
156+
index: u128,
157+
target_block: BasicBlock,
158+
src_place: Place<'tcx>,
159+
src_ty: PlaceTy<'tcx>,
160+
dead_store_analysis: &mut DeadStoreAnalysis<'tcx, '_, '_>,
161+
) -> Option<Place<'tcx>> {
162+
let statements = &body.basic_blocks[target_block].statements;
163+
if statements.is_empty() {
164+
return None;
165+
}
166+
let assign_stmt = if statements.len() == 1 {
167+
0
168+
} else {
169+
let mut lived_stmts: DenseBitSet<usize> = DenseBitSet::new_filled(statements.len());
170+
let mut expected_assign_stmt = None;
171+
for (statement_index, statement) in statements.iter().enumerate().rev() {
172+
let loc = Location { block: target_block, statement_index };
173+
if dead_store_analysis.is_dead_store(loc, &statement.kind) {
174+
lived_stmts.remove(statement_index);
175+
} else if matches!(
176+
statement.kind,
177+
StatementKind::StorageLive(_) | StatementKind::StorageDead(_)
178+
) {
179+
} else if matches!(statement.kind, StatementKind::Assign(_))
180+
&& expected_assign_stmt.is_none()
181+
{
182+
// There is only one assign statement that cannot be ignored
183+
// that can be used as an expected copy statement.
184+
expected_assign_stmt = Some(statement_index);
185+
lived_stmts.remove(statement_index);
186+
} else {
187+
return None;
188+
}
189+
}
190+
let expected_assign = expected_assign_stmt?;
191+
if !lived_stmts.is_empty() {
192+
// We can ignore the paired StorageLive and StorageDead.
193+
let mut storage_live_locals: DenseBitSet<Local> =
194+
DenseBitSet::new_empty(body.local_decls.len());
195+
for stmt_index in lived_stmts.iter() {
196+
let statement = &statements[stmt_index];
197+
match &statement.kind {
198+
StatementKind::StorageLive(local) if storage_live_locals.insert(*local) => {}
199+
StatementKind::StorageDead(local) if storage_live_locals.remove(*local) => {}
200+
_ => return None,
201+
}
202+
}
203+
if !storage_live_locals.is_empty() {
204+
return None;
205+
}
206+
}
207+
expected_assign
208+
};
209+
let &(dest_place, ref rvalue) = statements[assign_stmt].kind.as_assign()?;
210+
let dest_ty = dest_place.ty(body.local_decls(), tcx);
211+
if dest_ty.ty != src_ty.ty {
212+
return None;
213+
}
214+
let ty::Adt(def, _) = dest_ty.ty.kind() else {
215+
return None;
216+
};
217+
match rvalue {
218+
// Check if `_3 = const Foo::B` can be transformed to `_3 = copy *_1`.
219+
Rvalue::Use(Operand::Constant(box constant))
220+
if let Const::Val(const_, ty) = constant.const_ =>
221+
{
222+
let (ecx, op) =
223+
mk_eval_cx_for_const_val(tcx.at(constant.span), typing_env, const_, ty)?;
224+
let variant = ecx.read_discriminant(&op).discard_err()?;
225+
if !def.variants()[variant].fields.is_empty() {
226+
return None;
227+
}
228+
let Discr { val, .. } = ty.discriminant_for_variant(tcx, variant)?;
229+
if val != index {
230+
return None;
231+
}
232+
}
233+
Rvalue::Use(Operand::Copy(place)) if *place == src_place => {}
234+
// Check if `_3 = Foo::B` can be transformed to `_3 = copy *_1`.
235+
Rvalue::Aggregate(box AggregateKind::Adt(_, variant_index, _, _, None), fields)
236+
if fields.is_empty()
237+
&& let Some(Discr { val, .. }) =
238+
src_ty.ty.discriminant_for_variant(tcx, *variant_index)
239+
&& val == index => {}
240+
_ => return None,
241+
}
242+
Some(dest_place)
243+
}

0 commit comments

Comments
 (0)