Skip to content

Commit 9179e1b

Browse files
committed
Add tag_for_variant query
This query allows for sharing code between `rustc_const_eval` and `rustc_transmutability`.
1 parent 9023f90 commit 9179e1b

File tree

7 files changed

+162
-94
lines changed

7 files changed

+162
-94
lines changed

Diff for: compiler/rustc_const_eval/src/const_eval/eval_queries.rs

+19-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use either::{Left, Right};
33
use rustc_hir::def::DefKind;
44
use rustc_middle::mir::interpret::{AllocId, ErrorHandled, InterpErrorInfo};
55
use rustc_middle::mir::{self, ConstAlloc, ConstValue};
6-
use rustc_middle::query::TyCtxtAt;
6+
use rustc_middle::query::{Key, TyCtxtAt};
77
use rustc_middle::traits::Reveal;
88
use rustc_middle::ty::layout::LayoutOf;
99
use rustc_middle::ty::print::with_no_trimmed_paths;
@@ -243,6 +243,24 @@ pub(crate) fn turn_into_const_value<'tcx>(
243243
op_to_const(&ecx, &mplace.into(), /* for diagnostics */ false)
244244
}
245245

246+
/// Computes the tag (if any) for a given type and variant.
247+
#[instrument(skip(tcx), level = "debug")]
248+
pub fn tag_for_variant_provider<'tcx>(
249+
tcx: TyCtxt<'tcx>,
250+
(ty, variant_index): (Ty<'tcx>, abi::VariantIdx),
251+
) -> Option<ty::ScalarInt> {
252+
assert!(ty.is_enum());
253+
254+
let ecx = InterpCx::new(
255+
tcx,
256+
ty.default_span(tcx),
257+
ty::ParamEnv::reveal_all(),
258+
CompileTimeInterpreter::new(CanAccessMutGlobal::No, CheckAlignment::Error),
259+
);
260+
261+
ecx.tag_for_variant(ty, variant_index).unwrap().value()
262+
}
263+
246264
#[instrument(skip(tcx), level = "debug")]
247265
pub fn eval_to_const_value_raw_provider<'tcx>(
248266
tcx: TyCtxt<'tcx>,

Diff for: compiler/rustc_const_eval/src/interpret/discriminant.rs

+106-68
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,28 @@
22
33
use rustc_middle::mir;
44
use rustc_middle::ty::layout::{LayoutOf, PrimitiveExt};
5-
use rustc_middle::ty::{self, Ty};
5+
use rustc_middle::ty::{self, ScalarInt, Ty};
66
use rustc_target::abi::{self, TagEncoding};
77
use rustc_target::abi::{VariantIdx, Variants};
88

99
use super::{ImmTy, InterpCx, InterpResult, Machine, Readable, Scalar, Writeable};
1010

11+
/// The tag of an enum discriminant.
12+
pub(crate) enum Tag {
13+
/// No tag; the variant is `Single`-encoded.
14+
None,
15+
/// The variant is tagged.
16+
Tagged { tag: ScalarInt, tag_field: usize },
17+
/// No tag; the variant is identified by its validity.
18+
Untagged,
19+
}
20+
21+
impl Tag {
22+
pub(crate) fn value(self) -> Option<ScalarInt> {
23+
if let Self::Tagged { tag, .. } = self { Some(tag) } else { None }
24+
}
25+
}
26+
1127
impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
1228
/// Writes the discriminant of the given variant.
1329
///
@@ -28,78 +44,28 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
2844
throw_ub!(UninhabitedEnumVariantWritten(variant_index))
2945
}
3046

31-
match dest.layout().variants {
32-
abi::Variants::Single { index } => {
33-
assert_eq!(index, variant_index);
34-
}
35-
abi::Variants::Multiple {
36-
tag_encoding: TagEncoding::Direct,
37-
tag: tag_layout,
38-
tag_field,
39-
..
40-
} => {
47+
let (tag, tag_field) = match self.tag_for_variant(dest.layout().ty, variant_index)? {
48+
Tag::None => return Ok(()),
49+
Tag::Tagged { tag, tag_field } => {
4150
// No need to validate that the discriminant here because the
42-
// `TyAndLayout::for_variant()` call earlier already checks the variant is valid.
43-
44-
let discr_val = dest
45-
.layout()
46-
.ty
47-
.discriminant_for_variant(*self.tcx, variant_index)
48-
.unwrap()
49-
.val;
50-
51-
// raw discriminants for enums are isize or bigger during
52-
// their computation, but the in-memory tag is the smallest possible
53-
// representation
54-
let size = tag_layout.size(self);
55-
let tag_val = size.truncate(discr_val);
56-
57-
let tag_dest = self.project_field(dest, tag_field)?;
58-
self.write_scalar(Scalar::from_uint(tag_val, size), &tag_dest)?;
51+
// `TyAndLayout::for_variant()` call earlier already checks the
52+
// variant is valid.
53+
(tag, tag_field)
5954
}
60-
abi::Variants::Multiple {
61-
tag_encoding:
62-
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start },
63-
tag: tag_layout,
64-
tag_field,
65-
..
66-
} => {
67-
// No need to validate that the discriminant here because the
68-
// `TyAndLayout::for_variant()` call earlier already checks the variant is valid.
69-
70-
if variant_index != untagged_variant {
71-
let variants_start = niche_variants.start().as_u32();
72-
let variant_index_relative = variant_index
73-
.as_u32()
74-
.checked_sub(variants_start)
75-
.expect("overflow computing relative variant idx");
76-
// We need to use machine arithmetic when taking into account `niche_start`:
77-
// tag_val = variant_index_relative + niche_start_val
78-
let tag_layout = self.layout_of(tag_layout.primitive().to_int_ty(*self.tcx))?;
79-
let niche_start_val = ImmTy::from_uint(niche_start, tag_layout);
80-
let variant_index_relative_val =
81-
ImmTy::from_uint(variant_index_relative, tag_layout);
82-
let tag_val = self.wrapping_binary_op(
83-
mir::BinOp::Add,
84-
&variant_index_relative_val,
85-
&niche_start_val,
86-
)?;
87-
// Write result.
88-
let niche_dest = self.project_field(dest, tag_field)?;
89-
self.write_immediate(*tag_val, &niche_dest)?;
90-
} else {
91-
// The untagged variant is implicitly encoded simply by having a value that is
92-
// outside the niche variants. But what if the data stored here does not
93-
// actually encode this variant? That would be bad! So let's double-check...
94-
let actual_variant = self.read_discriminant(&dest.to_op(self)?)?;
95-
if actual_variant != variant_index {
96-
throw_ub!(InvalidNichedEnumVariantWritten { enum_ty: dest.layout().ty });
97-
}
55+
Tag::Untagged => {
56+
// The untagged variant is implicitly encoded simply by having a value that is
57+
// outside the niche variants. But what if the data stored here does not
58+
// actually encode this variant? That would be bad! So let's double-check...
59+
let actual_variant = self.read_discriminant(&dest.to_op(self)?)?;
60+
if actual_variant != variant_index {
61+
throw_ub!(InvalidNichedEnumVariantWritten { enum_ty: dest.layout().ty });
9862
}
63+
return Ok(());
9964
}
100-
}
65+
};
10166

102-
Ok(())
67+
let tag_dest = self.project_field(dest, tag_field)?;
68+
self.write_scalar(tag, &tag_dest)
10369
}
10470

10571
/// Read discriminant, return the runtime value as well as the variant index.
@@ -277,4 +243,76 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
277243
};
278244
Ok(ImmTy::from_scalar(discr_value, discr_layout))
279245
}
246+
247+
/// Computes the tag (if any) of a given variant of type `ty`.
248+
pub(crate) fn tag_for_variant(
249+
&self,
250+
ty: Ty<'tcx>,
251+
variant_index: VariantIdx,
252+
) -> InterpResult<'tcx, Tag> {
253+
match self.layout_of(ty)?.variants {
254+
abi::Variants::Single { index } => {
255+
assert_eq!(index, variant_index);
256+
Ok(Tag::None)
257+
}
258+
259+
abi::Variants::Multiple {
260+
tag_encoding: TagEncoding::Direct,
261+
tag: tag_layout,
262+
tag_field,
263+
..
264+
} => {
265+
// raw discriminants for enums are isize or bigger during
266+
// their computation, but the in-memory tag is the smallest possible
267+
// representation
268+
let discr = self.discriminant_for_variant(ty, variant_index)?;
269+
let discr_size = discr.layout.size;
270+
let discr_val = discr.to_scalar().to_bits(discr_size)?;
271+
let tag_size = tag_layout.size(self);
272+
let tag_val = tag_size.truncate(discr_val);
273+
let tag = ScalarInt::try_from_uint(tag_val, tag_size).unwrap();
274+
Ok(Tag::Tagged { tag, tag_field })
275+
}
276+
277+
abi::Variants::Multiple {
278+
tag_encoding: TagEncoding::Niche { untagged_variant, .. },
279+
..
280+
} if untagged_variant == variant_index => {
281+
// The untagged variant is implicitly encoded simply by having a
282+
// value that is outside the niche variants.
283+
Ok(Tag::Untagged)
284+
}
285+
286+
abi::Variants::Multiple {
287+
tag_encoding:
288+
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start },
289+
tag: tag_layout,
290+
tag_field,
291+
..
292+
} => {
293+
assert!(variant_index != untagged_variant);
294+
let variants_start = niche_variants.start().as_u32();
295+
let variant_index_relative = variant_index
296+
.as_u32()
297+
.checked_sub(variants_start)
298+
.expect("overflow computing relative variant idx");
299+
// We need to use machine arithmetic when taking into account `niche_start`:
300+
// tag_val = variant_index_relative + niche_start_val
301+
let tag_layout = self.layout_of(tag_layout.primitive().to_int_ty(*self.tcx))?;
302+
let niche_start_val = ImmTy::from_uint(niche_start, tag_layout);
303+
let variant_index_relative_val =
304+
ImmTy::from_uint(variant_index_relative, tag_layout);
305+
let tag = self
306+
.wrapping_binary_op(
307+
mir::BinOp::Add,
308+
&variant_index_relative_val,
309+
&niche_start_val,
310+
)?
311+
.to_scalar()
312+
.try_to_int()
313+
.unwrap();
314+
Ok(Tag::Tagged { tag, tag_field })
315+
}
316+
}
317+
}
280318
}

Diff for: compiler/rustc_const_eval/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ rustc_fluent_macro::fluent_messages! { "../messages.ftl" }
4040

4141
pub fn provide(providers: &mut Providers) {
4242
const_eval::provide(providers);
43+
providers.tag_for_variant = const_eval::tag_for_variant_provider;
4344
providers.eval_to_const_value_raw = const_eval::eval_to_const_value_raw_provider;
4445
providers.eval_to_allocation_raw = const_eval::eval_to_allocation_raw_provider;
4546
providers.eval_static_initializer = const_eval::eval_static_initializer_provider;

Diff for: compiler/rustc_middle/src/query/erase.rs

+1
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ trivial! {
234234
Option<rustc_middle::middle::stability::DeprecationEntry>,
235235
Option<rustc_middle::ty::Destructor>,
236236
Option<rustc_middle::ty::ImplTraitInTraitData>,
237+
Option<rustc_middle::ty::ScalarInt>,
237238
Option<rustc_span::def_id::CrateNum>,
238239
Option<rustc_span::def_id::DefId>,
239240
Option<rustc_span::def_id::LocalDefId>,

Diff for: compiler/rustc_middle/src/query/keys.rs

+9
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use rustc_query_system::query::DefIdCacheSelector;
1313
use rustc_query_system::query::{DefaultCacheSelector, SingleCacheSelector, VecCacheSelector};
1414
use rustc_span::symbol::{Ident, Symbol};
1515
use rustc_span::{Span, DUMMY_SP};
16+
use rustc_target::abi;
1617

1718
/// Placeholder for `CrateNum`'s "local" counterpart
1819
#[derive(Copy, Clone, Debug)]
@@ -502,6 +503,14 @@ impl<'tcx> Key for (DefId, Ty<'tcx>, GenericArgsRef<'tcx>, ty::ParamEnv<'tcx>) {
502503
}
503504
}
504505

506+
impl<'tcx> Key for (Ty<'tcx>, abi::VariantIdx) {
507+
type CacheSelector = DefaultCacheSelector<Self>;
508+
509+
fn default_span(&self, _tcx: TyCtxt<'_>) -> Span {
510+
DUMMY_SP
511+
}
512+
}
513+
505514
impl<'tcx> Key for (ty::Predicate<'tcx>, traits::WellFormedLoc) {
506515
type CacheSelector = DefaultCacheSelector<Self>;
507516

Diff for: compiler/rustc_middle/src/query/mod.rs

+7
Original file line numberDiff line numberDiff line change
@@ -1042,6 +1042,13 @@ rustc_queries! {
10421042
}
10431043
}
10441044

1045+
/// Computes the tag (if any) for a given type and variant.
1046+
query tag_for_variant(
1047+
key: (Ty<'tcx>, abi::VariantIdx)
1048+
) -> Option<ty::ScalarInt> {
1049+
desc { "computing variant tag for enum" }
1050+
}
1051+
10451052
/// Evaluates a constant and returns the computed allocation.
10461053
///
10471054
/// **Do not use this** directly, use the `eval_to_const_value` or `eval_to_valtree` instead.

Diff for: compiler/rustc_transmute/src/layout/tree.rs

+19-25
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,10 @@ pub(crate) mod rustc {
174174
use crate::layout::rustc::{Def, Ref};
175175

176176
use rustc_middle::ty::layout::LayoutError;
177-
use rustc_middle::ty::util::Discr;
178177
use rustc_middle::ty::AdtDef;
179178
use rustc_middle::ty::GenericArgsRef;
180179
use rustc_middle::ty::ParamEnv;
180+
use rustc_middle::ty::ScalarInt;
181181
use rustc_middle::ty::VariantDef;
182182
use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitableExt};
183183
use rustc_span::ErrorGuaranteed;
@@ -331,14 +331,15 @@ pub(crate) mod rustc {
331331
trace!(?adt_def, "treeifying enum");
332332
let mut tree = Tree::uninhabited();
333333

334-
for (idx, discr) in adt_def.discriminants(tcx) {
334+
for (idx, variant) in adt_def.variants().iter_enumerated() {
335+
let tag = tcx.tag_for_variant((ty, idx));
335336
tree = tree.or(Self::from_repr_c_variant(
336337
ty,
337338
*adt_def,
338339
args_ref,
339340
&layout_summary,
340-
Some(discr),
341-
adt_def.variant(idx),
341+
tag,
342+
variant,
342343
tcx,
343344
)?);
344345
}
@@ -393,7 +394,7 @@ pub(crate) mod rustc {
393394
adt_def: AdtDef<'tcx>,
394395
args_ref: GenericArgsRef<'tcx>,
395396
layout_summary: &LayoutSummary,
396-
discr: Option<Discr<'tcx>>,
397+
tag: Option<ScalarInt>,
397398
variant_def: &'tcx VariantDef,
398399
tcx: TyCtxt<'tcx>,
399400
) -> Result<Self, Err> {
@@ -403,9 +404,6 @@ pub(crate) mod rustc {
403404
let min_align = repr.align.unwrap_or(Align::ONE);
404405
let max_align = repr.pack.unwrap_or(Align::MAX);
405406

406-
let clamp =
407-
|align: Align| align.clamp(min_align, max_align).bytes().try_into().unwrap();
408-
409407
let variant_span = trace_span!(
410408
"treeifying variant",
411409
min_align = ?min_align,
@@ -419,17 +417,12 @@ pub(crate) mod rustc {
419417
)
420418
.unwrap();
421419

422-
// The layout of the variant is prefixed by the discriminant, if any.
423-
if let Some(discr) = discr {
424-
trace!(?discr, "treeifying discriminant");
425-
let discr_layout = alloc::Layout::from_size_align(
426-
layout_summary.discriminant_size,
427-
clamp(layout_summary.discriminant_align),
428-
)
429-
.unwrap();
430-
trace!(?discr_layout, "computed discriminant layout");
431-
variant_layout = variant_layout.extend(discr_layout).unwrap().0;
432-
tree = tree.then(Self::from_discr(discr, tcx, layout_summary.discriminant_size));
420+
// The layout of the variant is prefixed by the tag, if any.
421+
if let Some(tag) = tag {
422+
let tag_layout =
423+
alloc::Layout::from_size_align(tag.size().bytes_usize(), 1).unwrap();
424+
tree = tree.then(Self::from_tag(tag, tcx));
425+
variant_layout = variant_layout.extend(tag_layout).unwrap().0;
433426
}
434427

435428
// Next come fields.
@@ -469,18 +462,19 @@ pub(crate) mod rustc {
469462
Ok(tree)
470463
}
471464

472-
pub fn from_discr(discr: Discr<'tcx>, tcx: TyCtxt<'tcx>, size: usize) -> Self {
465+
pub fn from_tag(tag: ScalarInt, tcx: TyCtxt<'tcx>) -> Self {
473466
use rustc_target::abi::Endian;
474-
467+
let size = tag.size();
468+
let bits = tag.to_bits(size).unwrap();
475469
let bytes: [u8; 16];
476470
let bytes = match tcx.data_layout.endian {
477471
Endian::Little => {
478-
bytes = discr.val.to_le_bytes();
479-
&bytes[..size]
472+
bytes = bits.to_le_bytes();
473+
&bytes[..size.bytes_usize()]
480474
}
481475
Endian::Big => {
482-
bytes = discr.val.to_be_bytes();
483-
&bytes[bytes.len() - size..]
476+
bytes = bits.to_be_bytes();
477+
&bytes[bytes.len() - size.bytes_usize()..]
484478
}
485479
};
486480
Self::Seq(bytes.iter().map(|&b| Self::from_bits(b)).collect())

0 commit comments

Comments
 (0)