Skip to content

Simplify thir::PatKind::ExpandedConstant #139108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions compiler/rustc_middle/src/thir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -812,23 +812,17 @@ pub enum PatKind<'tcx> {
},

/// Pattern obtained by converting a constant (inline or named) to its pattern
/// representation using `const_to_pat`.
/// representation using `const_to_pat`. This is used for unsafety checking.
ExpandedConstant {
/// [DefId] of the constant, we need this so that we have a
/// reference that can be used by unsafety checking to visit nested
/// unevaluated constants and for diagnostics. If the `DefId` doesn't
/// correspond to a local crate, it points at the `const` item.
/// [DefId] of the constant item.
def_id: DefId,
/// If `false`, then `def_id` points at a `const` item, otherwise it
/// corresponds to a local inline const.
is_inline: bool,
/// If the inline constant is used in a range pattern, this subpattern
/// represents the range (if both ends are inline constants, there will
/// be multiple InlineConstant wrappers).
/// The pattern that the constant lowered to.
///
/// Otherwise, the actual pattern that the constant lowered to. As with
/// other constants, inline constants are matched structurally where
/// possible.
/// HACK: we need to keep the `DefId` of inline constants around for unsafety checking;
/// therefore when a range pattern contains inline constants, we re-wrap the range pattern
/// with the `ExpandedConstant` nodes that correspond to the range endpoints. Hence
/// `subpattern` may actually be a range pattern, and `def_id` be the constant for one of
/// its endpoints.
subpattern: Box<Pat<'tcx>>,
},

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ impl<'a, 'tcx> ParseCtxt<'a, 'tcx> {
let arm = &self.thir[*arm];
let value = match arm.pattern.kind {
PatKind::Constant { value } => value,
PatKind::ExpandedConstant { ref subpattern, def_id: _, is_inline: false }
PatKind::ExpandedConstant { ref subpattern, def_id: _ }
if let PatKind::Constant { value } = subpattern.kind =>
{
value
Expand Down
31 changes: 1 addition & 30 deletions compiler/rustc_mir_build/src/builder/matches/match_pair.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,39 +201,10 @@ impl<'tcx> MatchPairTree<'tcx> {
None
}

PatKind::ExpandedConstant { subpattern: ref pattern, def_id: _, is_inline: false } => {
PatKind::ExpandedConstant { subpattern: ref pattern, .. } => {
MatchPairTree::for_pattern(place_builder, pattern, cx, &mut subpairs, extra_data);
None
}
PatKind::ExpandedConstant { subpattern: ref pattern, def_id, is_inline: true } => {
MatchPairTree::for_pattern(place_builder, pattern, cx, &mut subpairs, extra_data);

// Apply a type ascription for the inline constant to the value at `match_pair.place`
if let Some(source) = place {
let span = pattern.span;
let parent_id = cx.tcx.typeck_root_def_id(cx.def_id.to_def_id());
let args = ty::InlineConstArgs::new(
cx.tcx,
ty::InlineConstArgsParts {
parent_args: ty::GenericArgs::identity_for_item(cx.tcx, parent_id),
ty: cx.infcx.next_ty_var(span),
},
)
.args;
let user_ty = cx.infcx.canonicalize_user_type_annotation(ty::UserType::new(
ty::UserTypeKind::TypeOf(def_id, ty::UserArgs { args, user_self_ty: None }),
));
let annotation = ty::CanonicalUserTypeAnnotation {
inferred_ty: pattern.ty,
span,
user_ty: Box::new(user_ty),
};
let variance = ty::Contravariant;
extra_data.ascriptions.push(super::Ascription { annotation, source, variance });
}

None
}

PatKind::Array { ref prefix, ref slice, ref suffix } => {
cx.prefix_slice_suffix(
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_mir_build/src/check_unsafety.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,9 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> {
visit::walk_pat(self, pat);
self.inside_adt = old_inside_adt;
}
PatKind::ExpandedConstant { def_id, is_inline, .. } => {
PatKind::ExpandedConstant { def_id, .. } => {
if let Some(def) = def_id.as_local()
&& *is_inline
&& matches!(self.tcx.def_kind(def_id), DefKind::InlineConst)
{
self.visit_inner_body(def);
}
Expand Down
5 changes: 3 additions & 2 deletions compiler/rustc_mir_build/src/thir/pattern/check_match.rs
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ impl<'p, 'tcx> MatchVisitor<'p, 'tcx> {
unpeeled_pat = subpattern;
}

if let PatKind::ExpandedConstant { def_id, is_inline: false, .. } = unpeeled_pat.kind
if let PatKind::ExpandedConstant { def_id, .. } = unpeeled_pat.kind
&& let DefKind::Const = self.tcx.def_kind(def_id)
&& let Ok(snippet) = self.tcx.sess.source_map().span_to_snippet(pat.span)
// We filter out paths with multiple path::segments.
Expand Down Expand Up @@ -1296,7 +1296,8 @@ fn report_non_exhaustive_match<'p, 'tcx>(

for &arm in arms {
let arm = &thir.arms[arm];
if let PatKind::ExpandedConstant { def_id, is_inline: false, .. } = arm.pattern.kind
if let PatKind::ExpandedConstant { def_id, .. } = arm.pattern.kind
&& !matches!(cx.tcx.def_kind(def_id), DefKind::InlineConst)
&& let Ok(snippet) = cx.tcx.sess.source_map().span_to_snippet(arm.pattern.span)
// We filter out paths with multiple path::segments.
&& snippet.chars().all(|c| c.is_alphanumeric() || c == '_')
Expand Down
5 changes: 4 additions & 1 deletion compiler/rustc_mir_build/src/thir/pattern/const_to_pat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,10 @@ impl<'tcx> ConstToPat<'tcx> {
}
}

inlined_const_as_pat
// Wrap the pattern in a marker node to indicate that it is the result of lowering a
// constant. This is used for diagnostics, and for unsafety checking of inline const blocks.
let kind = PatKind::ExpandedConstant { subpattern: inlined_const_as_pat, def_id: uv.def };
Box::new(Pat { kind, ty, span: self.span })
}

fn field_pats(
Expand Down
147 changes: 75 additions & 72 deletions compiler/rustc_mir_build/src/thir/pattern/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ use rustc_hir::def::{CtorOf, DefKind, Res};
use rustc_hir::pat_util::EnumerateAndAdjustIterator;
use rustc_hir::{self as hir, LangItem, RangeEnd};
use rustc_index::Idx;
use rustc_infer::infer::TyCtxtInferExt;
use rustc_middle::mir::interpret::LitToConstInput;
use rustc_middle::thir::{
Ascription, FieldPat, LocalVarId, Pat, PatKind, PatRange, PatRangeBoundary,
};
use rustc_middle::ty::layout::IntegerExt;
use rustc_middle::ty::{self, CanonicalUserTypeAnnotation, Ty, TyCtxt, TypeVisitableExt};
use rustc_middle::ty::{self, CanonicalUserTypeAnnotation, Ty, TyCtxt, TypingMode};
use rustc_middle::{bug, span_bug};
use rustc_span::def_id::LocalDefId;
use rustc_span::def_id::DefId;
use rustc_span::{ErrorGuaranteed, Span};
use tracing::{debug, instrument};

Expand Down Expand Up @@ -124,7 +125,7 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
expr: Option<&'tcx hir::PatExpr<'tcx>>,
// Out-parameters collecting extra data to be reapplied by the caller
ascriptions: &mut Vec<Ascription<'tcx>>,
inline_consts: &mut Vec<LocalDefId>,
expanded_consts: &mut Vec<DefId>,
) -> Result<Option<PatRangeBoundary<'tcx>>, ErrorGuaranteed> {
let Some(expr) = expr else { return Ok(None) };

Expand All @@ -139,10 +140,8 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
ascriptions.push(ascription);
kind = subpattern.kind;
}
PatKind::ExpandedConstant { is_inline, def_id, subpattern } => {
if is_inline {
inline_consts.extend(def_id.as_local());
}
PatKind::ExpandedConstant { def_id, subpattern } => {
expanded_consts.push(def_id);
kind = subpattern.kind;
}
_ => break,
Expand Down Expand Up @@ -221,10 +220,10 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {

// Collect extra data while lowering the endpoints, to be reapplied later.
let mut ascriptions = vec![];
let mut inline_consts = vec![];
let mut expanded_consts = vec![];

let mut lower_endpoint =
|expr| self.lower_pattern_range_endpoint(expr, &mut ascriptions, &mut inline_consts);
|expr| self.lower_pattern_range_endpoint(expr, &mut ascriptions, &mut expanded_consts);

let lo = lower_endpoint(lo_expr)?.unwrap_or(PatRangeBoundary::NegInfinity);
let hi = lower_endpoint(hi_expr)?.unwrap_or(PatRangeBoundary::PosInfinity);
Expand Down Expand Up @@ -269,17 +268,12 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
// `Foo::<'a>::A..=Foo::B`), we need to put the ascriptions for the associated
// constants somewhere. Have them on the range pattern.
for ascription in ascriptions {
kind = PatKind::AscribeUserType {
ascription,
subpattern: Box::new(Pat { span, ty, kind }),
};
let subpattern = Box::new(Pat { span, ty, kind });
kind = PatKind::AscribeUserType { ascription, subpattern };
}
for def in inline_consts {
kind = PatKind::ExpandedConstant {
def_id: def.to_def_id(),
is_inline: true,
subpattern: Box::new(Pat { span, ty, kind }),
};
for def_id in expanded_consts {
let subpattern = Box::new(Pat { span, ty, kind });
kind = PatKind::ExpandedConstant { def_id, subpattern };
}
Ok(kind)
}
Expand Down Expand Up @@ -567,15 +561,7 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
// Lower the named constant to a THIR pattern.
let args = self.typeck_results.node_args(id);
let c = ty::Const::new_unevaluated(self.tcx, ty::UnevaluatedConst { def: def_id, args });
let subpattern = self.const_to_pat(c, ty, id, span);

// Wrap the pattern in a marker node to indicate that it is the result
// of lowering a named constant. This marker is used for improved
// diagnostics in some situations, but has no effect at runtime.
let mut pattern = {
let kind = PatKind::ExpandedConstant { subpattern, def_id, is_inline: false };
Box::new(Pat { span, ty, kind })
};
let mut pattern = self.const_to_pat(c, ty, id, span);

// If this is an associated constant with an explicit user-written
// type, add an ascription node (e.g. `<Foo<'a> as MyTrait>::CONST`).
Expand Down Expand Up @@ -612,18 +598,37 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
let ty = tcx.typeck(def_id).node_type(block.hir_id);

let typeck_root_def_id = tcx.typeck_root_def_id(def_id.to_def_id());
let parent_args =
tcx.erase_regions(ty::GenericArgs::identity_for_item(tcx, typeck_root_def_id));
let parent_args = ty::GenericArgs::identity_for_item(tcx, typeck_root_def_id);
let args = ty::InlineConstArgs::new(tcx, ty::InlineConstArgsParts { parent_args, ty }).args;

debug_assert!(!args.has_free_regions());

let ct = ty::UnevaluatedConst { def: def_id.to_def_id(), args };
let subpattern = self.const_to_pat(ty::Const::new_unevaluated(self.tcx, ct), ty, id, span);

// Wrap the pattern in a marker node to indicate that it is the result
// of lowering an inline const block.
PatKind::ExpandedConstant { subpattern, def_id: def_id.to_def_id(), is_inline: true }
let c = ty::Const::new_unevaluated(self.tcx, ct);
let pattern = self.const_to_pat(c, ty, id, span);

// Apply a type ascription for the inline constant.
let annotation = {
let infcx = tcx.infer_ctxt().build(TypingMode::non_body_analysis());
let args = ty::InlineConstArgs::new(
tcx,
ty::InlineConstArgsParts { parent_args, ty: infcx.next_ty_var(span) },
)
.args;
infcx.canonicalize_user_type_annotation(ty::UserType::new(ty::UserTypeKind::TypeOf(
def_id.to_def_id(),
ty::UserArgs { args, user_self_ty: None },
)))
};
let annotation =
CanonicalUserTypeAnnotation { user_ty: Box::new(annotation), span, inferred_ty: ty };
PatKind::AscribeUserType {
subpattern: pattern,
ascription: Ascription {
annotation,
// Note that we use `Contravariant` here. See the `variance` field documentation
// for details.
variance: ty::Contravariant,
},
}
}

/// Lowers the kinds of "expression" that can appear in a HIR pattern:
Expand All @@ -635,43 +640,41 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
expr: &'tcx hir::PatExpr<'tcx>,
pat_ty: Option<Ty<'tcx>>,
) -> PatKind<'tcx> {
let (lit, neg) = match &expr.kind {
hir::PatExprKind::Path(qpath) => {
return self.lower_path(qpath, expr.hir_id, expr.span).kind;
}
match &expr.kind {
hir::PatExprKind::Path(qpath) => self.lower_path(qpath, expr.hir_id, expr.span).kind,
hir::PatExprKind::ConstBlock(anon_const) => {
return self.lower_inline_const(anon_const, expr.hir_id, expr.span);
self.lower_inline_const(anon_const, expr.hir_id, expr.span)
}
hir::PatExprKind::Lit { lit, negated } => (lit, *negated),
};

// We handle byte string literal patterns by using the pattern's type instead of the
// literal's type in `const_to_pat`: if the literal `b"..."` matches on a slice reference,
// the pattern's type will be `&[u8]` whereas the literal's type is `&[u8; 3]`; using the
// pattern's type means we'll properly translate it to a slice reference pattern. This works
// because slices and arrays have the same valtree representation.
// HACK: As an exception, use the literal's type if `pat_ty` is `String`; this can happen if
// `string_deref_patterns` is enabled. There's a special case for that when lowering to MIR.
// FIXME(deref_patterns): This hack won't be necessary once `string_deref_patterns` is
// superseded by a more general implementation of deref patterns.
let ct_ty = match pat_ty {
Some(pat_ty)
if let ty::Adt(def, _) = *pat_ty.kind()
&& self.tcx.is_lang_item(def.did(), LangItem::String) =>
{
if !self.tcx.features().string_deref_patterns() {
span_bug!(
expr.span,
"matching on `String` went through without enabling string_deref_patterns"
);
}
self.typeck_results.node_type(expr.hir_id)
hir::PatExprKind::Lit { lit, negated } => {
// We handle byte string literal patterns by using the pattern's type instead of the
// literal's type in `const_to_pat`: if the literal `b"..."` matches on a slice reference,
// the pattern's type will be `&[u8]` whereas the literal's type is `&[u8; 3]`; using the
// pattern's type means we'll properly translate it to a slice reference pattern. This works
// because slices and arrays have the same valtree representation.
// HACK: As an exception, use the literal's type if `pat_ty` is `String`; this can happen if
// `string_deref_patterns` is enabled. There's a special case for that when lowering to MIR.
// FIXME(deref_patterns): This hack won't be necessary once `string_deref_patterns` is
// superseded by a more general implementation of deref patterns.
let ct_ty = match pat_ty {
Some(pat_ty)
if let ty::Adt(def, _) = *pat_ty.kind()
&& self.tcx.is_lang_item(def.did(), LangItem::String) =>
{
if !self.tcx.features().string_deref_patterns() {
span_bug!(
expr.span,
"matching on `String` went through without enabling string_deref_patterns"
);
}
self.typeck_results.node_type(expr.hir_id)
}
Some(pat_ty) => pat_ty,
None => self.typeck_results.node_type(expr.hir_id),
};
let lit_input = LitToConstInput { lit: &lit.node, ty: ct_ty, neg: *negated };
let constant = self.tcx.at(expr.span).lit_to_const(lit_input);
self.const_to_pat(constant, ct_ty, expr.hir_id, lit.span).kind
}
Some(pat_ty) => pat_ty,
None => self.typeck_results.node_type(expr.hir_id),
};
let lit_input = LitToConstInput { lit: &lit.node, ty: ct_ty, neg };
let constant = self.tcx.at(expr.span).lit_to_const(lit_input);
self.const_to_pat(constant, ct_ty, expr.hir_id, lit.span).kind
}
}
}
3 changes: 1 addition & 2 deletions compiler/rustc_mir_build/src/thir/print.rs
Original file line number Diff line number Diff line change
Expand Up @@ -740,10 +740,9 @@ impl<'a, 'tcx> ThirPrinter<'a, 'tcx> {
print_indented!(self, format!("value: {:?}", value), depth_lvl + 2);
print_indented!(self, "}", depth_lvl + 1);
}
PatKind::ExpandedConstant { def_id, is_inline, subpattern } => {
PatKind::ExpandedConstant { def_id, subpattern } => {
print_indented!(self, "ExpandedConstant {", depth_lvl + 1);
print_indented!(self, format!("def_id: {def_id:?}"), depth_lvl + 2);
print_indented!(self, format!("is_inline: {is_inline:?}"), depth_lvl + 2);
print_indented!(self, "subpattern:", depth_lvl + 2);
self.print_pat(subpattern, depth_lvl + 2);
print_indented!(self, "}", depth_lvl + 1);
Expand Down
Loading