Skip to content

Commit a3236be

Browse files
committed
Auto merge of rust-lang#16630 - ShoyuVanilla:fix-closure-kind-inference, r=Veykril
fix: Wrong closure kind deduction for closures with predicates Completes rust-lang#16472, fixes rust-lang#16421 The changed closure kind deduction is mostly simlar to `rustc_hir_typeck/src/closure.rs`. Porting closure sig deduction from it seems possible too and I'm considering doing it with another PR
2 parents c031246 + a4021f6 commit a3236be

File tree

8 files changed

+269
-36
lines changed

8 files changed

+269
-36
lines changed

crates/hir-ty/src/infer/closure.rs

+81-7
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::{cmp, convert::Infallible, mem};
55
use chalk_ir::{
66
cast::Cast,
77
fold::{FallibleTypeFolder, TypeFoldable},
8-
AliasEq, AliasTy, BoundVar, DebruijnIndex, FnSubst, Mutability, TyKind, WhereClause,
8+
BoundVar, DebruijnIndex, FnSubst, Mutability, TyKind,
99
};
1010
use either::Either;
1111
use hir_def::{
@@ -22,13 +22,14 @@ use stdx::never;
2222

2323
use crate::{
2424
db::{HirDatabase, InternedClosure},
25-
from_placeholder_idx, make_binders,
25+
from_chalk_trait_id, from_placeholder_idx, make_binders,
2626
mir::{BorrowKind, MirSpan, MutBorrowKind, ProjectionElem},
2727
static_lifetime, to_chalk_trait_id,
2828
traits::FnTrait,
29-
utils::{self, generics, Generics},
30-
Adjust, Adjustment, Binders, BindingMode, ChalkTraitId, ClosureId, DynTy, FnAbi, FnPointer,
31-
FnSig, Interner, Substitution, Ty, TyExt,
29+
utils::{self, elaborate_clause_supertraits, generics, Generics},
30+
Adjust, Adjustment, AliasEq, AliasTy, Binders, BindingMode, ChalkTraitId, ClosureId, DynTy,
31+
DynTyExt, FnAbi, FnPointer, FnSig, Interner, OpaqueTy, ProjectionTyExt, Substitution, Ty,
32+
TyExt, WhereClause,
3233
};
3334

3435
use super::{Expectation, InferenceContext};
@@ -47,6 +48,15 @@ impl InferenceContext<'_> {
4748
None => return,
4849
};
4950

51+
if let TyKind::Closure(closure_id, _) = closure_ty.kind(Interner) {
52+
if let Some(closure_kind) = self.deduce_closure_kind_from_expectations(&expected_ty) {
53+
self.result
54+
.closure_info
55+
.entry(*closure_id)
56+
.or_insert_with(|| (Vec::new(), closure_kind));
57+
}
58+
}
59+
5060
// Deduction from where-clauses in scope, as well as fn-pointer coercion are handled here.
5161
let _ = self.coerce(Some(closure_expr), closure_ty, &expected_ty);
5262

@@ -65,6 +75,60 @@ impl InferenceContext<'_> {
6575
}
6676
}
6777

78+
// Closure kind deductions are mostly from `rustc_hir_typeck/src/closure.rs`.
79+
// Might need to port closure sig deductions too.
80+
fn deduce_closure_kind_from_expectations(&mut self, expected_ty: &Ty) -> Option<FnTrait> {
81+
match expected_ty.kind(Interner) {
82+
TyKind::Alias(AliasTy::Opaque(OpaqueTy { .. })) | TyKind::OpaqueType(..) => {
83+
let clauses = expected_ty
84+
.impl_trait_bounds(self.db)
85+
.into_iter()
86+
.flatten()
87+
.map(|b| b.into_value_and_skipped_binders().0);
88+
self.deduce_closure_kind_from_predicate_clauses(clauses)
89+
}
90+
TyKind::Dyn(dyn_ty) => dyn_ty.principal().and_then(|trait_ref| {
91+
self.fn_trait_kind_from_trait_id(from_chalk_trait_id(trait_ref.trait_id))
92+
}),
93+
TyKind::InferenceVar(ty, chalk_ir::TyVariableKind::General) => {
94+
let clauses = self.clauses_for_self_ty(*ty);
95+
self.deduce_closure_kind_from_predicate_clauses(clauses.into_iter())
96+
}
97+
TyKind::Function(_) => Some(FnTrait::Fn),
98+
_ => None,
99+
}
100+
}
101+
102+
fn deduce_closure_kind_from_predicate_clauses(
103+
&self,
104+
clauses: impl DoubleEndedIterator<Item = WhereClause>,
105+
) -> Option<FnTrait> {
106+
let mut expected_kind = None;
107+
108+
for clause in elaborate_clause_supertraits(self.db, clauses.rev()) {
109+
let trait_id = match clause {
110+
WhereClause::AliasEq(AliasEq {
111+
alias: AliasTy::Projection(projection), ..
112+
}) => Some(projection.trait_(self.db)),
113+
WhereClause::Implemented(trait_ref) => {
114+
Some(from_chalk_trait_id(trait_ref.trait_id))
115+
}
116+
_ => None,
117+
};
118+
if let Some(closure_kind) =
119+
trait_id.and_then(|trait_id| self.fn_trait_kind_from_trait_id(trait_id))
120+
{
121+
// `FnX`'s variants order is opposite from rustc, so use `cmp::max` instead of `cmp::min`
122+
expected_kind = Some(
123+
expected_kind
124+
.map_or_else(|| closure_kind, |current| cmp::max(current, closure_kind)),
125+
);
126+
}
127+
}
128+
129+
expected_kind
130+
}
131+
68132
fn deduce_sig_from_dyn_ty(&self, dyn_ty: &DynTy) -> Option<FnPointer> {
69133
// Search for a predicate like `<$self as FnX<Args>>::Output == Ret`
70134

@@ -111,6 +175,10 @@ impl InferenceContext<'_> {
111175

112176
None
113177
}
178+
179+
fn fn_trait_kind_from_trait_id(&self, trait_id: hir_def::TraitId) -> Option<FnTrait> {
180+
FnTrait::from_lang_item(self.db.lang_attr(trait_id.into())?)
181+
}
114182
}
115183

116184
// The below functions handle capture and closure kind (Fn, FnMut, ..)
@@ -962,8 +1030,14 @@ impl InferenceContext<'_> {
9621030
}
9631031
}
9641032
self.restrict_precision_for_unsafe();
965-
// closure_kind should be done before adjust_for_move_closure
966-
let closure_kind = self.closure_kind();
1033+
// `closure_kind` should be done before adjust_for_move_closure
1034+
// If there exists pre-deduced kind of a closure, use it instead of one determined by capture, as rustc does.
1035+
// rustc also does diagnostics here if the latter is not a subtype of the former.
1036+
let closure_kind = self
1037+
.result
1038+
.closure_info
1039+
.get(&closure)
1040+
.map_or_else(|| self.closure_kind(), |info| info.1);
9671041
match capture_by {
9681042
CaptureBy::Value => self.adjust_for_move_closure(),
9691043
CaptureBy::Ref => (),

crates/hir-ty/src/infer/unify.rs

+70-3
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@ use chalk_solve::infer::ParameterEnaVariableExt;
1010
use either::Either;
1111
use ena::unify::UnifyKey;
1212
use hir_expand::name;
13+
use smallvec::SmallVec;
1314
use triomphe::Arc;
1415

1516
use super::{InferOk, InferResult, InferenceContext, TypeError};
1617
use crate::{
1718
consteval::unknown_const, db::HirDatabase, fold_tys_and_consts, static_lifetime,
1819
to_chalk_trait_id, traits::FnTrait, AliasEq, AliasTy, BoundVar, Canonical, Const, ConstValue,
19-
DebruijnIndex, GenericArg, GenericArgData, Goal, Guidance, InEnvironment, InferenceVar,
20-
Interner, Lifetime, ParamKind, ProjectionTy, ProjectionTyExt, Scalar, Solution, Substitution,
21-
TraitEnvironment, Ty, TyBuilder, TyExt, TyKind, VariableKind,
20+
DebruijnIndex, DomainGoal, GenericArg, GenericArgData, Goal, GoalData, Guidance, InEnvironment,
21+
InferenceVar, Interner, Lifetime, ParamKind, ProjectionTy, ProjectionTyExt, Scalar, Solution,
22+
Substitution, TraitEnvironment, Ty, TyBuilder, TyExt, TyKind, VariableKind, WhereClause,
2223
};
2324

2425
impl InferenceContext<'_> {
@@ -31,6 +32,72 @@ impl InferenceContext<'_> {
3132
{
3233
self.table.canonicalize(t)
3334
}
35+
36+
pub(super) fn clauses_for_self_ty(
37+
&mut self,
38+
self_ty: InferenceVar,
39+
) -> SmallVec<[WhereClause; 4]> {
40+
self.table.resolve_obligations_as_possible();
41+
42+
let root = self.table.var_unification_table.inference_var_root(self_ty);
43+
let pending_obligations = mem::take(&mut self.table.pending_obligations);
44+
let obligations = pending_obligations
45+
.iter()
46+
.filter_map(|obligation| match obligation.value.value.goal.data(Interner) {
47+
GoalData::DomainGoal(DomainGoal::Holds(
48+
clause @ WhereClause::AliasEq(AliasEq {
49+
alias: AliasTy::Projection(projection),
50+
..
51+
}),
52+
)) => {
53+
let projection_self = projection.self_type_parameter(self.db);
54+
let uncanonical = chalk_ir::Substitute::apply(
55+
&obligation.free_vars,
56+
projection_self,
57+
Interner,
58+
);
59+
if matches!(
60+
self.resolve_ty_shallow(&uncanonical).kind(Interner),
61+
TyKind::InferenceVar(iv, TyVariableKind::General) if *iv == root,
62+
) {
63+
Some(chalk_ir::Substitute::apply(
64+
&obligation.free_vars,
65+
clause.clone(),
66+
Interner,
67+
))
68+
} else {
69+
None
70+
}
71+
}
72+
GoalData::DomainGoal(DomainGoal::Holds(
73+
clause @ WhereClause::Implemented(trait_ref),
74+
)) => {
75+
let trait_ref_self = trait_ref.self_type_parameter(Interner);
76+
let uncanonical = chalk_ir::Substitute::apply(
77+
&obligation.free_vars,
78+
trait_ref_self,
79+
Interner,
80+
);
81+
if matches!(
82+
self.resolve_ty_shallow(&uncanonical).kind(Interner),
83+
TyKind::InferenceVar(iv, TyVariableKind::General) if *iv == root,
84+
) {
85+
Some(chalk_ir::Substitute::apply(
86+
&obligation.free_vars,
87+
clause.clone(),
88+
Interner,
89+
))
90+
} else {
91+
None
92+
}
93+
}
94+
_ => None,
95+
})
96+
.collect();
97+
self.table.pending_obligations = pending_obligations;
98+
99+
obligations
100+
}
34101
}
35102

36103
#[derive(Debug, Clone)]

crates/hir-ty/src/tests/patterns.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -702,25 +702,25 @@ fn test() {
702702
51..58 'loop {}': !
703703
56..58 '{}': ()
704704
72..171 '{ ... x); }': ()
705-
78..81 'foo': fn foo<&(i32, &str), i32, impl Fn(&(i32, &str)) -> i32>(&(i32, &str), impl Fn(&(i32, &str)) -> i32) -> i32
705+
78..81 'foo': fn foo<&(i32, &str), i32, impl FnOnce(&(i32, &str)) -> i32>(&(i32, &str), impl FnOnce(&(i32, &str)) -> i32) -> i32
706706
78..105 'foo(&(...y)| x)': i32
707707
82..91 '&(1, "a")': &(i32, &str)
708708
83..91 '(1, "a")': (i32, &str)
709709
84..85 '1': i32
710710
87..90 '"a"': &str
711-
93..104 '|&(x, y)| x': impl Fn(&(i32, &str)) -> i32
711+
93..104 '|&(x, y)| x': impl FnOnce(&(i32, &str)) -> i32
712712
94..101 '&(x, y)': &(i32, &str)
713713
95..101 '(x, y)': (i32, &str)
714714
96..97 'x': i32
715715
99..100 'y': &str
716716
103..104 'x': i32
717-
142..145 'foo': fn foo<&(i32, &str), &i32, impl Fn(&(i32, &str)) -> &i32>(&(i32, &str), impl Fn(&(i32, &str)) -> &i32) -> &i32
717+
142..145 'foo': fn foo<&(i32, &str), &i32, impl FnOnce(&(i32, &str)) -> &i32>(&(i32, &str), impl FnOnce(&(i32, &str)) -> &i32) -> &i32
718718
142..168 'foo(&(...y)| x)': &i32
719719
146..155 '&(1, "a")': &(i32, &str)
720720
147..155 '(1, "a")': (i32, &str)
721721
148..149 '1': i32
722722
151..154 '"a"': &str
723-
157..167 '|(x, y)| x': impl Fn(&(i32, &str)) -> &i32
723+
157..167 '|(x, y)| x': impl FnOnce(&(i32, &str)) -> &i32
724724
158..164 '(x, y)': (i32, &str)
725725
159..160 'x': &i32
726726
162..163 'y': &&str

crates/hir-ty/src/tests/regression.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -862,7 +862,7 @@ fn main() {
862862
123..126 'S()': S<i32>
863863
132..133 's': S<i32>
864864
132..144 's.g(|_x| {})': ()
865-
136..143 '|_x| {}': impl Fn(&i32)
865+
136..143 '|_x| {}': impl FnOnce(&i32)
866866
137..139 '_x': &i32
867867
141..143 '{}': ()
868868
150..151 's': S<i32>

crates/hir-ty/src/tests/simple.rs

+39-2
Original file line numberDiff line numberDiff line change
@@ -2190,9 +2190,9 @@ fn main() {
21902190
149..151 'Ok': extern "rust-call" Ok<(), ()>(()) -> Result<(), ()>
21912191
149..155 'Ok(())': Result<(), ()>
21922192
152..154 '()': ()
2193-
167..171 'test': fn test<(), (), impl Fn() -> impl Future<Output = Result<(), ()>>, impl Future<Output = Result<(), ()>>>(impl Fn() -> impl Future<Output = Result<(), ()>>)
2193+
167..171 'test': fn test<(), (), impl FnMut() -> impl Future<Output = Result<(), ()>>, impl Future<Output = Result<(), ()>>>(impl FnMut() -> impl Future<Output = Result<(), ()>>)
21942194
167..228 'test(|... })': ()
2195-
172..227 '|| asy... }': impl Fn() -> impl Future<Output = Result<(), ()>>
2195+
172..227 '|| asy... }': impl FnMut() -> impl Future<Output = Result<(), ()>>
21962196
175..227 'async ... }': impl Future<Output = Result<(), ()>>
21972197
191..205 'return Err(())': !
21982198
198..201 'Err': extern "rust-call" Err<(), ()>(()) -> Result<(), ()>
@@ -2886,6 +2886,43 @@ fn f() {
28862886
)
28872887
}
28882888

2889+
#[test]
2890+
fn closure_kind_with_predicates() {
2891+
check_types(
2892+
r#"
2893+
//- minicore: fn
2894+
#![feature(unboxed_closures)]
2895+
2896+
struct X<T: FnOnce()>(T);
2897+
2898+
fn f1() -> impl FnOnce() {
2899+
|| {}
2900+
// ^^^^^ impl FnOnce()
2901+
}
2902+
2903+
fn f2(c: impl FnOnce<(), Output = i32>) {}
2904+
2905+
fn test {
2906+
let x1 = X(|| {});
2907+
let c1 = x1.0;
2908+
// ^^ impl FnOnce()
2909+
2910+
let c2 = || {};
2911+
// ^^ impl Fn()
2912+
let x2 = X(c2);
2913+
let c3 = x2.0
2914+
// ^^ impl Fn()
2915+
2916+
let c4 = f1();
2917+
// ^^ impl FnOnce() + ?Sized
2918+
2919+
f2(|| { 0 });
2920+
// ^^^^^^^^ impl FnOnce() -> i32
2921+
}
2922+
"#,
2923+
)
2924+
}
2925+
28892926
#[test]
28902927
fn derive_macro_should_work_for_associated_type() {
28912928
check_types(

0 commit comments

Comments
 (0)