Skip to content

Commit a3c2218

Browse files
authored
Improve inference for context sensitive functions in object and array literal arguments (#48538)
* Use intra-expression inference sites in type argument inference * Accept new baselines * Add tests
1 parent 7da80d7 commit a3c2218

File tree

8 files changed

+1968
-43
lines changed

8 files changed

+1968
-43
lines changed

src/compiler/checker.ts

+87-31
Original file line numberDiff line numberDiff line change
@@ -21749,6 +21749,9 @@ namespace ts {
2174921749
const inference = inferences[i];
2175021750
if (t === inference.typeParameter) {
2175121751
if (fix && !inference.isFixed) {
21752+
// Before we commit to a particular inference (and thus lock out any further inferences),
21753+
// we infer from any intra-expression inference sites we have collected.
21754+
inferFromIntraExpressionSites(context);
2175221755
clearCachedInferences(inferences);
2175321756
inference.isFixed = true;
2175421757
}
@@ -21766,6 +21769,37 @@ namespace ts {
2176621769
}
2176721770
}
2176821771

21772+
function addIntraExpressionInferenceSite(context: InferenceContext, node: Expression | MethodDeclaration, type: Type) {
21773+
(context.intraExpressionInferenceSites ??= []).push({ node, type });
21774+
}
21775+
21776+
// We collect intra-expression inference sites within object and array literals to handle cases where
21777+
// inferred types flow between context sensitive element expressions. For example:
21778+
//
21779+
// declare function foo<T>(arg: [(n: number) => T, (x: T) => void]): void;
21780+
// foo([_a => 0, n => n.toFixed()]);
21781+
//
21782+
// Above, both arrow functions in the tuple argument are context sensitive, thus both are omitted from the
21783+
// pass that collects inferences from the non-context sensitive parts of the arguments. In the subsequent
21784+
// pass where nothing is omitted, we need to commit to an inference for T in order to contextually type the
21785+
// parameter in the second arrow function, but we want to first infer from the return type of the first
21786+
// arrow function. This happens automatically when the arrow functions are discrete arguments (because we
21787+
// infer from each argument before processing the next), but when the arrow functions are elements of an
21788+
// object or array literal, we need to perform intra-expression inferences early.
21789+
function inferFromIntraExpressionSites(context: InferenceContext) {
21790+
if (context.intraExpressionInferenceSites) {
21791+
for (const { node, type } of context.intraExpressionInferenceSites) {
21792+
const contextualType = node.kind === SyntaxKind.MethodDeclaration ?
21793+
getContextualTypeForObjectLiteralMethod(node as MethodDeclaration, ContextFlags.NoConstraints) :
21794+
getContextualType(node, ContextFlags.NoConstraints);
21795+
if (contextualType) {
21796+
inferTypes(context.inferences, type, contextualType);
21797+
}
21798+
}
21799+
context.intraExpressionInferenceSites = undefined;
21800+
}
21801+
}
21802+
2176921803
function createInferenceInfo(typeParameter: TypeParameter): InferenceInfo {
2177021804
return {
2177121805
typeParameter,
@@ -27429,6 +27463,11 @@ namespace ts {
2742927463
const type = checkExpressionForMutableLocation(e, checkMode, elementContextualType, forceTuple);
2743027464
elementTypes.push(addOptionality(type, /*isProperty*/ true, hasOmittedExpression));
2743127465
elementFlags.push(hasOmittedExpression ? ElementFlags.Optional : ElementFlags.Required);
27466+
if (contextualType && someType(contextualType, isTupleLikeType) && checkMode && checkMode & CheckMode.Inferential && !(checkMode & CheckMode.SkipContextSensitive) && isContextSensitive(e)) {
27467+
const inferenceContext = getInferenceContext(node);
27468+
Debug.assert(inferenceContext); // In CheckMode.Inferential we should always have an inference context
27469+
addIntraExpressionInferenceSite(inferenceContext, e, type);
27470+
}
2743227471
}
2743327472
}
2743427473
if (inDestructuringPattern) {
@@ -27646,6 +27685,14 @@ namespace ts {
2764627685
prop.target = member;
2764727686
member = prop;
2764827687
allPropertiesTable?.set(prop.escapedName, prop);
27688+
27689+
if (contextualType && checkMode && checkMode & CheckMode.Inferential && !(checkMode & CheckMode.SkipContextSensitive) &&
27690+
(memberDecl.kind === SyntaxKind.PropertyAssignment || memberDecl.kind === SyntaxKind.MethodDeclaration) && isContextSensitive(memberDecl)) {
27691+
const inferenceContext = getInferenceContext(node);
27692+
Debug.assert(inferenceContext); // In CheckMode.Inferential we should always have an inference context
27693+
const inferenceNode = memberDecl.kind === SyntaxKind.PropertyAssignment ? memberDecl.initializer : memberDecl;
27694+
addIntraExpressionInferenceSite(inferenceContext, inferenceNode, type);
27695+
}
2764927696
}
2765027697
else if (memberDecl.kind === SyntaxKind.SpreadAssignment) {
2765127698
if (languageVersion < ScriptTarget.ES2015) {
@@ -29748,34 +29795,36 @@ namespace ts {
2974829795
if (node.kind !== SyntaxKind.Decorator) {
2974929796
const contextualType = getContextualType(node, every(signature.typeParameters, p => !!getDefaultFromTypeParameter(p)) ? ContextFlags.SkipBindingPatterns : ContextFlags.None);
2975029797
if (contextualType) {
29751-
// We clone the inference context to avoid disturbing a resolution in progress for an
29752-
// outer call expression. Effectively we just want a snapshot of whatever has been
29753-
// inferred for any outer call expression so far.
29754-
const outerContext = getInferenceContext(node);
29755-
const outerMapper = getMapperFromContext(cloneInferenceContext(outerContext, InferenceFlags.NoDefault));
29756-
const instantiatedType = instantiateType(contextualType, outerMapper);
29757-
// If the contextual type is a generic function type with a single call signature, we
29758-
// instantiate the type with its own type parameters and type arguments. This ensures that
29759-
// the type parameters are not erased to type any during type inference such that they can
29760-
// be inferred as actual types from the contextual type. For example:
29761-
// declare function arrayMap<T, U>(f: (x: T) => U): (a: T[]) => U[];
29762-
// const boxElements: <A>(a: A[]) => { value: A }[] = arrayMap(value => ({ value }));
29763-
// Above, the type of the 'value' parameter is inferred to be 'A'.
29764-
const contextualSignature = getSingleCallSignature(instantiatedType);
29765-
const inferenceSourceType = contextualSignature && contextualSignature.typeParameters ?
29766-
getOrCreateTypeFromSignature(getSignatureInstantiationWithoutFillingInTypeArguments(contextualSignature, contextualSignature.typeParameters)) :
29767-
instantiatedType;
2976829798
const inferenceTargetType = getReturnTypeOfSignature(signature);
29769-
// Inferences made from return types have lower priority than all other inferences.
29770-
inferTypes(context.inferences, inferenceSourceType, inferenceTargetType, InferencePriority.ReturnType);
29771-
// Create a type mapper for instantiating generic contextual types using the inferences made
29772-
// from the return type. We need a separate inference pass here because (a) instantiation of
29773-
// the source type uses the outer context's return mapper (which excludes inferences made from
29774-
// outer arguments), and (b) we don't want any further inferences going into this context.
29775-
const returnContext = createInferenceContext(signature.typeParameters!, signature, context.flags);
29776-
const returnSourceType = instantiateType(contextualType, outerContext && outerContext.returnMapper);
29777-
inferTypes(returnContext.inferences, returnSourceType, inferenceTargetType);
29778-
context.returnMapper = some(returnContext.inferences, hasInferenceCandidates) ? getMapperFromContext(cloneInferredPartOfContext(returnContext)) : undefined;
29799+
if (couldContainTypeVariables(inferenceTargetType)) {
29800+
// We clone the inference context to avoid disturbing a resolution in progress for an
29801+
// outer call expression. Effectively we just want a snapshot of whatever has been
29802+
// inferred for any outer call expression so far.
29803+
const outerContext = getInferenceContext(node);
29804+
const outerMapper = getMapperFromContext(cloneInferenceContext(outerContext, InferenceFlags.NoDefault));
29805+
const instantiatedType = instantiateType(contextualType, outerMapper);
29806+
// If the contextual type is a generic function type with a single call signature, we
29807+
// instantiate the type with its own type parameters and type arguments. This ensures that
29808+
// the type parameters are not erased to type any during type inference such that they can
29809+
// be inferred as actual types from the contextual type. For example:
29810+
// declare function arrayMap<T, U>(f: (x: T) => U): (a: T[]) => U[];
29811+
// const boxElements: <A>(a: A[]) => { value: A }[] = arrayMap(value => ({ value }));
29812+
// Above, the type of the 'value' parameter is inferred to be 'A'.
29813+
const contextualSignature = getSingleCallSignature(instantiatedType);
29814+
const inferenceSourceType = contextualSignature && contextualSignature.typeParameters ?
29815+
getOrCreateTypeFromSignature(getSignatureInstantiationWithoutFillingInTypeArguments(contextualSignature, contextualSignature.typeParameters)) :
29816+
instantiatedType;
29817+
// Inferences made from return types have lower priority than all other inferences.
29818+
inferTypes(context.inferences, inferenceSourceType, inferenceTargetType, InferencePriority.ReturnType);
29819+
// Create a type mapper for instantiating generic contextual types using the inferences made
29820+
// from the return type. We need a separate inference pass here because (a) instantiation of
29821+
// the source type uses the outer context's return mapper (which excludes inferences made from
29822+
// outer arguments), and (b) we don't want any further inferences going into this context.
29823+
const returnContext = createInferenceContext(signature.typeParameters!, signature, context.flags);
29824+
const returnSourceType = instantiateType(contextualType, outerContext && outerContext.returnMapper);
29825+
inferTypes(returnContext.inferences, returnSourceType, inferenceTargetType);
29826+
context.returnMapper = some(returnContext.inferences, hasInferenceCandidates) ? getMapperFromContext(cloneInferredPartOfContext(returnContext)) : undefined;
29827+
}
2977929828
}
2978029829
}
2978129830

@@ -29789,7 +29838,7 @@ namespace ts {
2978929838
}
2979029839

2979129840
const thisType = getThisTypeOfSignature(signature);
29792-
if (thisType) {
29841+
if (thisType && couldContainTypeVariables(thisType)) {
2979329842
const thisArgumentNode = getThisArgumentOfCall(node);
2979429843
inferTypes(context.inferences, getThisArgumentType(thisArgumentNode), thisType);
2979529844
}
@@ -29798,12 +29847,14 @@ namespace ts {
2979829847
const arg = args[i];
2979929848
if (arg.kind !== SyntaxKind.OmittedExpression && !(checkMode & CheckMode.IsForStringLiteralArgumentCompletions && hasSkipDirectInferenceFlag(arg))) {
2980029849
const paramType = getTypeAtPosition(signature, i);
29801-
const argType = checkExpressionWithContextualType(arg, paramType, context, checkMode);
29802-
inferTypes(context.inferences, argType, paramType);
29850+
if (couldContainTypeVariables(paramType)) {
29851+
const argType = checkExpressionWithContextualType(arg, paramType, context, checkMode);
29852+
inferTypes(context.inferences, argType, paramType);
29853+
}
2980329854
}
2980429855
}
2980529856

29806-
if (restType) {
29857+
if (restType && couldContainTypeVariables(restType)) {
2980729858
const spreadType = getSpreadArgumentType(args, argCount, args.length, restType, context, checkMode);
2980829859
inferTypes(context.inferences, spreadType, restType);
2980929860
}
@@ -34162,6 +34213,11 @@ namespace ts {
3416234213
context.contextualType = contextualType;
3416334214
context.inferenceContext = inferenceContext;
3416434215
const type = checkExpression(node, checkMode | CheckMode.Contextual | (inferenceContext ? CheckMode.Inferential : 0));
34216+
// In CheckMode.Inferential we collect intra-expression inference sites to process before fixing any type
34217+
// parameters. This information is no longer needed after the call to checkExpression.
34218+
if (inferenceContext && inferenceContext.intraExpressionInferenceSites) {
34219+
inferenceContext.intraExpressionInferenceSites = undefined;
34220+
}
3416534221
// We strip literal freshness when an appropriate contextual type is present such that contextually typed
3416634222
// literals always preserve their literal types (otherwise they might widen during type inference). An alternative
3416734223
// here would be to not mark contextually typed literals as fresh in the first place.

src/compiler/types.ts

+7
Original file line numberDiff line numberDiff line change
@@ -5915,6 +5915,13 @@ namespace ts {
59155915
nonFixingMapper: TypeMapper; // Mapper that doesn't fix inferences
59165916
returnMapper?: TypeMapper; // Type mapper for inferences from return types (if any)
59175917
inferredTypeParameters?: readonly TypeParameter[]; // Inferred type parameters for function result
5918+
intraExpressionInferenceSites?: IntraExpressionInferenceSite[];
5919+
}
5920+
5921+
/* @internal */
5922+
export interface IntraExpressionInferenceSite {
5923+
node: Expression | MethodDeclaration;
5924+
type: Type;
59185925
}
59195926

59205927
/* @internal */

tests/baselines/reference/contextualTypingOfOptionalMembers.types

+12-12
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ declare function app<State, Actions extends ActionsObject<State>>(obj: Options<S
2525
app({
2626
>app({ state: 100, actions: { foo: s => s // Should be typed number => number }, view: (s, a) => undefined as any,}) : void
2727
>app : <State, Actions extends ActionsObject<State>>(obj: Options<State, Actions>) => void
28-
>{ state: 100, actions: { foo: s => s // Should be typed number => number }, view: (s, a) => undefined as any,} : { state: number; actions: { foo: (s: number) => number; }; view: (s: number, a: ActionsObject<number>) => any; }
28+
>{ state: 100, actions: { foo: s => s // Should be typed number => number }, view: (s, a) => undefined as any,} : { state: number; actions: { foo: (s: number) => number; }; view: (s: number, a: { foo: (s: number) => number; }) => any; }
2929

3030
state: 100,
3131
>state : number
@@ -43,10 +43,10 @@ app({
4343

4444
},
4545
view: (s, a) => undefined as any,
46-
>view : (s: number, a: ActionsObject<number>) => any
47-
>(s, a) => undefined as any : (s: number, a: ActionsObject<number>) => any
46+
>view : (s: number, a: { foo: (s: number) => number; }) => any
47+
>(s, a) => undefined as any : (s: number, a: { foo: (s: number) => number; }) => any
4848
>s : number
49-
>a : ActionsObject<number>
49+
>a : { foo: (s: number) => number; }
5050
>undefined as any : any
5151
>undefined : undefined
5252

@@ -95,7 +95,7 @@ declare function app2<State, Actions extends ActionsObject<State>>(obj: Options2
9595
app2({
9696
>app2({ state: 100, actions: { foo: s => s // Should be typed number => number }, view: (s, a) => undefined as any,}) : void
9797
>app2 : <State, Actions extends ActionsObject<State>>(obj: Options2<State, Actions>) => void
98-
>{ state: 100, actions: { foo: s => s // Should be typed number => number }, view: (s, a) => undefined as any,} : { state: number; actions: { foo: (s: number) => number; }; view: (s: number, a: ActionsObject<number>) => any; }
98+
>{ state: 100, actions: { foo: s => s // Should be typed number => number }, view: (s, a) => undefined as any,} : { state: number; actions: { foo: (s: number) => number; }; view: (s: number, a: { foo: (s: number) => number; }) => any; }
9999

100100
state: 100,
101101
>state : number
@@ -113,10 +113,10 @@ app2({
113113

114114
},
115115
view: (s, a) => undefined as any,
116-
>view : (s: number, a: ActionsObject<number>) => any
117-
>(s, a) => undefined as any : (s: number, a: ActionsObject<number>) => any
116+
>view : (s: number, a: { foo: (s: number) => number; }) => any
117+
>(s, a) => undefined as any : (s: number, a: { foo: (s: number) => number; }) => any
118118
>s : number
119-
>a : ActionsObject<number>
119+
>a : { foo: (s: number) => number; }
120120
>undefined as any : any
121121
>undefined : undefined
122122

@@ -134,7 +134,7 @@ declare function app3<State, Actions extends ActionsArray<State>>(obj: Options<S
134134
app3({
135135
>app3({ state: 100, actions: [ s => s // Should be typed number => number ], view: (s, a) => undefined as any,}) : void
136136
>app3 : <State, Actions extends ActionsArray<State>>(obj: Options<State, Actions>) => void
137-
>{ state: 100, actions: [ s => s // Should be typed number => number ], view: (s, a) => undefined as any,} : { state: number; actions: ((s: number) => number)[]; view: (s: number, a: ActionsArray<number>) => any; }
137+
>{ state: 100, actions: [ s => s // Should be typed number => number ], view: (s, a) => undefined as any,} : { state: number; actions: ((s: number) => number)[]; view: (s: number, a: ((s: number) => number)[]) => any; }
138138

139139
state: 100,
140140
>state : number
@@ -151,10 +151,10 @@ app3({
151151

152152
],
153153
view: (s, a) => undefined as any,
154-
>view : (s: number, a: ActionsArray<number>) => any
155-
>(s, a) => undefined as any : (s: number, a: ActionsArray<number>) => any
154+
>view : (s: number, a: ((s: number) => number)[]) => any
155+
>(s, a) => undefined as any : (s: number, a: ((s: number) => number)[]) => any
156156
>s : number
157-
>a : ActionsArray<number>
157+
>a : ((s: number) => number)[]
158158
>undefined as any : any
159159
>undefined : undefined
160160

0 commit comments

Comments
 (0)