Skip to content

Commit 13c17db

Browse files
authored
Merge pull request #19066 from alibektas/slice_pattern_type_inference
fix: try to infer array type from slice pattern
2 parents cd0753a + 135fca9 commit 13c17db

File tree

5 files changed

+183
-31
lines changed

5 files changed

+183
-31
lines changed

crates/hir-ty/src/infer.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,7 @@ impl<'a> InferenceContext<'a> {
946946
let ty = self.insert_type_vars(ty);
947947
let ty = self.normalize_associated_types_in(ty);
948948

949-
self.infer_top_pat(*pat, &ty);
949+
self.infer_top_pat(*pat, &ty, None);
950950
if ty
951951
.data(Interner)
952952
.flags

crates/hir-ty/src/infer/expr.rs

+15-8
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ use crate::{
4343
primitive::{self, UintTy},
4444
static_lifetime, to_chalk_trait_id,
4545
traits::FnTrait,
46-
Adjust, Adjustment, AdtId, AutoBorrow, Binders, CallableDefId, CallableSig, FnAbi, FnPointer,
47-
FnSig, FnSubst, Interner, Rawness, Scalar, Substitution, TraitEnvironment, TraitRef, Ty,
48-
TyBuilder, TyExt, TyKind,
46+
Adjust, Adjustment, AdtId, AutoBorrow, Binders, CallableDefId, CallableSig, DeclContext,
47+
DeclOrigin, FnAbi, FnPointer, FnSig, FnSubst, Interner, Rawness, Scalar, Substitution,
48+
TraitEnvironment, TraitRef, Ty, TyBuilder, TyExt, TyKind,
4949
};
5050

5151
use super::{
@@ -334,7 +334,11 @@ impl InferenceContext<'_> {
334334
ExprIsRead::No
335335
};
336336
let input_ty = self.infer_expr(expr, &Expectation::none(), child_is_read);
337-
self.infer_top_pat(pat, &input_ty);
337+
self.infer_top_pat(
338+
pat,
339+
&input_ty,
340+
Some(DeclContext { origin: DeclOrigin::LetExpr }),
341+
);
338342
self.result.standard_types.bool_.clone()
339343
}
340344
Expr::Block { statements, tail, label, id } => {
@@ -461,7 +465,7 @@ impl InferenceContext<'_> {
461465

462466
// Now go through the argument patterns
463467
for (arg_pat, arg_ty) in args.iter().zip(&sig_tys) {
464-
self.infer_top_pat(*arg_pat, arg_ty);
468+
self.infer_top_pat(*arg_pat, arg_ty, None);
465469
}
466470

467471
// FIXME: lift these out into a struct
@@ -582,7 +586,7 @@ impl InferenceContext<'_> {
582586
let mut all_arms_diverge = Diverges::Always;
583587
for arm in arms.iter() {
584588
let input_ty = self.resolve_ty_shallow(&input_ty);
585-
self.infer_top_pat(arm.pat, &input_ty);
589+
self.infer_top_pat(arm.pat, &input_ty, None);
586590
}
587591

588592
let expected = expected.adjust_for_branches(&mut self.table);
@@ -927,7 +931,7 @@ impl InferenceContext<'_> {
927931
let resolver_guard =
928932
self.resolver.update_to_inner_scope(self.db.upcast(), self.owner, tgt_expr);
929933
self.inside_assignment = true;
930-
self.infer_top_pat(target, &rhs_ty);
934+
self.infer_top_pat(target, &rhs_ty, None);
931935
self.inside_assignment = false;
932936
self.resolver.reset_to_guard(resolver_guard);
933937
}
@@ -1632,8 +1636,11 @@ impl InferenceContext<'_> {
16321636
decl_ty
16331637
};
16341638

1635-
this.infer_top_pat(*pat, &ty);
1639+
let decl = DeclContext {
1640+
origin: DeclOrigin::LocalDecl { has_else: else_branch.is_some() },
1641+
};
16361642

1643+
this.infer_top_pat(*pat, &ty, Some(decl));
16371644
if let Some(expr) = else_branch {
16381645
let previous_diverges =
16391646
mem::replace(&mut this.diverges, Diverges::Maybe);

crates/hir-ty/src/infer/pat.rs

+103-22
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,21 @@ use hir_def::{
66
expr_store::Body,
77
hir::{Binding, BindingAnnotation, BindingId, Expr, ExprId, Literal, Pat, PatId},
88
path::Path,
9+
HasModule,
910
};
1011
use hir_expand::name::Name;
1112
use stdx::TupleExt;
1213

1314
use crate::{
14-
consteval::{try_const_usize, usize_const},
15+
consteval::{self, try_const_usize, usize_const},
1516
infer::{
1617
coerce::CoerceNever, expr::ExprIsRead, BindingMode, Expectation, InferenceContext,
1718
TypeMismatch,
1819
},
1920
lower::lower_to_chalk_mutability,
2021
primitive::UintTy,
21-
static_lifetime, InferenceDiagnostic, Interner, Mutability, Scalar, Substitution, Ty,
22-
TyBuilder, TyExt, TyKind,
22+
static_lifetime, DeclContext, DeclOrigin, InferenceDiagnostic, Interner, Mutability, Scalar,
23+
Substitution, Ty, TyBuilder, TyExt, TyKind,
2324
};
2425

2526
impl InferenceContext<'_> {
@@ -34,6 +35,7 @@ impl InferenceContext<'_> {
3435
id: PatId,
3536
ellipsis: Option<u32>,
3637
subs: &[PatId],
38+
decl: Option<DeclContext>,
3739
) -> Ty {
3840
let (ty, def) = self.resolve_variant(id.into(), path, true);
3941
let var_data = def.map(|it| it.variant_data(self.db.upcast()));
@@ -92,13 +94,13 @@ impl InferenceContext<'_> {
9294
}
9395
};
9496

95-
self.infer_pat(subpat, &expected_ty, default_bm);
97+
self.infer_pat(subpat, &expected_ty, default_bm, decl);
9698
}
9799
}
98100
None => {
99101
let err_ty = self.err_ty();
100102
for &inner in subs {
101-
self.infer_pat(inner, &err_ty, default_bm);
103+
self.infer_pat(inner, &err_ty, default_bm, decl);
102104
}
103105
}
104106
}
@@ -114,6 +116,7 @@ impl InferenceContext<'_> {
114116
default_bm: BindingMode,
115117
id: PatId,
116118
subs: impl ExactSizeIterator<Item = (Name, PatId)>,
119+
decl: Option<DeclContext>,
117120
) -> Ty {
118121
let (ty, def) = self.resolve_variant(id.into(), path, false);
119122
if let Some(variant) = def {
@@ -162,13 +165,13 @@ impl InferenceContext<'_> {
162165
}
163166
};
164167

165-
self.infer_pat(inner, &expected_ty, default_bm);
168+
self.infer_pat(inner, &expected_ty, default_bm, decl);
166169
}
167170
}
168171
None => {
169172
let err_ty = self.err_ty();
170173
for (_, inner) in subs {
171-
self.infer_pat(inner, &err_ty, default_bm);
174+
self.infer_pat(inner, &err_ty, default_bm, decl);
172175
}
173176
}
174177
}
@@ -185,6 +188,7 @@ impl InferenceContext<'_> {
185188
default_bm: BindingMode,
186189
ellipsis: Option<u32>,
187190
subs: &[PatId],
191+
decl: Option<DeclContext>,
188192
) -> Ty {
189193
let expected = self.resolve_ty_shallow(expected);
190194
let expectations = match expected.as_tuple() {
@@ -209,12 +213,12 @@ impl InferenceContext<'_> {
209213

210214
// Process pre
211215
for (ty, pat) in inner_tys.iter_mut().zip(pre) {
212-
*ty = self.infer_pat(*pat, ty, default_bm);
216+
*ty = self.infer_pat(*pat, ty, default_bm, decl);
213217
}
214218

215219
// Process post
216220
for (ty, pat) in inner_tys.iter_mut().skip(pre.len() + n_uncovered_patterns).zip(post) {
217-
*ty = self.infer_pat(*pat, ty, default_bm);
221+
*ty = self.infer_pat(*pat, ty, default_bm, decl);
218222
}
219223

220224
TyKind::Tuple(inner_tys.len(), Substitution::from_iter(Interner, inner_tys))
@@ -223,11 +227,17 @@ impl InferenceContext<'_> {
223227

224228
/// The resolver needs to be updated to the surrounding expression when inside assignment
225229
/// (because there, `Pat::Path` can refer to a variable).
226-
pub(super) fn infer_top_pat(&mut self, pat: PatId, expected: &Ty) {
227-
self.infer_pat(pat, expected, BindingMode::default());
230+
pub(super) fn infer_top_pat(&mut self, pat: PatId, expected: &Ty, decl: Option<DeclContext>) {
231+
self.infer_pat(pat, expected, BindingMode::default(), decl);
228232
}
229233

230-
fn infer_pat(&mut self, pat: PatId, expected: &Ty, mut default_bm: BindingMode) -> Ty {
234+
fn infer_pat(
235+
&mut self,
236+
pat: PatId,
237+
expected: &Ty,
238+
mut default_bm: BindingMode,
239+
decl: Option<DeclContext>,
240+
) -> Ty {
231241
let mut expected = self.resolve_ty_shallow(expected);
232242

233243
if matches!(&self.body[pat], Pat::Ref { .. }) || self.inside_assignment {
@@ -261,11 +271,11 @@ impl InferenceContext<'_> {
261271

262272
let ty = match &self.body[pat] {
263273
Pat::Tuple { args, ellipsis } => {
264-
self.infer_tuple_pat_like(&expected, default_bm, *ellipsis, args)
274+
self.infer_tuple_pat_like(&expected, default_bm, *ellipsis, args, decl)
265275
}
266276
Pat::Or(pats) => {
267277
for pat in pats.iter() {
268-
self.infer_pat(*pat, &expected, default_bm);
278+
self.infer_pat(*pat, &expected, default_bm, decl);
269279
}
270280
expected.clone()
271281
}
@@ -274,6 +284,7 @@ impl InferenceContext<'_> {
274284
lower_to_chalk_mutability(mutability),
275285
&expected,
276286
default_bm,
287+
decl,
277288
),
278289
Pat::TupleStruct { path: p, args: subpats, ellipsis } => self
279290
.infer_tuple_struct_pat_like(
@@ -283,10 +294,11 @@ impl InferenceContext<'_> {
283294
pat,
284295
*ellipsis,
285296
subpats,
297+
decl,
286298
),
287299
Pat::Record { path: p, args: fields, ellipsis: _ } => {
288300
let subs = fields.iter().map(|f| (f.name.clone(), f.pat));
289-
self.infer_record_pat_like(p.as_deref(), &expected, default_bm, pat, subs)
301+
self.infer_record_pat_like(p.as_deref(), &expected, default_bm, pat, subs, decl)
290302
}
291303
Pat::Path(path) => {
292304
let ty = self.infer_path(path, pat.into()).unwrap_or_else(|| self.err_ty());
@@ -319,10 +331,10 @@ impl InferenceContext<'_> {
319331
}
320332
}
321333
Pat::Bind { id, subpat } => {
322-
return self.infer_bind_pat(pat, *id, default_bm, *subpat, &expected);
334+
return self.infer_bind_pat(pat, *id, default_bm, *subpat, &expected, decl);
323335
}
324336
Pat::Slice { prefix, slice, suffix } => {
325-
self.infer_slice_pat(&expected, prefix, slice, suffix, default_bm)
337+
self.infer_slice_pat(&expected, prefix, slice, suffix, default_bm, decl)
326338
}
327339
Pat::Wild => expected.clone(),
328340
Pat::Range { .. } => {
@@ -345,7 +357,7 @@ impl InferenceContext<'_> {
345357
_ => (self.result.standard_types.unknown.clone(), None),
346358
};
347359

348-
let inner_ty = self.infer_pat(*inner, &inner_ty, default_bm);
360+
let inner_ty = self.infer_pat(*inner, &inner_ty, default_bm, decl);
349361
let mut b = TyBuilder::adt(self.db, box_adt).push(inner_ty);
350362

351363
if let Some(alloc_ty) = alloc_ty {
@@ -420,6 +432,7 @@ impl InferenceContext<'_> {
420432
mutability: Mutability,
421433
expected: &Ty,
422434
default_bm: BindingMode,
435+
decl: Option<DeclContext>,
423436
) -> Ty {
424437
let (expectation_type, expectation_lt) = match expected.as_reference() {
425438
Some((inner_ty, lifetime, _exp_mut)) => (inner_ty.clone(), lifetime.clone()),
@@ -433,7 +446,7 @@ impl InferenceContext<'_> {
433446
(inner_ty, inner_lt)
434447
}
435448
};
436-
let subty = self.infer_pat(inner_pat, &expectation_type, default_bm);
449+
let subty = self.infer_pat(inner_pat, &expectation_type, default_bm, decl);
437450
TyKind::Ref(mutability, expectation_lt, subty).intern(Interner)
438451
}
439452

@@ -444,6 +457,7 @@ impl InferenceContext<'_> {
444457
default_bm: BindingMode,
445458
subpat: Option<PatId>,
446459
expected: &Ty,
460+
decl: Option<DeclContext>,
447461
) -> Ty {
448462
let Binding { mode, .. } = self.body.bindings[binding];
449463
let mode = if mode == BindingAnnotation::Unannotated {
@@ -454,7 +468,7 @@ impl InferenceContext<'_> {
454468
self.result.binding_modes.insert(pat, mode);
455469

456470
let inner_ty = match subpat {
457-
Some(subpat) => self.infer_pat(subpat, expected, default_bm),
471+
Some(subpat) => self.infer_pat(subpat, expected, default_bm, decl),
458472
None => expected.clone(),
459473
};
460474
let inner_ty = self.insert_type_vars_shallow(inner_ty);
@@ -478,14 +492,28 @@ impl InferenceContext<'_> {
478492
slice: &Option<PatId>,
479493
suffix: &[PatId],
480494
default_bm: BindingMode,
495+
decl: Option<DeclContext>,
481496
) -> Ty {
497+
let expected = self.resolve_ty_shallow(expected);
498+
499+
// If `expected` is an infer ty, we try to equate it to an array if the given pattern
500+
// allows it. See issue #16609
501+
if self.pat_is_irrefutable(decl) && expected.is_ty_var() {
502+
if let Some(resolved_array_ty) =
503+
self.try_resolve_slice_ty_to_array_ty(prefix, suffix, slice)
504+
{
505+
self.unify(&expected, &resolved_array_ty);
506+
}
507+
}
508+
509+
let expected = self.resolve_ty_shallow(&expected);
482510
let elem_ty = match expected.kind(Interner) {
483511
TyKind::Array(st, _) | TyKind::Slice(st) => st.clone(),
484512
_ => self.err_ty(),
485513
};
486514

487515
for &pat_id in prefix.iter().chain(suffix.iter()) {
488-
self.infer_pat(pat_id, &elem_ty, default_bm);
516+
self.infer_pat(pat_id, &elem_ty, default_bm, decl);
489517
}
490518

491519
if let &Some(slice_pat_id) = slice {
@@ -499,7 +527,7 @@ impl InferenceContext<'_> {
499527
_ => TyKind::Slice(elem_ty.clone()),
500528
}
501529
.intern(Interner);
502-
self.infer_pat(slice_pat_id, &rest_pat_ty, default_bm);
530+
self.infer_pat(slice_pat_id, &rest_pat_ty, default_bm, decl);
503531
}
504532

505533
match expected.kind(Interner) {
@@ -553,6 +581,59 @@ impl InferenceContext<'_> {
553581
| Pat::Expr(_) => false,
554582
}
555583
}
584+
585+
fn try_resolve_slice_ty_to_array_ty(
586+
&mut self,
587+
before: &[PatId],
588+
suffix: &[PatId],
589+
slice: &Option<PatId>,
590+
) -> Option<Ty> {
591+
if !slice.is_none() {
592+
return None;
593+
}
594+
595+
let len = before.len() + suffix.len();
596+
let size =
597+
consteval::usize_const(self.db, Some(len as u128), self.owner.krate(self.db.upcast()));
598+
599+
let elem_ty = self.table.new_type_var();
600+
let array_ty = TyKind::Array(elem_ty.clone(), size).intern(Interner);
601+
Some(array_ty)
602+
}
603+
604+
/// Used to determine whether we can infer the expected type in the slice pattern to be of type array.
605+
/// This is only possible if we're in an irrefutable pattern. If we were to allow this in refutable
606+
/// patterns we wouldn't e.g. report ambiguity in the following situation:
607+
///
608+
/// ```ignore(rust)
609+
/// struct Zeroes;
610+
/// const ARR: [usize; 2] = [0; 2];
611+
/// const ARR2: [usize; 2] = [2; 2];
612+
///
613+
/// impl Into<&'static [usize; 2]> for Zeroes {
614+
/// fn into(self) -> &'static [usize; 2] {
615+
/// &ARR
616+
/// }
617+
/// }
618+
///
619+
/// impl Into<&'static [usize]> for Zeroes {
620+
/// fn into(self) -> &'static [usize] {
621+
/// &ARR2
622+
/// }
623+
/// }
624+
///
625+
/// fn main() {
626+
/// let &[a, b]: &[usize] = Zeroes.into() else {
627+
/// ..
628+
/// };
629+
/// }
630+
/// ```
631+
///
632+
/// If we're in an irrefutable pattern we prefer the array impl candidate given that
633+
/// the slice impl candidate would be rejected anyway (if no ambiguity existed).
634+
fn pat_is_irrefutable(&self, decl_ctxt: Option<DeclContext>) -> bool {
635+
matches!(decl_ctxt, Some(DeclContext { origin: DeclOrigin::LocalDecl { has_else: false } }))
636+
}
556637
}
557638

558639
pub(super) fn contains_explicit_ref_binding(body: &Body, pat_id: PatId) -> bool {

0 commit comments

Comments
 (0)