Skip to content

Commit 5a8dfd9

Browse files
committed
Auto merge of #85158 - JulianKnodt:array_const_val, r=cjgillot
Mir-Opt for copying enums with large discrepancies I have been meaning to make this for quite a while, based off of this [hackmd](https://hackmd.io/`@ft4bxUsFT5CEUBmRKYHr7w/rJM8BBPzD).` I'm not sure where to put this opt now that I've made it, so I'd appreciate suggestions on that! It's also one long chain of statements, not sure if there's a more friendly format to make it. r? `@tmiasko` I would `r` oli but he's on leave so he suggested I `r` tmiasko or wesleywiser.
2 parents 2773383 + 15d4728 commit 5a8dfd9

11 files changed

+789
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
use crate::rustc_middle::ty::util::IntTypeExt;
2+
use crate::MirPass;
3+
use rustc_data_structures::fx::FxHashMap;
4+
use rustc_middle::mir::interpret::AllocId;
5+
use rustc_middle::mir::*;
6+
use rustc_middle::ty::{self, AdtDef, Const, ParamEnv, Ty, TyCtxt};
7+
use rustc_session::Session;
8+
use rustc_target::abi::{HasDataLayout, Size, TagEncoding, Variants};
9+
10+
/// A pass that seeks to optimize unnecessary moves of large enum types, if there is a large
11+
/// enough discrepancy between them.
12+
///
13+
/// i.e. If there is are two variants:
14+
/// ```
15+
/// enum Example {
16+
/// Small,
17+
/// Large([u32; 1024]),
18+
/// }
19+
/// ```
20+
/// Instead of emitting moves of the large variant,
21+
/// Perform a memcpy instead.
22+
/// Based off of [this HackMD](https://hackmd.io/@ft4bxUsFT5CEUBmRKYHr7w/rJM8BBPzD).
23+
///
24+
/// In summary, what this does is at runtime determine which enum variant is active,
25+
/// and instead of copying all the bytes of the largest possible variant,
26+
/// copy only the bytes for the currently active variant.
27+
pub struct EnumSizeOpt {
28+
pub(crate) discrepancy: u64,
29+
}
30+
31+
impl<'tcx> MirPass<'tcx> for EnumSizeOpt {
32+
fn is_enabled(&self, sess: &Session) -> bool {
33+
sess.opts.unstable_opts.unsound_mir_opts || sess.mir_opt_level() >= 3
34+
}
35+
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
36+
// NOTE: This pass may produce different MIR based on the alignment of the target
37+
// platform, but it will still be valid.
38+
self.optim(tcx, body);
39+
}
40+
}
41+
42+
impl EnumSizeOpt {
43+
fn candidate<'tcx>(
44+
&self,
45+
tcx: TyCtxt<'tcx>,
46+
param_env: ParamEnv<'tcx>,
47+
ty: Ty<'tcx>,
48+
alloc_cache: &mut FxHashMap<Ty<'tcx>, AllocId>,
49+
) -> Option<(AdtDef<'tcx>, usize, AllocId)> {
50+
let adt_def = match ty.kind() {
51+
ty::Adt(adt_def, _substs) if adt_def.is_enum() => adt_def,
52+
_ => return None,
53+
};
54+
let layout = tcx.layout_of(param_env.and(ty)).ok()?;
55+
let variants = match &layout.variants {
56+
Variants::Single { .. } => return None,
57+
Variants::Multiple { tag_encoding, .. }
58+
if matches!(tag_encoding, TagEncoding::Niche { .. }) =>
59+
{
60+
return None;
61+
}
62+
Variants::Multiple { variants, .. } if variants.len() <= 1 => return None,
63+
Variants::Multiple { variants, .. } => variants,
64+
};
65+
let min = variants.iter().map(|v| v.size).min().unwrap();
66+
let max = variants.iter().map(|v| v.size).max().unwrap();
67+
if max.bytes() - min.bytes() < self.discrepancy {
68+
return None;
69+
}
70+
71+
let num_discrs = adt_def.discriminants(tcx).count();
72+
if variants.iter_enumerated().any(|(var_idx, _)| {
73+
let discr_for_var = adt_def.discriminant_for_variant(tcx, var_idx).val;
74+
(discr_for_var > usize::MAX as u128) || (discr_for_var as usize >= num_discrs)
75+
}) {
76+
return None;
77+
}
78+
if let Some(alloc_id) = alloc_cache.get(&ty) {
79+
return Some((*adt_def, num_discrs, *alloc_id));
80+
}
81+
82+
let data_layout = tcx.data_layout();
83+
let ptr_sized_int = data_layout.ptr_sized_integer();
84+
let target_bytes = ptr_sized_int.size().bytes() as usize;
85+
let mut data = vec![0; target_bytes * num_discrs];
86+
macro_rules! encode_store {
87+
($curr_idx: expr, $endian: expr, $bytes: expr) => {
88+
let bytes = match $endian {
89+
rustc_target::abi::Endian::Little => $bytes.to_le_bytes(),
90+
rustc_target::abi::Endian::Big => $bytes.to_be_bytes(),
91+
};
92+
for (i, b) in bytes.into_iter().enumerate() {
93+
data[$curr_idx + i] = b;
94+
}
95+
};
96+
}
97+
98+
for (var_idx, layout) in variants.iter_enumerated() {
99+
let curr_idx =
100+
target_bytes * adt_def.discriminant_for_variant(tcx, var_idx).val as usize;
101+
let sz = layout.size;
102+
match ptr_sized_int {
103+
rustc_target::abi::Integer::I32 => {
104+
encode_store!(curr_idx, data_layout.endian, sz.bytes() as u32);
105+
}
106+
rustc_target::abi::Integer::I64 => {
107+
encode_store!(curr_idx, data_layout.endian, sz.bytes());
108+
}
109+
_ => unreachable!(),
110+
};
111+
}
112+
let alloc = interpret::Allocation::from_bytes(
113+
data,
114+
tcx.data_layout.ptr_sized_integer().align(&tcx.data_layout).abi,
115+
Mutability::Not,
116+
);
117+
let alloc = tcx.create_memory_alloc(tcx.intern_const_alloc(alloc));
118+
Some((*adt_def, num_discrs, *alloc_cache.entry(ty).or_insert(alloc)))
119+
}
120+
fn optim<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
121+
let mut alloc_cache = FxHashMap::default();
122+
let body_did = body.source.def_id();
123+
let param_env = tcx.param_env(body_did);
124+
125+
let blocks = body.basic_blocks.as_mut();
126+
let local_decls = &mut body.local_decls;
127+
128+
for bb in blocks {
129+
bb.expand_statements(|st| {
130+
if let StatementKind::Assign(box (
131+
lhs,
132+
Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)),
133+
)) = &st.kind
134+
{
135+
let ty = lhs.ty(local_decls, tcx).ty;
136+
137+
let source_info = st.source_info;
138+
let span = source_info.span;
139+
140+
let (adt_def, num_variants, alloc_id) =
141+
self.candidate(tcx, param_env, ty, &mut alloc_cache)?;
142+
let alloc = tcx.global_alloc(alloc_id).unwrap_memory();
143+
144+
let tmp_ty = tcx.mk_ty(ty::Array(
145+
tcx.types.usize,
146+
Const::from_usize(tcx, num_variants as u64),
147+
));
148+
149+
let size_array_local = local_decls.push(LocalDecl::new(tmp_ty, span));
150+
let store_live = Statement {
151+
source_info,
152+
kind: StatementKind::StorageLive(size_array_local),
153+
};
154+
155+
let place = Place::from(size_array_local);
156+
let constant_vals = Constant {
157+
span,
158+
user_ty: None,
159+
literal: ConstantKind::Val(
160+
interpret::ConstValue::ByRef { alloc, offset: Size::ZERO },
161+
tmp_ty,
162+
),
163+
};
164+
let rval = Rvalue::Use(Operand::Constant(box (constant_vals)));
165+
166+
let const_assign =
167+
Statement { source_info, kind: StatementKind::Assign(box (place, rval)) };
168+
169+
let discr_place = Place::from(
170+
local_decls
171+
.push(LocalDecl::new(adt_def.repr().discr_type().to_ty(tcx), span)),
172+
);
173+
174+
let store_discr = Statement {
175+
source_info,
176+
kind: StatementKind::Assign(box (discr_place, Rvalue::Discriminant(*rhs))),
177+
};
178+
179+
let discr_cast_place =
180+
Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span)));
181+
182+
let cast_discr = Statement {
183+
source_info,
184+
kind: StatementKind::Assign(box (
185+
discr_cast_place,
186+
Rvalue::Cast(
187+
CastKind::IntToInt,
188+
Operand::Copy(discr_place),
189+
tcx.types.usize,
190+
),
191+
)),
192+
};
193+
194+
let size_place =
195+
Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span)));
196+
197+
let store_size = Statement {
198+
source_info,
199+
kind: StatementKind::Assign(box (
200+
size_place,
201+
Rvalue::Use(Operand::Copy(Place {
202+
local: size_array_local,
203+
projection: tcx.intern_place_elems(&[PlaceElem::Index(
204+
discr_cast_place.local,
205+
)]),
206+
})),
207+
)),
208+
};
209+
210+
let dst =
211+
Place::from(local_decls.push(LocalDecl::new(tcx.mk_mut_ptr(ty), span)));
212+
213+
let dst_ptr = Statement {
214+
source_info,
215+
kind: StatementKind::Assign(box (
216+
dst,
217+
Rvalue::AddressOf(Mutability::Mut, *lhs),
218+
)),
219+
};
220+
221+
let dst_cast_ty = tcx.mk_mut_ptr(tcx.types.u8);
222+
let dst_cast_place =
223+
Place::from(local_decls.push(LocalDecl::new(dst_cast_ty, span)));
224+
225+
let dst_cast = Statement {
226+
source_info,
227+
kind: StatementKind::Assign(box (
228+
dst_cast_place,
229+
Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(dst), dst_cast_ty),
230+
)),
231+
};
232+
233+
let src =
234+
Place::from(local_decls.push(LocalDecl::new(tcx.mk_imm_ptr(ty), span)));
235+
236+
let src_ptr = Statement {
237+
source_info,
238+
kind: StatementKind::Assign(box (
239+
src,
240+
Rvalue::AddressOf(Mutability::Not, *rhs),
241+
)),
242+
};
243+
244+
let src_cast_ty = tcx.mk_imm_ptr(tcx.types.u8);
245+
let src_cast_place =
246+
Place::from(local_decls.push(LocalDecl::new(src_cast_ty, span)));
247+
248+
let src_cast = Statement {
249+
source_info,
250+
kind: StatementKind::Assign(box (
251+
src_cast_place,
252+
Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(src), src_cast_ty),
253+
)),
254+
};
255+
256+
let deinit_old =
257+
Statement { source_info, kind: StatementKind::Deinit(box dst) };
258+
259+
let copy_bytes = Statement {
260+
source_info,
261+
kind: StatementKind::Intrinsic(
262+
box NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping {
263+
src: Operand::Copy(src_cast_place),
264+
dst: Operand::Copy(dst_cast_place),
265+
count: Operand::Copy(size_place),
266+
}),
267+
),
268+
};
269+
270+
let store_dead = Statement {
271+
source_info,
272+
kind: StatementKind::StorageDead(size_array_local),
273+
};
274+
let iter = [
275+
store_live,
276+
const_assign,
277+
store_discr,
278+
cast_discr,
279+
store_size,
280+
dst_ptr,
281+
dst_cast,
282+
src_ptr,
283+
src_cast,
284+
deinit_old,
285+
copy_bytes,
286+
store_dead,
287+
]
288+
.into_iter();
289+
290+
st.make_nop();
291+
Some(iter)
292+
} else {
293+
None
294+
}
295+
});
296+
}
297+
}
298+
}

compiler/rustc_mir_transform/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#![allow(rustc::potential_query_instability)]
22
#![feature(box_patterns)]
33
#![feature(drain_filter)]
4+
#![feature(box_syntax)]
45
#![feature(let_chains)]
56
#![feature(map_try_insert)]
67
#![feature(min_specialization)]
@@ -73,6 +74,7 @@ mod function_item_references;
7374
mod generator;
7475
mod inline;
7576
mod instcombine;
77+
mod large_enums;
7678
mod lower_intrinsics;
7779
mod lower_slice_len;
7880
mod match_branches;
@@ -583,6 +585,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
583585
&simplify::SimplifyLocals::new("final"),
584586
&multiple_return_terminators::MultipleReturnTerminators,
585587
&deduplicate_blocks::DeduplicateBlocks,
588+
&large_enums::EnumSizeOpt { discrepancy: 128 },
586589
// Some cleanup necessary at least for LLVM and potentially other codegen backends.
587590
&add_call_guards::CriticalCallEdges,
588591
// Dump the end result for testing and debugging purposes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
- // MIR for `cand` before EnumSizeOpt
2+
+ // MIR for `cand` after EnumSizeOpt
3+
4+
fn cand() -> Candidate {
5+
let mut _0: Candidate; // return place in scope 0 at $DIR/enum_opt.rs:+0:18: +0:27
6+
let mut _1: Candidate; // in scope 0 at $DIR/enum_opt.rs:+1:7: +1:12
7+
let mut _2: Candidate; // in scope 0 at $DIR/enum_opt.rs:+2:7: +2:34
8+
let mut _3: [u8; 8196]; // in scope 0 at $DIR/enum_opt.rs:+2:24: +2:33
9+
+ let mut _4: [usize; 2]; // in scope 0 at $DIR/enum_opt.rs:+2:3: +2:34
10+
+ let mut _5: isize; // in scope 0 at $DIR/enum_opt.rs:+2:3: +2:34
11+
+ let mut _6: usize; // in scope 0 at $DIR/enum_opt.rs:+2:3: +2:34
12+
+ let mut _7: usize; // in scope 0 at $DIR/enum_opt.rs:+2:3: +2:34
13+
+ let mut _8: *mut Candidate; // in scope 0 at $DIR/enum_opt.rs:+2:3: +2:34
14+
+ let mut _9: *mut u8; // in scope 0 at $DIR/enum_opt.rs:+2:3: +2:34
15+
+ let mut _10: *const Candidate; // in scope 0 at $DIR/enum_opt.rs:+2:3: +2:34
16+
+ let mut _11: *const u8; // in scope 0 at $DIR/enum_opt.rs:+2:3: +2:34
17+
+ let mut _12: [usize; 2]; // in scope 0 at $DIR/enum_opt.rs:+3:3: +3:4
18+
+ let mut _13: isize; // in scope 0 at $DIR/enum_opt.rs:+3:3: +3:4
19+
+ let mut _14: usize; // in scope 0 at $DIR/enum_opt.rs:+3:3: +3:4
20+
+ let mut _15: usize; // in scope 0 at $DIR/enum_opt.rs:+3:3: +3:4
21+
+ let mut _16: *mut Candidate; // in scope 0 at $DIR/enum_opt.rs:+3:3: +3:4
22+
+ let mut _17: *mut u8; // in scope 0 at $DIR/enum_opt.rs:+3:3: +3:4
23+
+ let mut _18: *const Candidate; // in scope 0 at $DIR/enum_opt.rs:+3:3: +3:4
24+
+ let mut _19: *const u8; // in scope 0 at $DIR/enum_opt.rs:+3:3: +3:4
25+
scope 1 {
26+
debug a => _1; // in scope 1 at $DIR/enum_opt.rs:+1:7: +1:12
27+
}
28+
29+
bb0: {
30+
StorageLive(_1); // scope 0 at $DIR/enum_opt.rs:+1:7: +1:12
31+
_1 = Candidate::Small(const 1_u8); // scope 0 at $DIR/enum_opt.rs:+1:15: +1:34
32+
StorageLive(_2); // scope 1 at $DIR/enum_opt.rs:+2:7: +2:34
33+
StorageLive(_3); // scope 1 at $DIR/enum_opt.rs:+2:24: +2:33
34+
_3 = [const 1_u8; 8196]; // scope 1 at $DIR/enum_opt.rs:+2:24: +2:33
35+
_2 = Candidate::Large(move _3); // scope 1 at $DIR/enum_opt.rs:+2:7: +2:34
36+
StorageDead(_3); // scope 1 at $DIR/enum_opt.rs:+2:33: +2:34
37+
- _1 = move _2; // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
38+
+ StorageLive(_4); // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
39+
+ _4 = const [2_usize, 8197_usize]; // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
40+
+ _5 = discriminant(_2); // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
41+
+ _6 = _5 as usize (IntToInt); // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
42+
+ _7 = _4[_6]; // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
43+
+ _8 = &raw mut _1; // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
44+
+ _9 = _8 as *mut u8 (PtrToPtr); // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
45+
+ _10 = &raw const _2; // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
46+
+ _11 = _10 as *const u8 (PtrToPtr); // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
47+
+ Deinit(_8); // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
48+
+ copy_nonoverlapping(dst = _9, src = _11, count = _7); // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
49+
+ StorageDead(_4); // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
50+
StorageDead(_2); // scope 1 at $DIR/enum_opt.rs:+2:33: +2:34
51+
- _0 = move _1; // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
52+
+ StorageLive(_12); // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
53+
+ _12 = const [2_usize, 8197_usize]; // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
54+
+ _13 = discriminant(_1); // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
55+
+ _14 = _13 as usize (IntToInt); // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
56+
+ _15 = _12[_14]; // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
57+
+ _16 = &raw mut _0; // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
58+
+ _17 = _16 as *mut u8 (PtrToPtr); // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
59+
+ _18 = &raw const _1; // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
60+
+ _19 = _18 as *const u8 (PtrToPtr); // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
61+
+ Deinit(_16); // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
62+
+ copy_nonoverlapping(dst = _17, src = _19, count = _15); // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
63+
+ StorageDead(_12); // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
64+
StorageDead(_1); // scope 0 at $DIR/enum_opt.rs:+4:1: +4:2
65+
return; // scope 0 at $DIR/enum_opt.rs:+4:2: +4:2
66+
}
67+
}
68+

0 commit comments

Comments
 (0)