Skip to content

fix: try to infer array type from slice pattern #19066

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/hir-ty/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,7 @@ impl<'a> InferenceContext<'a> {
let ty = self.insert_type_vars(ty);
let ty = self.normalize_associated_types_in(ty);

self.infer_top_pat(*pat, &ty);
self.infer_top_pat(*pat, &ty, None);
if ty
.data(Interner)
.flags
Expand Down
23 changes: 15 additions & 8 deletions crates/hir-ty/src/infer/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ use crate::{
primitive::{self, UintTy},
static_lifetime, to_chalk_trait_id,
traits::FnTrait,
Adjust, Adjustment, AdtId, AutoBorrow, Binders, CallableDefId, CallableSig, FnAbi, FnPointer,
FnSig, FnSubst, Interner, Rawness, Scalar, Substitution, TraitEnvironment, TraitRef, Ty,
TyBuilder, TyExt, TyKind,
Adjust, Adjustment, AdtId, AutoBorrow, Binders, CallableDefId, CallableSig, DeclContext,
DeclOrigin, FnAbi, FnPointer, FnSig, FnSubst, Interner, Rawness, Scalar, Substitution,
TraitEnvironment, TraitRef, Ty, TyBuilder, TyExt, TyKind,
};

use super::{
Expand Down Expand Up @@ -334,7 +334,11 @@ impl InferenceContext<'_> {
ExprIsRead::No
};
let input_ty = self.infer_expr(expr, &Expectation::none(), child_is_read);
self.infer_top_pat(pat, &input_ty);
self.infer_top_pat(
pat,
&input_ty,
Some(DeclContext { origin: DeclOrigin::LetExpr }),
);
self.result.standard_types.bool_.clone()
}
Expr::Block { statements, tail, label, id } => {
Expand Down Expand Up @@ -461,7 +465,7 @@ impl InferenceContext<'_> {

// Now go through the argument patterns
for (arg_pat, arg_ty) in args.iter().zip(&sig_tys) {
self.infer_top_pat(*arg_pat, arg_ty);
self.infer_top_pat(*arg_pat, arg_ty, None);
}

// FIXME: lift these out into a struct
Expand Down Expand Up @@ -582,7 +586,7 @@ impl InferenceContext<'_> {
let mut all_arms_diverge = Diverges::Always;
for arm in arms.iter() {
let input_ty = self.resolve_ty_shallow(&input_ty);
self.infer_top_pat(arm.pat, &input_ty);
self.infer_top_pat(arm.pat, &input_ty, None);
}

let expected = expected.adjust_for_branches(&mut self.table);
Expand Down Expand Up @@ -927,7 +931,7 @@ impl InferenceContext<'_> {
let resolver_guard =
self.resolver.update_to_inner_scope(self.db.upcast(), self.owner, tgt_expr);
self.inside_assignment = true;
self.infer_top_pat(target, &rhs_ty);
self.infer_top_pat(target, &rhs_ty, None);
self.inside_assignment = false;
self.resolver.reset_to_guard(resolver_guard);
}
Expand Down Expand Up @@ -1632,8 +1636,11 @@ impl InferenceContext<'_> {
decl_ty
};

this.infer_top_pat(*pat, &ty);
let decl = DeclContext {
origin: DeclOrigin::LocalDecl { has_else: else_branch.is_some() },
};

this.infer_top_pat(*pat, &ty, Some(decl));
if let Some(expr) = else_branch {
let previous_diverges =
mem::replace(&mut this.diverges, Diverges::Maybe);
Expand Down
125 changes: 103 additions & 22 deletions crates/hir-ty/src/infer/pat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,21 @@ use hir_def::{
expr_store::Body,
hir::{Binding, BindingAnnotation, BindingId, Expr, ExprId, Literal, Pat, PatId},
path::Path,
HasModule,
};
use hir_expand::name::Name;
use stdx::TupleExt;

use crate::{
consteval::{try_const_usize, usize_const},
consteval::{self, try_const_usize, usize_const},
infer::{
coerce::CoerceNever, expr::ExprIsRead, BindingMode, Expectation, InferenceContext,
TypeMismatch,
},
lower::lower_to_chalk_mutability,
primitive::UintTy,
static_lifetime, InferenceDiagnostic, Interner, Mutability, Scalar, Substitution, Ty,
TyBuilder, TyExt, TyKind,
static_lifetime, DeclContext, DeclOrigin, InferenceDiagnostic, Interner, Mutability, Scalar,
Substitution, Ty, TyBuilder, TyExt, TyKind,
};

impl InferenceContext<'_> {
Expand All @@ -34,6 +35,7 @@ impl InferenceContext<'_> {
id: PatId,
ellipsis: Option<u32>,
subs: &[PatId],
decl: Option<DeclContext>,
) -> Ty {
let (ty, def) = self.resolve_variant(id.into(), path, true);
let var_data = def.map(|it| it.variant_data(self.db.upcast()));
Expand Down Expand Up @@ -92,13 +94,13 @@ impl InferenceContext<'_> {
}
};

self.infer_pat(subpat, &expected_ty, default_bm);
self.infer_pat(subpat, &expected_ty, default_bm, decl);
}
}
None => {
let err_ty = self.err_ty();
for &inner in subs {
self.infer_pat(inner, &err_ty, default_bm);
self.infer_pat(inner, &err_ty, default_bm, decl);
}
}
}
Expand All @@ -114,6 +116,7 @@ impl InferenceContext<'_> {
default_bm: BindingMode,
id: PatId,
subs: impl ExactSizeIterator<Item = (Name, PatId)>,
decl: Option<DeclContext>,
) -> Ty {
let (ty, def) = self.resolve_variant(id.into(), path, false);
if let Some(variant) = def {
Expand Down Expand Up @@ -162,13 +165,13 @@ impl InferenceContext<'_> {
}
};

self.infer_pat(inner, &expected_ty, default_bm);
self.infer_pat(inner, &expected_ty, default_bm, decl);
}
}
None => {
let err_ty = self.err_ty();
for (_, inner) in subs {
self.infer_pat(inner, &err_ty, default_bm);
self.infer_pat(inner, &err_ty, default_bm, decl);
}
}
}
Expand All @@ -185,6 +188,7 @@ impl InferenceContext<'_> {
default_bm: BindingMode,
ellipsis: Option<u32>,
subs: &[PatId],
decl: Option<DeclContext>,
) -> Ty {
let expected = self.resolve_ty_shallow(expected);
let expectations = match expected.as_tuple() {
Expand All @@ -209,12 +213,12 @@ impl InferenceContext<'_> {

// Process pre
for (ty, pat) in inner_tys.iter_mut().zip(pre) {
*ty = self.infer_pat(*pat, ty, default_bm);
*ty = self.infer_pat(*pat, ty, default_bm, decl);
}

// Process post
for (ty, pat) in inner_tys.iter_mut().skip(pre.len() + n_uncovered_patterns).zip(post) {
*ty = self.infer_pat(*pat, ty, default_bm);
*ty = self.infer_pat(*pat, ty, default_bm, decl);
}

TyKind::Tuple(inner_tys.len(), Substitution::from_iter(Interner, inner_tys))
Expand All @@ -223,11 +227,17 @@ impl InferenceContext<'_> {

/// The resolver needs to be updated to the surrounding expression when inside assignment
/// (because there, `Pat::Path` can refer to a variable).
pub(super) fn infer_top_pat(&mut self, pat: PatId, expected: &Ty) {
self.infer_pat(pat, expected, BindingMode::default());
pub(super) fn infer_top_pat(&mut self, pat: PatId, expected: &Ty, decl: Option<DeclContext>) {
self.infer_pat(pat, expected, BindingMode::default(), decl);
}

fn infer_pat(&mut self, pat: PatId, expected: &Ty, mut default_bm: BindingMode) -> Ty {
fn infer_pat(
&mut self,
pat: PatId,
expected: &Ty,
mut default_bm: BindingMode,
decl: Option<DeclContext>,
) -> Ty {
let mut expected = self.resolve_ty_shallow(expected);

if matches!(&self.body[pat], Pat::Ref { .. }) || self.inside_assignment {
Expand Down Expand Up @@ -261,11 +271,11 @@ impl InferenceContext<'_> {

let ty = match &self.body[pat] {
Pat::Tuple { args, ellipsis } => {
self.infer_tuple_pat_like(&expected, default_bm, *ellipsis, args)
self.infer_tuple_pat_like(&expected, default_bm, *ellipsis, args, decl)
}
Pat::Or(pats) => {
for pat in pats.iter() {
self.infer_pat(*pat, &expected, default_bm);
self.infer_pat(*pat, &expected, default_bm, decl);
}
expected.clone()
}
Expand All @@ -274,6 +284,7 @@ impl InferenceContext<'_> {
lower_to_chalk_mutability(mutability),
&expected,
default_bm,
decl,
),
Pat::TupleStruct { path: p, args: subpats, ellipsis } => self
.infer_tuple_struct_pat_like(
Expand All @@ -283,10 +294,11 @@ impl InferenceContext<'_> {
pat,
*ellipsis,
subpats,
decl,
),
Pat::Record { path: p, args: fields, ellipsis: _ } => {
let subs = fields.iter().map(|f| (f.name.clone(), f.pat));
self.infer_record_pat_like(p.as_deref(), &expected, default_bm, pat, subs)
self.infer_record_pat_like(p.as_deref(), &expected, default_bm, pat, subs, decl)
}
Pat::Path(path) => {
let ty = self.infer_path(path, pat.into()).unwrap_or_else(|| self.err_ty());
Expand Down Expand Up @@ -319,10 +331,10 @@ impl InferenceContext<'_> {
}
}
Pat::Bind { id, subpat } => {
return self.infer_bind_pat(pat, *id, default_bm, *subpat, &expected);
return self.infer_bind_pat(pat, *id, default_bm, *subpat, &expected, decl);
}
Pat::Slice { prefix, slice, suffix } => {
self.infer_slice_pat(&expected, prefix, slice, suffix, default_bm)
self.infer_slice_pat(&expected, prefix, slice, suffix, default_bm, decl)
}
Pat::Wild => expected.clone(),
Pat::Range { .. } => {
Expand All @@ -345,7 +357,7 @@ impl InferenceContext<'_> {
_ => (self.result.standard_types.unknown.clone(), None),
};

let inner_ty = self.infer_pat(*inner, &inner_ty, default_bm);
let inner_ty = self.infer_pat(*inner, &inner_ty, default_bm, decl);
let mut b = TyBuilder::adt(self.db, box_adt).push(inner_ty);

if let Some(alloc_ty) = alloc_ty {
Expand Down Expand Up @@ -420,6 +432,7 @@ impl InferenceContext<'_> {
mutability: Mutability,
expected: &Ty,
default_bm: BindingMode,
decl: Option<DeclContext>,
) -> Ty {
let (expectation_type, expectation_lt) = match expected.as_reference() {
Some((inner_ty, lifetime, _exp_mut)) => (inner_ty.clone(), lifetime.clone()),
Expand All @@ -433,7 +446,7 @@ impl InferenceContext<'_> {
(inner_ty, inner_lt)
}
};
let subty = self.infer_pat(inner_pat, &expectation_type, default_bm);
let subty = self.infer_pat(inner_pat, &expectation_type, default_bm, decl);
TyKind::Ref(mutability, expectation_lt, subty).intern(Interner)
}

Expand All @@ -444,6 +457,7 @@ impl InferenceContext<'_> {
default_bm: BindingMode,
subpat: Option<PatId>,
expected: &Ty,
decl: Option<DeclContext>,
) -> Ty {
let Binding { mode, .. } = self.body.bindings[binding];
let mode = if mode == BindingAnnotation::Unannotated {
Expand All @@ -454,7 +468,7 @@ impl InferenceContext<'_> {
self.result.binding_modes.insert(pat, mode);

let inner_ty = match subpat {
Some(subpat) => self.infer_pat(subpat, expected, default_bm),
Some(subpat) => self.infer_pat(subpat, expected, default_bm, decl),
None => expected.clone(),
};
let inner_ty = self.insert_type_vars_shallow(inner_ty);
Expand All @@ -478,14 +492,28 @@ impl InferenceContext<'_> {
slice: &Option<PatId>,
suffix: &[PatId],
default_bm: BindingMode,
decl: Option<DeclContext>,
) -> Ty {
let expected = self.resolve_ty_shallow(expected);

// If `expected` is an infer ty, we try to equate it to an array if the given pattern
// allows it. See issue #16609
if self.pat_is_irrefutable(decl) && expected.is_ty_var() {
if let Some(resolved_array_ty) =
self.try_resolve_slice_ty_to_array_ty(prefix, suffix, slice)
{
self.unify(&expected, &resolved_array_ty);
}
}

let expected = self.resolve_ty_shallow(&expected);
let elem_ty = match expected.kind(Interner) {
TyKind::Array(st, _) | TyKind::Slice(st) => st.clone(),
_ => self.err_ty(),
};

for &pat_id in prefix.iter().chain(suffix.iter()) {
self.infer_pat(pat_id, &elem_ty, default_bm);
self.infer_pat(pat_id, &elem_ty, default_bm, decl);
}

if let &Some(slice_pat_id) = slice {
Expand All @@ -499,7 +527,7 @@ impl InferenceContext<'_> {
_ => TyKind::Slice(elem_ty.clone()),
}
.intern(Interner);
self.infer_pat(slice_pat_id, &rest_pat_ty, default_bm);
self.infer_pat(slice_pat_id, &rest_pat_ty, default_bm, decl);
}

match expected.kind(Interner) {
Expand Down Expand Up @@ -553,6 +581,59 @@ impl InferenceContext<'_> {
| Pat::Expr(_) => false,
}
}

fn try_resolve_slice_ty_to_array_ty(
&mut self,
before: &[PatId],
suffix: &[PatId],
slice: &Option<PatId>,
) -> Option<Ty> {
if !slice.is_none() {
return None;
}

let len = before.len() + suffix.len();
let size =
consteval::usize_const(self.db, Some(len as u128), self.owner.krate(self.db.upcast()));

let elem_ty = self.table.new_type_var();
let array_ty = TyKind::Array(elem_ty.clone(), size).intern(Interner);
Some(array_ty)
}

/// Used to determine whether we can infer the expected type in the slice pattern to be of type array.
/// This is only possible if we're in an irrefutable pattern. If we were to allow this in refutable
/// patterns we wouldn't e.g. report ambiguity in the following situation:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compiler doesn't report ambiguity (https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=574bb654d0bd9905773ec8a82246c446), are you sure this logic is correct?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's copied from the corresponding rustc PR, though not sure why the text wasnt just copied verbatim https://github.com/rust-lang/rust/pull/113199/files#diff-9f26e4ec8d6ac64edbb3532a590592556b268b0e33f9fd4264d5d449aebbecf7R2061-R2090

///
/// ```ignore(rust)
/// struct Zeroes;
/// const ARR: [usize; 2] = [0; 2];
/// const ARR2: [usize; 2] = [2; 2];
///
/// impl Into<&'static [usize; 2]> for Zeroes {
/// fn into(self) -> &'static [usize; 2] {
/// &ARR
/// }
/// }
///
/// impl Into<&'static [usize]> for Zeroes {
/// fn into(self) -> &'static [usize] {
/// &ARR2
/// }
/// }
///
/// fn main() {
/// let &[a, b]: &[usize] = Zeroes.into() else {
/// ..
/// };
/// }
/// ```
///
/// If we're in an irrefutable pattern we prefer the array impl candidate given that
/// the slice impl candidate would be rejected anyway (if no ambiguity existed).
fn pat_is_irrefutable(&self, decl_ctxt: Option<DeclContext>) -> bool {
matches!(decl_ctxt, Some(DeclContext { origin: DeclOrigin::LocalDecl { has_else: false } }))
}
}

pub(super) fn contains_explicit_ref_binding(body: &Body, pat_id: PatId) -> bool {
Expand Down
Loading