Skip to content

Commit 65222d6

Browse files
committed
Check that generator return type is assignable from its IterableIterator type
1 parent a9e1d48 commit 65222d6

File tree

1 file changed

+64
-17
lines changed

1 file changed

+64
-17
lines changed

src/compiler/checker.ts

+64-17
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ module ts {
111111
let globalTemplateStringsArrayType: ObjectType;
112112
let globalESSymbolType: ObjectType;
113113
let globalIterableType: GenericType;
114+
let globalIteratorType: GenericType;
114115
let globalIterableIteratorType: GenericType;
115116

116117
let anyArrayType: Type;
@@ -2119,7 +2120,7 @@ module ts {
21192120
// checkRightHandSideOfForOf will return undefined if the for-of expression type was
21202121
// missing properties/signatures required to get its iteratedType (like
21212122
// [Symbol.iterator] or next). This may be because we accessed properties from anyType,
2122-
// or it may have led to an error inside getIteratedType.
2123+
// or it may have led to an error inside getElementTypeFromIterable.
21232124
return checkRightHandSideOfForOf((<ForOfStatement>declaration.parent.parent).expression) || anyType;
21242125
}
21252126
if (isBindingPattern(declaration.parent)) {
@@ -5854,7 +5855,7 @@ module ts {
58545855
let index = indexOf(arrayLiteral.elements, node);
58555856
return getTypeOfPropertyOfContextualType(type, "" + index)
58565857
|| getIndexTypeOfContextualType(type, IndexKind.Number)
5857-
|| (languageVersion >= ScriptTarget.ES6 ? getIteratedType(type, /*expressionForError*/ undefined) : undefined);
5858+
|| (languageVersion >= ScriptTarget.ES6 ? getElementTypeFromIterable(type, /*expressionForError*/ undefined) : undefined);
58585859
}
58595860
return undefined;
58605861
}
@@ -6041,7 +6042,7 @@ module ts {
60416042
// if there is no index type / iterated type.
60426043
let restArrayType = checkExpression((<SpreadElementExpression>e).expression, contextualMapper);
60436044
let restElementType = getIndexTypeOfType(restArrayType, IndexKind.Number) ||
6044-
(languageVersion >= ScriptTarget.ES6 ? getIteratedType(restArrayType, /*expressionForError*/ undefined) : undefined);
6045+
(languageVersion >= ScriptTarget.ES6 ? getElementTypeFromIterable(restArrayType, /*expressionForError*/ undefined) : undefined);
60456046

60466047
if (restElementType) {
60476048
elementTypes.push(restElementType);
@@ -8188,6 +8189,22 @@ module ts {
81888189
break;
81898190
}
81908191
}
8192+
8193+
if (node.type) {
8194+
if (languageVersion >= ScriptTarget.ES6 && isSyntacticallyValidGenerator(node)) {
8195+
let returnType = getTypeFromTypeNode(node.type);
8196+
let generatorElementType = getElementTypeFromIterableIterator(returnType, /*errorNode*/ undefined) || anyType;
8197+
let iterableIteratorInstantiation = createIterableIteratorType(generatorElementType);
8198+
8199+
// Naively, one could check that IterableIterator<any> is assignable to the return type annotation.
8200+
// However, that would not catch the error in the following case.
8201+
//
8202+
// interface BadGenerator extends Iterable<number>, Iterator<string> { }
8203+
// function* g(): BadGenerator { } // Iterable and Iterator have different types!
8204+
//
8205+
checkTypeAssignableTo(iterableIteratorInstantiation, returnType, node.type);
8206+
}
8207+
}
81918208
}
81928209

81938210
checkSpecializedSignatureDeclaration(node);
@@ -9385,7 +9402,7 @@ module ts {
93859402
// iteratedType will be undefined if the rightType was missing properties/signatures
93869403
// required to get its iteratedType (like [Symbol.iterator] or next). This may be
93879404
// because we accessed properties from anyType, or it may have led to an error inside
9388-
// getIteratedType.
9405+
// getElementTypeFromIterable.
93899406
if (iteratedType) {
93909407
checkTypeAssignableTo(iteratedType, leftType, varExpr, /*headMessage*/ undefined);
93919408
}
@@ -9483,30 +9500,24 @@ module ts {
94839500
* When errorNode is undefined, it means we should not report any errors.
94849501
*/
94859502
function checkIteratedType(iterable: Type, errorNode: Node): Type {
9486-
let iteratedType = getIteratedType(iterable, errorNode);
9503+
let elementType = getElementTypeFromIterable(iterable, errorNode);
94879504
// Now even though we have extracted the iteratedType, we will have to validate that the type
94889505
// passed in is actually an Iterable.
9489-
if (errorNode && iteratedType) {
9490-
checkTypeAssignableTo(iterable, createIterableType(iteratedType), errorNode);
9506+
if (errorNode && elementType) {
9507+
checkTypeAssignableTo(iterable, createIterableType(elementType), errorNode);
94919508
}
94929509

9493-
return iteratedType;
9510+
return elementType;
94949511
}
94959512

9496-
function getIteratedType(iterable: Type, errorNode: Node) {
9513+
function getElementTypeFromIterable(iterable: Type, errorNode: Node): Type {
94979514
Debug.assert(languageVersion >= ScriptTarget.ES6);
94989515
// We want to treat type as an iterable, and get the type it is an iterable of. The iterable
94999516
// must have the following structure (annotated with the names of the variables below):
95009517
//
95019518
// { // iterable
95029519
// [Symbol.iterator]: { // iteratorFunction
9503-
// (): { // iterator
9504-
// next: { // iteratorNextFunction
9505-
// (): { // iteratorNextResult
9506-
// value: T // iteratorNextValue
9507-
// }
9508-
// }
9509-
// }
9520+
// (): Iterator<T>
95109521
// }
95119522
// }
95129523
//
@@ -9544,11 +9555,31 @@ module ts {
95449555
return undefined;
95459556
}
95469557

9547-
let iterator = getUnionType(map(iteratorFunctionSignatures, getReturnTypeOfSignature));
9558+
return getElementTypeFromIterator(getUnionType(map(iteratorFunctionSignatures, getReturnTypeOfSignature)), errorNode);
9559+
}
9560+
9561+
function getElementTypeFromIterator(iterator: Type, errorNode: Node): Type {
9562+
// This function has very similar logic as getElementTypeFromIterable, except that it operates on
9563+
// Iterators instead of Iterables. Here is the structure:
9564+
//
9565+
// { // iterator
9566+
// next: { // iteratorNextFunction
9567+
// (): { // iteratorNextResult
9568+
// value: T // iteratorNextValue
9569+
// }
9570+
// }
9571+
// }
9572+
//
95489573
if (allConstituentTypesHaveKind(iterator, TypeFlags.Any)) {
95499574
return undefined;
95509575
}
95519576

9577+
// As an optimization, if the type is instantiated directly using the globalIteratorType (Iterator<number>),
9578+
// then just grab its type argument.
9579+
if ((iterator.flags & TypeFlags.Reference) && (<GenericType>iterator).target === globalIteratorType) {
9580+
return (<GenericType>iterator).typeArguments[0];
9581+
}
9582+
95529583
let iteratorNextFunction = getTypeOfPropertyOfType(iterator, "next");
95539584
if (iteratorNextFunction && allConstituentTypesHaveKind(iteratorNextFunction, TypeFlags.Any)) {
95549585
return undefined;
@@ -9578,6 +9609,21 @@ module ts {
95789609
return iteratorNextValue;
95799610
}
95809611

9612+
function getElementTypeFromIterableIterator(iterableIterator: Type, errorNode: Node): Type {
9613+
if (allConstituentTypesHaveKind(iterableIterator, TypeFlags.Any)) {
9614+
return undefined;
9615+
}
9616+
9617+
// As an optimization, if the type is instantiated directly using the globalIterableIteratorType (IterableIterator<number>),
9618+
// then just grab its type argument.
9619+
if ((iterableIterator.flags & TypeFlags.Reference) && (<GenericType>iterableIterator).target === globalIterableIteratorType) {
9620+
return (<GenericType>iterableIterator).typeArguments[0];
9621+
}
9622+
9623+
return getElementTypeFromIterable(iterableIterator, errorNode) ||
9624+
getElementTypeFromIterator(iterableIterator, errorNode);
9625+
}
9626+
95819627
/**
95829628
* This function does the following steps:
95839629
* 1. Break up arrayOrStringType (possibly a union) into its string constituents and array constituents.
@@ -12000,6 +12046,7 @@ module ts {
1200012046
globalESSymbolType = getGlobalType("Symbol");
1200112047
globalESSymbolConstructorSymbol = getGlobalValueSymbol("Symbol");
1200212048
globalIterableType = <GenericType>getGlobalType("Iterable", /*arity*/ 1);
12049+
globalIteratorType = <GenericType>getGlobalType("Iterator", /*arity*/ 1);
1200312050
globalIterableIteratorType = <GenericType>getGlobalType("IterableIterator", /*arity*/ 1);
1200412051
}
1200512052
else {

0 commit comments

Comments
 (0)