Skip to content

Commit d74d67c

Browse files
Implement async closure signature deduction
1 parent b0696a5 commit d74d67c

File tree

2 files changed

+57
-27
lines changed

2 files changed

+57
-27
lines changed

Diff for: compiler/rustc_hir_typeck/src/closure.rs

+47-27
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,12 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
5656
// It's always helpful for inference if we know the kind of
5757
// closure sooner rather than later, so first examine the expected
5858
// type, and see if can glean a closure kind from there.
59-
let (expected_sig, expected_kind) = match closure.kind {
60-
hir::ClosureKind::Closure => match expected.to_option(self) {
61-
Some(ty) => {
62-
self.deduce_closure_signature(self.try_structurally_resolve_type(expr_span, ty))
63-
}
64-
None => (None, None),
65-
},
66-
// We don't want to deduce a signature from `Fn` bounds for coroutines
67-
// or coroutine-closures, because the former does not implement `Fn`
68-
// ever, and the latter's signature doesn't correspond to the coroutine
69-
// type that it returns.
70-
hir::ClosureKind::Coroutine(_) | hir::ClosureKind::CoroutineClosure(_) => (None, None),
59+
let (expected_sig, expected_kind) = match expected.to_option(self) {
60+
Some(ty) => self.deduce_closure_signature(
61+
self.try_structurally_resolve_type(expr_span, ty),
62+
closure.kind,
63+
),
64+
None => (None, None),
7165
};
7266

7367
let ClosureSignatures { bound_sig, mut liberated_sig } =
@@ -323,11 +317,13 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
323317
fn deduce_closure_signature(
324318
&self,
325319
expected_ty: Ty<'tcx>,
320+
closure_kind: hir::ClosureKind,
326321
) -> (Option<ExpectedSig<'tcx>>, Option<ty::ClosureKind>) {
327322
match *expected_ty.kind() {
328323
ty::Alias(ty::Opaque, ty::AliasTy { def_id, args, .. }) => self
329324
.deduce_closure_signature_from_predicates(
330325
expected_ty,
326+
closure_kind,
331327
self.tcx
332328
.explicit_item_bounds(def_id)
333329
.iter_instantiated_copied(self.tcx, args)
@@ -336,7 +332,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
336332
ty::Dynamic(object_type, ..) => {
337333
let sig = object_type.projection_bounds().find_map(|pb| {
338334
let pb = pb.with_self_ty(self.tcx, self.tcx.types.trait_object_dummy_self);
339-
self.deduce_sig_from_projection(None, pb)
335+
self.deduce_sig_from_projection(None, closure_kind, pb)
340336
});
341337
let kind = object_type
342338
.principal_def_id()
@@ -345,19 +341,26 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
345341
}
346342
ty::Infer(ty::TyVar(vid)) => self.deduce_closure_signature_from_predicates(
347343
Ty::new_var(self.tcx, self.root_var(vid)),
344+
closure_kind,
348345
self.obligations_for_self_ty(vid).map(|obl| (obl.predicate, obl.cause.span)),
349346
),
350-
ty::FnPtr(sig) => {
351-
let expected_sig = ExpectedSig { cause_span: None, sig };
352-
(Some(expected_sig), Some(ty::ClosureKind::Fn))
353-
}
347+
ty::FnPtr(sig) => match closure_kind {
348+
hir::ClosureKind::Closure => {
349+
let expected_sig = ExpectedSig { cause_span: None, sig };
350+
(Some(expected_sig), Some(ty::ClosureKind::Fn))
351+
}
352+
hir::ClosureKind::Coroutine(_) | hir::ClosureKind::CoroutineClosure(_) => {
353+
(None, None)
354+
}
355+
},
354356
_ => (None, None),
355357
}
356358
}
357359

358360
fn deduce_closure_signature_from_predicates(
359361
&self,
360362
expected_ty: Ty<'tcx>,
363+
closure_kind: hir::ClosureKind,
361364
predicates: impl DoubleEndedIterator<Item = (ty::Predicate<'tcx>, Span)>,
362365
) -> (Option<ExpectedSig<'tcx>>, Option<ty::ClosureKind>) {
363366
let mut expected_sig = None;
@@ -386,6 +389,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
386389
span,
387390
self.deduce_sig_from_projection(
388391
Some(span),
392+
closure_kind,
389393
bound_predicate.rebind(proj_predicate),
390394
),
391395
);
@@ -422,13 +426,22 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
422426
ty::PredicateKind::Clause(ty::ClauseKind::Trait(data)) => Some(data.def_id()),
423427
_ => None,
424428
};
425-
if let Some(closure_kind) =
426-
trait_def_id.and_then(|def_id| self.tcx.fn_trait_kind_from_def_id(def_id))
427-
{
428-
expected_kind = Some(
429-
expected_kind
430-
.map_or_else(|| closure_kind, |current| cmp::min(current, closure_kind)),
431-
);
429+
430+
if let Some(trait_def_id) = trait_def_id {
431+
let found_kind = match closure_kind {
432+
hir::ClosureKind::Closure => self.tcx.fn_trait_kind_from_def_id(trait_def_id),
433+
hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Async) => {
434+
self.tcx.async_fn_trait_kind_from_def_id(trait_def_id)
435+
}
436+
_ => None,
437+
};
438+
439+
if let Some(found_kind) = found_kind {
440+
expected_kind = Some(
441+
expected_kind
442+
.map_or_else(|| found_kind, |current| cmp::min(current, found_kind)),
443+
);
444+
}
432445
}
433446
}
434447

@@ -445,14 +458,21 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
445458
fn deduce_sig_from_projection(
446459
&self,
447460
cause_span: Option<Span>,
461+
closure_kind: hir::ClosureKind,
448462
projection: ty::PolyProjectionPredicate<'tcx>,
449463
) -> Option<ExpectedSig<'tcx>> {
450464
let tcx = self.tcx;
451465

452466
let trait_def_id = projection.trait_def_id(tcx);
453-
// For now, we only do signature deduction based off of the `Fn` traits.
454-
if !tcx.is_fn_trait(trait_def_id) {
455-
return None;
467+
468+
// For now, we only do signature deduction based off of the `Fn` and `AsyncFn` traits,
469+
// for closures and async closures, respectively.
470+
match closure_kind {
471+
hir::ClosureKind::Closure
472+
if self.tcx.fn_trait_kind_from_def_id(trait_def_id).is_some() => {}
473+
hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Async)
474+
if self.tcx.async_fn_trait_kind_from_def_id(trait_def_id).is_some() => {}
475+
_ => return None,
456476
}
457477

458478
let arg_param_ty = projection.skip_binder().projection_ty.args.type_at(1);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
//@ check-pass
2+
//@ edition: 2021
3+
4+
#![feature(async_closure)]
5+
6+
async fn foo(x: impl async Fn(&str) -> &str) {}
7+
8+
fn main() {
9+
foo(async |x| x);
10+
}

0 commit comments

Comments
 (0)