diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs index 3c258e3c4cf4..59c6ba18eee1 100644 --- a/crates/hir-ty/src/infer.rs +++ b/crates/hir-ty/src/infer.rs @@ -943,7 +943,7 @@ impl<'a> InferenceContext<'a> { let ty = self.insert_type_vars(ty); let ty = self.normalize_associated_types_in(ty); - self.infer_top_pat(*pat, &ty); + self.infer_top_pat(*pat, &ty, None); if ty .data(Interner) .flags diff --git a/crates/hir-ty/src/infer/expr.rs b/crates/hir-ty/src/infer/expr.rs index b951443897cb..86e5afdb5092 100644 --- a/crates/hir-ty/src/infer/expr.rs +++ b/crates/hir-ty/src/infer/expr.rs @@ -43,9 +43,9 @@ use crate::{ primitive::{self, UintTy}, static_lifetime, to_chalk_trait_id, traits::FnTrait, - Adjust, Adjustment, AdtId, AutoBorrow, Binders, CallableDefId, CallableSig, FnAbi, FnPointer, - FnSig, FnSubst, Interner, Rawness, Scalar, Substitution, TraitEnvironment, TraitRef, Ty, - TyBuilder, TyExt, TyKind, + Adjust, Adjustment, AdtId, AutoBorrow, Binders, CallableDefId, CallableSig, DeclContext, + DeclOrigin, FnAbi, FnPointer, FnSig, FnSubst, Interner, Rawness, Scalar, Substitution, + TraitEnvironment, TraitRef, Ty, TyBuilder, TyExt, TyKind, }; use super::{ @@ -334,7 +334,11 @@ impl InferenceContext<'_> { ExprIsRead::No }; let input_ty = self.infer_expr(expr, &Expectation::none(), child_is_read); - self.infer_top_pat(pat, &input_ty); + self.infer_top_pat( + pat, + &input_ty, + Some(DeclContext { origin: DeclOrigin::LetExpr }), + ); self.result.standard_types.bool_.clone() } Expr::Block { statements, tail, label, id } => { @@ -461,7 +465,7 @@ impl InferenceContext<'_> { // Now go through the argument patterns for (arg_pat, arg_ty) in args.iter().zip(&sig_tys) { - self.infer_top_pat(*arg_pat, arg_ty); + self.infer_top_pat(*arg_pat, arg_ty, None); } // FIXME: lift these out into a struct @@ -582,7 +586,7 @@ impl InferenceContext<'_> { let mut all_arms_diverge = Diverges::Always; for arm in arms.iter() { let input_ty = self.resolve_ty_shallow(&input_ty); - self.infer_top_pat(arm.pat, &input_ty); + self.infer_top_pat(arm.pat, &input_ty, None); } let expected = expected.adjust_for_branches(&mut self.table); @@ -927,7 +931,7 @@ impl InferenceContext<'_> { let resolver_guard = self.resolver.update_to_inner_scope(self.db.upcast(), self.owner, tgt_expr); self.inside_assignment = true; - self.infer_top_pat(target, &rhs_ty); + self.infer_top_pat(target, &rhs_ty, None); self.inside_assignment = false; self.resolver.reset_to_guard(resolver_guard); } @@ -1632,8 +1636,11 @@ impl InferenceContext<'_> { decl_ty }; - this.infer_top_pat(*pat, &ty); + let decl = DeclContext { + origin: DeclOrigin::LocalDecl { has_else: else_branch.is_some() }, + }; + this.infer_top_pat(*pat, &ty, Some(decl)); if let Some(expr) = else_branch { let previous_diverges = mem::replace(&mut this.diverges, Diverges::Maybe); diff --git a/crates/hir-ty/src/infer/pat.rs b/crates/hir-ty/src/infer/pat.rs index ca8d5bae5e50..5ff22bea34de 100644 --- a/crates/hir-ty/src/infer/pat.rs +++ b/crates/hir-ty/src/infer/pat.rs @@ -6,20 +6,21 @@ use hir_def::{ expr_store::Body, hir::{Binding, BindingAnnotation, BindingId, Expr, ExprId, Literal, Pat, PatId}, path::Path, + HasModule, }; use hir_expand::name::Name; use stdx::TupleExt; use crate::{ - consteval::{try_const_usize, usize_const}, + consteval::{self, try_const_usize, usize_const}, infer::{ coerce::CoerceNever, expr::ExprIsRead, BindingMode, Expectation, InferenceContext, TypeMismatch, }, lower::lower_to_chalk_mutability, primitive::UintTy, - static_lifetime, InferenceDiagnostic, Interner, Mutability, Scalar, Substitution, Ty, - TyBuilder, TyExt, TyKind, + static_lifetime, DeclContext, DeclOrigin, InferenceDiagnostic, Interner, Mutability, Scalar, + Substitution, Ty, TyBuilder, TyExt, TyKind, }; impl InferenceContext<'_> { @@ -34,6 +35,7 @@ impl InferenceContext<'_> { id: PatId, ellipsis: Option, subs: &[PatId], + decl: Option, ) -> Ty { let (ty, def) = self.resolve_variant(id.into(), path, true); let var_data = def.map(|it| it.variant_data(self.db.upcast())); @@ -92,13 +94,13 @@ impl InferenceContext<'_> { } }; - self.infer_pat(subpat, &expected_ty, default_bm); + self.infer_pat(subpat, &expected_ty, default_bm, decl); } } None => { let err_ty = self.err_ty(); for &inner in subs { - self.infer_pat(inner, &err_ty, default_bm); + self.infer_pat(inner, &err_ty, default_bm, decl); } } } @@ -114,6 +116,7 @@ impl InferenceContext<'_> { default_bm: BindingMode, id: PatId, subs: impl ExactSizeIterator, + decl: Option, ) -> Ty { let (ty, def) = self.resolve_variant(id.into(), path, false); if let Some(variant) = def { @@ -162,13 +165,13 @@ impl InferenceContext<'_> { } }; - self.infer_pat(inner, &expected_ty, default_bm); + self.infer_pat(inner, &expected_ty, default_bm, decl); } } None => { let err_ty = self.err_ty(); for (_, inner) in subs { - self.infer_pat(inner, &err_ty, default_bm); + self.infer_pat(inner, &err_ty, default_bm, decl); } } } @@ -185,6 +188,7 @@ impl InferenceContext<'_> { default_bm: BindingMode, ellipsis: Option, subs: &[PatId], + decl: Option, ) -> Ty { let expected = self.resolve_ty_shallow(expected); let expectations = match expected.as_tuple() { @@ -209,12 +213,12 @@ impl InferenceContext<'_> { // Process pre for (ty, pat) in inner_tys.iter_mut().zip(pre) { - *ty = self.infer_pat(*pat, ty, default_bm); + *ty = self.infer_pat(*pat, ty, default_bm, decl); } // Process post for (ty, pat) in inner_tys.iter_mut().skip(pre.len() + n_uncovered_patterns).zip(post) { - *ty = self.infer_pat(*pat, ty, default_bm); + *ty = self.infer_pat(*pat, ty, default_bm, decl); } TyKind::Tuple(inner_tys.len(), Substitution::from_iter(Interner, inner_tys)) @@ -223,11 +227,17 @@ impl InferenceContext<'_> { /// The resolver needs to be updated to the surrounding expression when inside assignment /// (because there, `Pat::Path` can refer to a variable). - pub(super) fn infer_top_pat(&mut self, pat: PatId, expected: &Ty) { - self.infer_pat(pat, expected, BindingMode::default()); + pub(super) fn infer_top_pat(&mut self, pat: PatId, expected: &Ty, decl: Option) { + self.infer_pat(pat, expected, BindingMode::default(), decl); } - fn infer_pat(&mut self, pat: PatId, expected: &Ty, mut default_bm: BindingMode) -> Ty { + fn infer_pat( + &mut self, + pat: PatId, + expected: &Ty, + mut default_bm: BindingMode, + decl: Option, + ) -> Ty { let mut expected = self.resolve_ty_shallow(expected); if matches!(&self.body[pat], Pat::Ref { .. }) || self.inside_assignment { @@ -261,11 +271,11 @@ impl InferenceContext<'_> { let ty = match &self.body[pat] { Pat::Tuple { args, ellipsis } => { - self.infer_tuple_pat_like(&expected, default_bm, *ellipsis, args) + self.infer_tuple_pat_like(&expected, default_bm, *ellipsis, args, decl) } Pat::Or(pats) => { for pat in pats.iter() { - self.infer_pat(*pat, &expected, default_bm); + self.infer_pat(*pat, &expected, default_bm, decl); } expected.clone() } @@ -274,6 +284,7 @@ impl InferenceContext<'_> { lower_to_chalk_mutability(mutability), &expected, default_bm, + decl, ), Pat::TupleStruct { path: p, args: subpats, ellipsis } => self .infer_tuple_struct_pat_like( @@ -283,10 +294,11 @@ impl InferenceContext<'_> { pat, *ellipsis, subpats, + decl, ), Pat::Record { path: p, args: fields, ellipsis: _ } => { let subs = fields.iter().map(|f| (f.name.clone(), f.pat)); - self.infer_record_pat_like(p.as_deref(), &expected, default_bm, pat, subs) + self.infer_record_pat_like(p.as_deref(), &expected, default_bm, pat, subs, decl) } Pat::Path(path) => { let ty = self.infer_path(path, pat.into()).unwrap_or_else(|| self.err_ty()); @@ -319,10 +331,10 @@ impl InferenceContext<'_> { } } Pat::Bind { id, subpat } => { - return self.infer_bind_pat(pat, *id, default_bm, *subpat, &expected); + return self.infer_bind_pat(pat, *id, default_bm, *subpat, &expected, decl); } Pat::Slice { prefix, slice, suffix } => { - self.infer_slice_pat(&expected, prefix, slice, suffix, default_bm) + self.infer_slice_pat(&expected, prefix, slice, suffix, default_bm, decl) } Pat::Wild => expected.clone(), Pat::Range { .. } => { @@ -345,7 +357,7 @@ impl InferenceContext<'_> { _ => (self.result.standard_types.unknown.clone(), None), }; - let inner_ty = self.infer_pat(*inner, &inner_ty, default_bm); + let inner_ty = self.infer_pat(*inner, &inner_ty, default_bm, decl); let mut b = TyBuilder::adt(self.db, box_adt).push(inner_ty); if let Some(alloc_ty) = alloc_ty { @@ -420,6 +432,7 @@ impl InferenceContext<'_> { mutability: Mutability, expected: &Ty, default_bm: BindingMode, + decl: Option, ) -> Ty { let (expectation_type, expectation_lt) = match expected.as_reference() { Some((inner_ty, lifetime, _exp_mut)) => (inner_ty.clone(), lifetime.clone()), @@ -433,7 +446,7 @@ impl InferenceContext<'_> { (inner_ty, inner_lt) } }; - let subty = self.infer_pat(inner_pat, &expectation_type, default_bm); + let subty = self.infer_pat(inner_pat, &expectation_type, default_bm, decl); TyKind::Ref(mutability, expectation_lt, subty).intern(Interner) } @@ -444,6 +457,7 @@ impl InferenceContext<'_> { default_bm: BindingMode, subpat: Option, expected: &Ty, + decl: Option, ) -> Ty { let Binding { mode, .. } = self.body.bindings[binding]; let mode = if mode == BindingAnnotation::Unannotated { @@ -454,7 +468,7 @@ impl InferenceContext<'_> { self.result.binding_modes.insert(pat, mode); let inner_ty = match subpat { - Some(subpat) => self.infer_pat(subpat, expected, default_bm), + Some(subpat) => self.infer_pat(subpat, expected, default_bm, decl), None => expected.clone(), }; let inner_ty = self.insert_type_vars_shallow(inner_ty); @@ -478,14 +492,28 @@ impl InferenceContext<'_> { slice: &Option, suffix: &[PatId], default_bm: BindingMode, + decl: Option, ) -> Ty { + let expected = self.resolve_ty_shallow(expected); + + // If `expected` is an infer ty, we try to equate it to an array if the given pattern + // allows it. See issue #16609 + if self.pat_is_irrefutable(decl) && expected.is_ty_var() { + if let Some(resolved_array_ty) = + self.try_resolve_slice_ty_to_array_ty(prefix, suffix, slice) + { + self.unify(&expected, &resolved_array_ty); + } + } + + let expected = self.resolve_ty_shallow(&expected); let elem_ty = match expected.kind(Interner) { TyKind::Array(st, _) | TyKind::Slice(st) => st.clone(), _ => self.err_ty(), }; for &pat_id in prefix.iter().chain(suffix.iter()) { - self.infer_pat(pat_id, &elem_ty, default_bm); + self.infer_pat(pat_id, &elem_ty, default_bm, decl); } if let &Some(slice_pat_id) = slice { @@ -499,7 +527,7 @@ impl InferenceContext<'_> { _ => TyKind::Slice(elem_ty.clone()), } .intern(Interner); - self.infer_pat(slice_pat_id, &rest_pat_ty, default_bm); + self.infer_pat(slice_pat_id, &rest_pat_ty, default_bm, decl); } match expected.kind(Interner) { @@ -553,6 +581,59 @@ impl InferenceContext<'_> { | Pat::Expr(_) => false, } } + + fn try_resolve_slice_ty_to_array_ty( + &mut self, + before: &[PatId], + suffix: &[PatId], + slice: &Option, + ) -> Option { + if !slice.is_none() { + return None; + } + + let len = before.len() + suffix.len(); + let size = + consteval::usize_const(self.db, Some(len as u128), self.owner.krate(self.db.upcast())); + + let elem_ty = self.table.new_type_var(); + let array_ty = TyKind::Array(elem_ty.clone(), size).intern(Interner); + Some(array_ty) + } + + /// Used to determine whether we can infer the expected type in the slice pattern to be of type array. + /// This is only possible if we're in an irrefutable pattern. If we were to allow this in refutable + /// patterns we wouldn't e.g. report ambiguity in the following situation: + /// + /// ```ignore(rust) + /// struct Zeroes; + /// const ARR: [usize; 2] = [0; 2]; + /// const ARR2: [usize; 2] = [2; 2]; + /// + /// impl Into<&'static [usize; 2]> for Zeroes { + /// fn into(self) -> &'static [usize; 2] { + /// &ARR + /// } + /// } + /// + /// impl Into<&'static [usize]> for Zeroes { + /// fn into(self) -> &'static [usize] { + /// &ARR2 + /// } + /// } + /// + /// fn main() { + /// let &[a, b]: &[usize] = Zeroes.into() else { + /// .. + /// }; + /// } + /// ``` + /// + /// If we're in an irrefutable pattern we prefer the array impl candidate given that + /// the slice impl candidate would be rejected anyway (if no ambiguity existed). + fn pat_is_irrefutable(&self, decl_ctxt: Option) -> bool { + matches!(decl_ctxt, Some(DeclContext { origin: DeclOrigin::LocalDecl { has_else: false } })) + } } pub(super) fn contains_explicit_ref_binding(body: &Body, pat_id: PatId) -> bool { diff --git a/crates/hir-ty/src/lib.rs b/crates/hir-ty/src/lib.rs index 4b159b7541e6..55d81875a2be 100644 --- a/crates/hir-ty/src/lib.rs +++ b/crates/hir-ty/src/lib.rs @@ -1049,3 +1049,20 @@ pub fn known_const_to_ast( } Some(make::expr_const_value(konst.display(db, edition).to_string().as_str())) } + +#[derive(Debug, Copy, Clone)] +pub(crate) enum DeclOrigin { + LetExpr, + /// from `let x = ..` + LocalDecl { + has_else: bool, + }, +} + +/// Provides context for checking patterns in declarations. More specifically this +/// allows us to infer array types if the pattern is irrefutable and allows us to infer +/// the size of the array. See issue rust-lang/rust#76342. +#[derive(Debug, Copy, Clone)] +pub(crate) struct DeclContext { + pub(crate) origin: DeclOrigin, +} diff --git a/crates/hir-ty/src/tests/simple.rs b/crates/hir-ty/src/tests/simple.rs index 156366045705..50a1ecd006d8 100644 --- a/crates/hir-ty/src/tests/simple.rs +++ b/crates/hir-ty/src/tests/simple.rs @@ -3814,3 +3814,50 @@ async fn foo(a: (), b: i32) -> u32 { "#, ); } + +#[test] +fn irrefutable_slices() { + check_infer( + r#" +//- minicore: from +struct A; + +impl From for [u8; 2] { + fn from(a: A) -> Self { + [0; 2] + } +} +impl From for [u8; 3] { + fn from(a: A) -> Self { + [0; 3] + } +} + + +fn main() { + let a = A; + let [b, c] = a.into(); +} +"#, + expect![[r#" + 50..51 'a': A + 64..86 '{ ... }': [u8; 2] + 74..80 '[0; 2]': [u8; 2] + 75..76 '0': u8 + 78..79 '2': usize + 128..129 'a': A + 142..164 '{ ... }': [u8; 3] + 152..158 '[0; 3]': [u8; 3] + 153..154 '0': u8 + 156..157 '3': usize + 179..224 '{ ...o(); }': () + 189..190 'a': A + 193..194 'A': A + 204..210 '[b, c]': [u8; 2] + 205..206 'b': u8 + 208..209 'c': u8 + 213..214 'a': A + 213..221 'a.into()': [u8; 2] + "#]], + ); +}