Skip to content

Commit 3229987

Browse files
committed
Fix closure kind inference
Fix existing errors in test cases Fix another test Fix another test Fix lints Remove irrelevant test code included by mistake
1 parent cfd7ef0 commit 3229987

File tree

7 files changed

+106
-42
lines changed

7 files changed

+106
-42
lines changed

crates/hir-ty/src/infer/closure.rs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,7 @@ impl InferenceContext<'_> {
924924
}
925925
}
926926

927-
fn closure_kind(&self) -> FnTrait {
927+
fn closure_kind_from_capture(&self) -> FnTrait {
928928
let mut r = FnTrait::Fn;
929929
for it in &self.current_captures {
930930
r = cmp::min(
@@ -941,7 +941,7 @@ impl InferenceContext<'_> {
941941
r
942942
}
943943

944-
fn analyze_closure(&mut self, closure: ClosureId) -> FnTrait {
944+
fn analyze_closure(&mut self, closure: ClosureId, predicate: Option<FnTrait>) -> FnTrait {
945945
let InternedClosure(_, root) = self.db.lookup_intern_closure(closure.into());
946946
self.current_closure = Some(closure);
947947
let Expr::Closure { body, capture_by, .. } = &self.body[root] else {
@@ -959,7 +959,14 @@ impl InferenceContext<'_> {
959959
}
960960
self.restrict_precision_for_unsafe();
961961
// closure_kind should be done before adjust_for_move_closure
962-
let closure_kind = self.closure_kind();
962+
let closure_kind = {
963+
let from_capture = self.closure_kind_from_capture();
964+
// if predicate.unwrap_or(FnTrait::Fn) < from_capture {
965+
// // Diagnostics here, like compiler does in
966+
// // https://github.com/rust-lang/rust/blob/11f32b73e0dc9287e305b5b9980d24aecdc8c17f/compiler/rustc_hir_typeck/src/upvar.rs#L264
967+
// }
968+
predicate.unwrap_or(from_capture)
969+
};
963970
match capture_by {
964971
CaptureBy::Value => self.adjust_for_move_closure(),
965972
CaptureBy::Ref => (),
@@ -975,7 +982,9 @@ impl InferenceContext<'_> {
975982
let deferred_closures = self.sort_closures();
976983
for (closure, exprs) in deferred_closures.into_iter().rev() {
977984
self.current_captures = vec![];
978-
let kind = self.analyze_closure(closure);
985+
986+
let predicate = self.table.get_closure_fn_trait_predicate(closure);
987+
let kind = self.analyze_closure(closure, predicate);
979988

980989
for (derefed_callee, callee_ty, params, expr) in exprs {
981990
if let &Expr::Call { callee, .. } = &self.body[expr] {

crates/hir-ty/src/infer/unify.rs

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
//! Unification and canonicalization logic.
22
3-
use std::{fmt, iter, mem};
3+
use std::{cmp, fmt, iter, mem};
44

55
use chalk_ir::{
66
cast::Cast, fold::TypeFoldable, interner::HasInterner, zip::Zip, CanonicalVarKind, FloatTy,
7-
IntTy, TyVariableKind, UniverseIndex,
7+
IntTy, TyVariableKind, UniverseIndex, WhereClause,
88
};
99
use chalk_solve::infer::ParameterEnaVariableExt;
1010
use either::Either;
@@ -14,11 +14,12 @@ use triomphe::Arc;
1414

1515
use super::{InferOk, InferResult, InferenceContext, TypeError};
1616
use crate::{
17-
consteval::unknown_const, db::HirDatabase, fold_tys_and_consts, static_lifetime,
18-
to_chalk_trait_id, traits::FnTrait, AliasEq, AliasTy, BoundVar, Canonical, Const, ConstValue,
19-
DebruijnIndex, GenericArg, GenericArgData, Goal, Guidance, InEnvironment, InferenceVar,
20-
Interner, Lifetime, ParamKind, ProjectionTy, ProjectionTyExt, Scalar, Solution, Substitution,
21-
TraitEnvironment, Ty, TyBuilder, TyExt, TyKind, VariableKind,
17+
chalk_db::TraitId, consteval::unknown_const, db::HirDatabase, fold_tys_and_consts,
18+
static_lifetime, to_chalk_trait_id, traits::FnTrait, AliasEq, AliasTy, BoundVar, Canonical,
19+
ClosureId, Const, ConstValue, DebruijnIndex, DomainGoal, GenericArg, GenericArgData, Goal,
20+
GoalData, Guidance, InEnvironment, InferenceVar, Interner, Lifetime, ParamKind, ProjectionTy,
21+
ProjectionTyExt, Scalar, Solution, Substitution, TraitEnvironment, Ty, TyBuilder, TyExt,
22+
TyKind, VariableKind,
2223
};
2324

2425
impl InferenceContext<'_> {
@@ -181,6 +182,8 @@ pub(crate) struct InferenceTable<'a> {
181182
/// Double buffer used in [`Self::resolve_obligations_as_possible`] to cut down on
182183
/// temporary allocations.
183184
resolve_obligations_buffer: Vec<Canonicalized<InEnvironment<Goal>>>,
185+
fn_trait_predicates: Vec<(Ty, FnTrait)>,
186+
cached_fn_trait_ids: Option<CachedFnTraitIds>,
184187
}
185188

186189
pub(crate) struct InferenceTableSnapshot {
@@ -189,15 +192,34 @@ pub(crate) struct InferenceTableSnapshot {
189192
type_variable_table_snapshot: Vec<TypeVariableFlags>,
190193
}
191194

195+
#[derive(Clone)]
196+
struct CachedFnTraitIds {
197+
fn_trait: TraitId,
198+
fn_mut_trait: TraitId,
199+
fn_once_trait: TraitId,
200+
}
201+
202+
impl CachedFnTraitIds {
203+
fn new(db: &dyn HirDatabase, trait_env: &Arc<TraitEnvironment>) -> Option<Self> {
204+
let fn_trait = FnTrait::Fn.get_id(db, trait_env.krate).map(to_chalk_trait_id)?;
205+
let fn_mut_trait = FnTrait::FnMut.get_id(db, trait_env.krate).map(to_chalk_trait_id)?;
206+
let fn_once_trait = FnTrait::FnOnce.get_id(db, trait_env.krate).map(to_chalk_trait_id)?;
207+
Some(Self { fn_trait, fn_mut_trait, fn_once_trait })
208+
}
209+
}
210+
192211
impl<'a> InferenceTable<'a> {
193212
pub(crate) fn new(db: &'a dyn HirDatabase, trait_env: Arc<TraitEnvironment>) -> Self {
213+
let cached_fn_trait_ids = CachedFnTraitIds::new(db, &trait_env);
194214
InferenceTable {
195215
db,
196216
trait_env,
197217
var_unification_table: ChalkInferenceTable::new(),
198218
type_variable_table: Vec::new(),
199219
pending_obligations: Vec::new(),
200220
resolve_obligations_buffer: Vec::new(),
221+
fn_trait_predicates: Vec::new(),
222+
cached_fn_trait_ids,
201223
}
202224
}
203225

@@ -547,6 +569,22 @@ impl<'a> InferenceTable<'a> {
547569
}
548570

549571
fn register_obligation_in_env(&mut self, goal: InEnvironment<Goal>) {
572+
if let Some(fn_trait_ids) = &self.cached_fn_trait_ids {
573+
if let GoalData::DomainGoal(DomainGoal::Holds(WhereClause::Implemented(trait_ref))) =
574+
goal.goal.data(Interner)
575+
{
576+
if let Some(ty) = trait_ref.substitution.type_parameters(Interner).next() {
577+
if trait_ref.trait_id == fn_trait_ids.fn_trait {
578+
self.fn_trait_predicates.push((ty, FnTrait::Fn));
579+
} else if trait_ref.trait_id == fn_trait_ids.fn_mut_trait {
580+
self.fn_trait_predicates.push((ty, FnTrait::FnMut));
581+
} else if trait_ref.trait_id == fn_trait_ids.fn_once_trait {
582+
self.fn_trait_predicates.push((ty, FnTrait::FnOnce));
583+
}
584+
}
585+
}
586+
}
587+
550588
let canonicalized = self.canonicalize(goal);
551589
let solution = self.try_resolve_obligation(&canonicalized);
552590
if matches!(solution, Some(Solution::Ambig(_))) {
@@ -838,6 +876,23 @@ impl<'a> InferenceTable<'a> {
838876
_ => c,
839877
}
840878
}
879+
880+
pub(super) fn get_closure_fn_trait_predicate(
881+
&mut self,
882+
closure_id: ClosureId,
883+
) -> Option<FnTrait> {
884+
let predicates = mem::take(&mut self.fn_trait_predicates);
885+
let res = predicates.iter().filter_map(|(ty, fn_trait)| {
886+
if matches!(self.resolve_completely(ty.clone()).kind(Interner), TyKind::Closure(c, ..) if *c == closure_id) {
887+
Some(*fn_trait)
888+
} else {
889+
None
890+
}
891+
}).fold(None, |acc, x| Some(cmp::max(acc.unwrap_or(FnTrait::FnOnce), x)));
892+
self.fn_trait_predicates = predicates;
893+
894+
res
895+
}
841896
}
842897

843898
impl fmt::Debug for InferenceTable<'_> {

crates/hir-ty/src/tests/patterns.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -702,25 +702,25 @@ fn test() {
702702
51..58 'loop {}': !
703703
56..58 '{}': ()
704704
72..171 '{ ... x); }': ()
705-
78..81 'foo': fn foo<&(i32, &str), i32, impl Fn(&(i32, &str)) -> i32>(&(i32, &str), impl Fn(&(i32, &str)) -> i32) -> i32
705+
78..81 'foo': fn foo<&(i32, &str), i32, impl FnOnce(&(i32, &str)) -> i32>(&(i32, &str), impl FnOnce(&(i32, &str)) -> i32) -> i32
706706
78..105 'foo(&(...y)| x)': i32
707707
82..91 '&(1, "a")': &(i32, &str)
708708
83..91 '(1, "a")': (i32, &str)
709709
84..85 '1': i32
710710
87..90 '"a"': &str
711-
93..104 '|&(x, y)| x': impl Fn(&(i32, &str)) -> i32
711+
93..104 '|&(x, y)| x': impl FnOnce(&(i32, &str)) -> i32
712712
94..101 '&(x, y)': &(i32, &str)
713713
95..101 '(x, y)': (i32, &str)
714714
96..97 'x': i32
715715
99..100 'y': &str
716716
103..104 'x': i32
717-
142..145 'foo': fn foo<&(i32, &str), &i32, impl Fn(&(i32, &str)) -> &i32>(&(i32, &str), impl Fn(&(i32, &str)) -> &i32) -> &i32
717+
142..145 'foo': fn foo<&(i32, &str), &i32, impl FnOnce(&(i32, &str)) -> &i32>(&(i32, &str), impl FnOnce(&(i32, &str)) -> &i32) -> &i32
718718
142..168 'foo(&(...y)| x)': &i32
719719
146..155 '&(1, "a")': &(i32, &str)
720720
147..155 '(1, "a")': (i32, &str)
721721
148..149 '1': i32
722722
151..154 '"a"': &str
723-
157..167 '|(x, y)| x': impl Fn(&(i32, &str)) -> &i32
723+
157..167 '|(x, y)| x': impl FnOnce(&(i32, &str)) -> &i32
724724
158..164 '(x, y)': (i32, &str)
725725
159..160 'x': &i32
726726
162..163 'y': &&str

crates/hir-ty/src/tests/regression.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -862,7 +862,7 @@ fn main() {
862862
123..126 'S()': S<i32>
863863
132..133 's': S<i32>
864864
132..144 's.g(|_x| {})': ()
865-
136..143 '|_x| {}': impl Fn(&i32)
865+
136..143 '|_x| {}': impl FnOnce(&i32)
866866
137..139 '_x': &i32
867867
141..143 '{}': ()
868868
150..151 's': S<i32>

crates/hir-ty/src/tests/simple.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2190,9 +2190,9 @@ fn main() {
21902190
149..151 'Ok': extern "rust-call" Ok<(), ()>(()) -> Result<(), ()>
21912191
149..155 'Ok(())': Result<(), ()>
21922192
152..154 '()': ()
2193-
167..171 'test': fn test<(), (), impl Fn() -> impl Future<Output = Result<(), ()>>, impl Future<Output = Result<(), ()>>>(impl Fn() -> impl Future<Output = Result<(), ()>>)
2193+
167..171 'test': fn test<(), (), impl FnMut() -> impl Future<Output = Result<(), ()>>, impl Future<Output = Result<(), ()>>>(impl FnMut() -> impl Future<Output = Result<(), ()>>)
21942194
167..228 'test(|... })': ()
2195-
172..227 '|| asy... }': impl Fn() -> impl Future<Output = Result<(), ()>>
2195+
172..227 '|| asy... }': impl FnMut() -> impl Future<Output = Result<(), ()>>
21962196
175..227 'async ... }': impl Future<Output = Result<(), ()>>
21972197
191..205 'return Err(())': !
21982198
198..201 'Err': extern "rust-call" Err<(), ()>(()) -> Result<(), ()>

crates/hir-ty/src/tests/traits.rs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,9 +1333,9 @@ fn foo<const C: u8, T>() -> (impl FnOnce(&str, T), impl Trait<u8>) {
13331333
}
13341334
"#,
13351335
expect![[r#"
1336-
134..165 '{ ...(C)) }': (impl Fn(&str, T), Bar<u8>)
1337-
140..163 '(|inpu...ar(C))': (impl Fn(&str, T), Bar<u8>)
1338-
141..154 '|input, t| {}': impl Fn(&str, T)
1336+
134..165 '{ ...(C)) }': (impl FnOnce(&str, T), Bar<u8>)
1337+
140..163 '(|inpu...ar(C))': (impl FnOnce(&str, T), Bar<u8>)
1338+
141..154 '|input, t| {}': impl FnOnce(&str, T)
13391339
142..147 'input': &str
13401340
149..150 't': T
13411341
152..154 '{}': ()
@@ -1963,20 +1963,20 @@ fn test() {
19631963
163..167 '1u32': u32
19641964
174..175 'x': Option<u32>
19651965
174..190 'x.map(...v + 1)': Option<u32>
1966-
180..189 '|v| v + 1': impl Fn(u32) -> u32
1966+
180..189 '|v| v + 1': impl FnOnce(u32) -> u32
19671967
181..182 'v': u32
19681968
184..185 'v': u32
19691969
184..189 'v + 1': u32
19701970
188..189 '1': u32
19711971
196..197 'x': Option<u32>
19721972
196..212 'x.map(... 1u64)': Option<u64>
1973-
202..211 '|_v| 1u64': impl Fn(u32) -> u64
1973+
202..211 '|_v| 1u64': impl FnOnce(u32) -> u64
19741974
203..205 '_v': u32
19751975
207..211 '1u64': u64
19761976
222..223 'y': Option<i64>
19771977
239..240 'x': Option<u32>
19781978
239..252 'x.map(|_v| 1)': Option<i64>
1979-
245..251 '|_v| 1': impl Fn(u32) -> i64
1979+
245..251 '|_v| 1': impl FnOnce(u32) -> i64
19801980
246..248 '_v': u32
19811981
250..251 '1': i64
19821982
"#]],
@@ -2062,17 +2062,17 @@ fn test() {
20622062
312..314 '{}': ()
20632063
330..489 '{ ... S); }': ()
20642064
340..342 'x1': u64
2065-
345..349 'foo1': fn foo1<S, u64, impl Fn(S) -> u64>(S, impl Fn(S) -> u64) -> u64
2065+
345..349 'foo1': fn foo1<S, u64, impl FnOnce(S) -> u64>(S, impl FnOnce(S) -> u64) -> u64
20662066
345..368 'foo1(S...hod())': u64
20672067
350..351 'S': S
2068-
353..367 '|s| s.method()': impl Fn(S) -> u64
2068+
353..367 '|s| s.method()': impl FnOnce(S) -> u64
20692069
354..355 's': S
20702070
357..358 's': S
20712071
357..367 's.method()': u64
20722072
378..380 'x2': u64
2073-
383..387 'foo2': fn foo2<S, u64, impl Fn(S) -> u64>(impl Fn(S) -> u64, S) -> u64
2073+
383..387 'foo2': fn foo2<S, u64, impl FnOnce(S) -> u64>(impl FnOnce(S) -> u64, S) -> u64
20742074
383..406 'foo2(|...(), S)': u64
2075-
388..402 '|s| s.method()': impl Fn(S) -> u64
2075+
388..402 '|s| s.method()': impl FnOnce(S) -> u64
20762076
389..390 's': S
20772077
392..393 's': S
20782078
392..402 's.method()': u64
@@ -2081,14 +2081,14 @@ fn test() {
20812081
421..422 'S': S
20822082
421..446 'S.foo1...hod())': u64
20832083
428..429 'S': S
2084-
431..445 '|s| s.method()': impl Fn(S) -> u64
2084+
431..445 '|s| s.method()': impl FnOnce(S) -> u64
20852085
432..433 's': S
20862086
435..436 's': S
20872087
435..445 's.method()': u64
20882088
456..458 'x4': u64
20892089
461..462 'S': S
20902090
461..486 'S.foo2...(), S)': u64
2091-
468..482 '|s| s.method()': impl Fn(S) -> u64
2091+
468..482 '|s| s.method()': impl FnOnce(S) -> u64
20922092
469..470 's': S
20932093
472..473 's': S
20942094
472..482 's.method()': u64
@@ -2562,9 +2562,9 @@ fn main() {
25622562
72..74 '_v': F
25632563
117..120 '{ }': ()
25642564
132..163 '{ ... }); }': ()
2565-
138..148 'f::<(), _>': fn f<(), impl Fn(&())>(impl Fn(&()))
2565+
138..148 'f::<(), _>': fn f<(), impl FnOnce(&())>(impl FnOnce(&()))
25662566
138..160 'f::<()... z; })': ()
2567-
149..159 '|z| { z; }': impl Fn(&())
2567+
149..159 '|z| { z; }': impl FnOnce(&())
25682568
150..151 'z': &()
25692569
153..159 '{ z; }': ()
25702570
155..156 'z': &()
@@ -2749,9 +2749,9 @@ fn main() {
27492749
983..998 'Vec::<i32>::new': fn new<i32>() -> Vec<i32>
27502750
983..1000 'Vec::<...:new()': Vec<i32>
27512751
983..1012 'Vec::<...iter()': IntoIter<i32>
2752-
983..1075 'Vec::<...one })': FilterMap<IntoIter<i32>, impl Fn(i32) -> Option<u32>>
2752+
983..1075 'Vec::<...one })': FilterMap<IntoIter<i32>, impl FnMut(i32) -> Option<u32>>
27532753
983..1101 'Vec::<... y; })': ()
2754-
1029..1074 '|x| if...None }': impl Fn(i32) -> Option<u32>
2754+
1029..1074 '|x| if...None }': impl FnMut(i32) -> Option<u32>
27552755
1030..1031 'x': i32
27562756
1033..1074 'if x >...None }': Option<u32>
27572757
1036..1037 'x': i32
@@ -2764,7 +2764,7 @@ fn main() {
27642764
1049..1057 'x as u32': u32
27652765
1066..1074 '{ None }': Option<u32>
27662766
1068..1072 'None': Option<u32>
2767-
1090..1100 '|y| { y; }': impl Fn(u32)
2767+
1090..1100 '|y| { y; }': impl FnMut(u32)
27682768
1091..1092 'y': u32
27692769
1094..1100 '{ y; }': ()
27702770
1096..1097 'y': u32

crates/ide/src/hover/tests.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -353,9 +353,9 @@ fn main() {
353353
expect![[r#"
354354
```rust
355355
{closure#0} // size = 8, align = 8, niches = 1
356-
impl FnOnce() -> S2
356+
impl Fn() -> S2
357357
```
358-
Coerced to: &impl FnOnce() -> S2
358+
Coerced to: &impl Fn() -> S2
359359
360360
## Captures
361361
* `x` by move"#]],
@@ -401,17 +401,17 @@ fn main() {
401401
},
402402
},
403403
HoverGotoTypeData {
404-
mod_path: "core::ops::function::FnOnce",
404+
mod_path: "core::ops::function::Fn",
405405
nav: NavigationTarget {
406406
file_id: FileId(
407407
1,
408408
),
409-
full_range: 632..867,
410-
focus_range: 693..699,
411-
name: "FnOnce",
409+
full_range: 254..425,
410+
focus_range: 310..312,
411+
name: "Fn",
412412
kind: Trait,
413413
container_name: "function",
414-
description: "pub trait FnOnce<Args>\nwhere\n Args: Tuple,",
414+
description: "pub trait Fn<Args>\nwhere\n Self: FnMut<Args>,\n Args: Tuple,",
415415
},
416416
},
417417
],

0 commit comments

Comments
 (0)