Skip to content

Commit 0406604

Browse files
committed
Rudimentary support for passing type parameters into generic functions.
Instead of generating an independent function instance for every combination of type parameters at compile time we construct generic function instances at runtime using "generic factory functions". Such a factory takes type params as arguments and returns a concrete instance of the function for the given type params (type param values are captured by the returned function as a closure and can be used as necessary). Here is an abbreviated example of how a generic function is compiled and called: ``` // Go: func F[T any](t T) {} f(1) // JS: F = function(T){ return function(t) {}; }; F($Int)(1); ``` This approach minimizes the size of the generated JS source, which is critical for the client-side use case, at the cost of runtime performance. See gopherjs#1013 (comment) for the detailed description. Note that the implementation in this commit is far from complete: - Generic function instances are not cached. - Generic types are not supported. - Declaring types dependent on type parameters doesn't work correctly. - Operators (such as `+`) do not work correctly with generic arguments.
1 parent 900dda7 commit 0406604

File tree

4 files changed

+123
-20
lines changed

4 files changed

+123
-20
lines changed

compiler/expressions.go

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,10 +492,18 @@ func (fc *funcContext) translateExpr(expr ast.Expr) *expression {
492492
)
493493
case *types.Basic:
494494
return fc.formatExpr("%e.charCodeAt(%f)", e.X, e.Index)
495+
case *types.Signature:
496+
return fc.translateGenericInstance(e)
495497
default:
496-
panic(fmt.Sprintf("Unhandled IndexExpr: %T\n", t))
498+
panic(fmt.Errorf("unhandled IndexExpr: %T", t))
499+
}
500+
case *ast.IndexListExpr:
501+
switch t := fc.pkgCtx.TypeOf(e.X).Underlying().(type) {
502+
case *types.Signature:
503+
return fc.translateGenericInstance(e)
504+
default:
505+
panic(fmt.Errorf("unhandled IndexListExpr: %T", t))
497506
}
498-
499507
case *ast.SliceExpr:
500508
if b, isBasic := fc.pkgCtx.TypeOf(e.X).Underlying().(*types.Basic); isBasic && isString(b) {
501509
switch {
@@ -749,6 +757,10 @@ func (fc *funcContext) translateExpr(expr ast.Expr) *expression {
749757
case *types.Var, *types.Const:
750758
return fc.formatExpr("%s", fc.objectName(o))
751759
case *types.Func:
760+
if _, ok := fc.pkgCtx.Info.Instances[e]; ok {
761+
// Generic function call with auto-inferred types.
762+
return fc.translateGenericInstance(e)
763+
}
752764
return fc.formatExpr("%s", fc.objectName(o))
753765
case *types.TypeName:
754766
return fc.formatExpr("%s", fc.typeName(o.Type()))
@@ -788,6 +800,38 @@ func (fc *funcContext) translateExpr(expr ast.Expr) *expression {
788800
}
789801
}
790802

803+
// translateGenericInstance translates a generic function instantiation.
804+
//
805+
// The returned JS expression evaluates into a callable function with type params
806+
// substituted.
807+
func (fc *funcContext) translateGenericInstance(e ast.Expr) *expression {
808+
var identifier *ast.Ident
809+
switch e := e.(type) {
810+
case *ast.Ident:
811+
identifier = e
812+
case *ast.IndexExpr:
813+
identifier = e.X.(*ast.Ident)
814+
case *ast.IndexListExpr:
815+
identifier = e.X.(*ast.Ident)
816+
default:
817+
err := bailout(fmt.Errorf("unexpected generic instantiation expression type %T at %s", e, fc.pkgCtx.fileSet.Position(e.Pos())))
818+
panic(err)
819+
}
820+
821+
instance, ok := fc.pkgCtx.Info.Instances[identifier]
822+
if !ok {
823+
err := fmt.Errorf("no matching generic instantiation for %q at %s", identifier, fc.pkgCtx.fileSet.Position(identifier.Pos()))
824+
bailout(err)
825+
}
826+
typeParams := []string{}
827+
for i := 0; i < instance.TypeArgs.Len(); i++ {
828+
t := instance.TypeArgs.At(i)
829+
typeParams = append(typeParams, fc.typeName(t))
830+
}
831+
o := fc.pkgCtx.Uses[identifier]
832+
return fc.formatExpr("%s(%s)", fc.objectName(o), strings.Join(typeParams, ", "))
833+
}
834+
791835
func (fc *funcContext) translateCall(e *ast.CallExpr, sig *types.Signature, fun *expression) *expression {
792836
args := fc.translateArgs(sig, e.Args, e.Ellipsis.IsValid())
793837
if fc.Blocking[e] {

compiler/package.go

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ func Compile(importPath string, files []*ast.File, fileSet *token.FileSet, impor
180180
Implicits: make(map[ast.Node]types.Object),
181181
Selections: make(map[*ast.SelectorExpr]*types.Selection),
182182
Scopes: make(map[ast.Node]*types.Scope),
183+
Instances: make(map[*ast.Ident]types.Instance),
183184
}
184185

185186
var errList ErrorList
@@ -294,7 +295,7 @@ func Compile(importPath string, files []*ast.File, fileSet *token.FileSet, impor
294295
// but now we do it here to maintain previous behavior.
295296
continue
296297
}
297-
funcCtx.pkgCtx.pkgVars[importedPkg.Path()] = funcCtx.newVariable(importedPkg.Name(), true)
298+
funcCtx.pkgCtx.pkgVars[importedPkg.Path()] = funcCtx.newVariable(importedPkg.Name(), varPackage)
298299
importedPaths = append(importedPaths, importedPkg.Path())
299300
}
300301
sort.Strings(importedPaths)
@@ -521,7 +522,7 @@ func Compile(importPath string, files []*ast.File, fileSet *token.FileSet, impor
521522
d.DeclCode = funcCtx.CatchOutput(0, func() {
522523
typeName := funcCtx.objectName(o)
523524
lhs := typeName
524-
if isPkgLevel(o) {
525+
if typeVarLevel(o) == varPackage {
525526
lhs += " = $pkg." + encodeIdent(o.Name())
526527
}
527528
size := int64(0)
@@ -898,5 +899,22 @@ func translateFunction(typ *ast.FuncType, recv *ast.Ident, body *ast.BlockStmt,
898899

899900
c.pkgCtx.escapingVars = prevEV
900901

901-
return params, fmt.Sprintf("function%s(%s) {\n%s%s}", functionName, strings.Join(params, ", "), bodyOutput, c.Indentation(0))
902+
if !c.sigTypes.IsGeneric() {
903+
return params, fmt.Sprintf("function%s(%s) {\n%s%s}", functionName, strings.Join(params, ", "), bodyOutput, c.Indentation(0))
904+
}
905+
906+
// Generic functions are generated as factories to allow passing type parameters
907+
// from the call site.
908+
// TODO(nevkontakte): Cache function instances for a given combination of type
909+
// parameters.
910+
// TODO(nevkontakte): Generate type parameter arguments and derive all dependent
911+
// types inside the function.
912+
typeParams := []string{}
913+
for i := 0; i < c.sigTypes.Sig.TypeParams().Len(); i++ {
914+
typeParam := c.sigTypes.Sig.TypeParams().At(i)
915+
typeParams = append(typeParams, c.typeName(typeParam))
916+
}
917+
918+
return params, fmt.Sprintf("function%s(%s){ return function(%s) {\n%s%s}; }",
919+
functionName, strings.Join(typeParams, ", "), strings.Join(params, ", "), bodyOutput, c.Indentation(0))
902920
}

compiler/statements.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ func (fc *funcContext) translateStmt(stmt ast.Stmt, label *types.Label) {
444444
for _, spec := range decl.Specs {
445445
o := fc.pkgCtx.Defs[spec.(*ast.TypeSpec).Name].(*types.TypeName)
446446
fc.pkgCtx.typeNames = append(fc.pkgCtx.typeNames, o)
447-
fc.pkgCtx.objectNames[o] = fc.newVariable(o.Name(), true)
447+
fc.pkgCtx.objectNames[o] = fc.newVariable(o.Name(), varPackage)
448448
fc.pkgCtx.dependencies[o] = true
449449
}
450450
case token.CONST:

compiler/utils.go

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -237,9 +237,23 @@ func (fc *funcContext) newConst(t types.Type, value constant.Value) ast.Expr {
237237
// local variable name. In this context "local" means "in scope of the current"
238238
// functionContext.
239239
func (fc *funcContext) newLocalVariable(name string) string {
240-
return fc.newVariable(name, false)
240+
return fc.newVariable(name, varFuncLocal)
241241
}
242242

243+
// varLevel specifies at which level a JavaScript variable should be declared.
244+
type varLevel int
245+
246+
const (
247+
// A variable defined at a function level (e.g. local variables).
248+
varFuncLocal = iota
249+
// A variable that should be declared in a generic type or function factory.
250+
// This is mainly for type parameters and generic-dependent types.
251+
varGenericFactory
252+
// A variable that should be declared in a package factory. This user is for
253+
// top-level functions, types, etc.
254+
varPackage
255+
)
256+
243257
// newVariable assigns a new JavaScript variable name for the given Go variable
244258
// or type.
245259
//
@@ -252,7 +266,7 @@ func (fc *funcContext) newLocalVariable(name string) string {
252266
// to this functionContext, as well as all parents, but not to the list of local
253267
// variables. If false, it is added to this context only, as well as the list of
254268
// local vars.
255-
func (fc *funcContext) newVariable(name string, pkgLevel bool) string {
269+
func (fc *funcContext) newVariable(name string, level varLevel) string {
256270
if name == "" {
257271
panic("newVariable: empty name")
258272
}
@@ -261,7 +275,7 @@ func (fc *funcContext) newVariable(name string, pkgLevel bool) string {
261275
i := 0
262276
for {
263277
offset := int('a')
264-
if pkgLevel {
278+
if level == varPackage {
265279
offset = int('A')
266280
}
267281
j := i
@@ -286,9 +300,22 @@ func (fc *funcContext) newVariable(name string, pkgLevel bool) string {
286300
varName = fmt.Sprintf("%s$%d", name, n)
287301
}
288302

289-
if pkgLevel {
290-
for c2 := fc.parent; c2 != nil; c2 = c2.parent {
291-
c2.allVars[name] = n + 1
303+
// Package-level variables are registered in all outer scopes.
304+
if level == varPackage {
305+
for c := fc.parent; c != nil; c = c.parent {
306+
c.allVars[name] = n + 1
307+
}
308+
return varName
309+
}
310+
311+
// Generic-factory level variables are registered in outer scopes up to the
312+
// level of the generic function or method.
313+
if level == varGenericFactory {
314+
for c := fc; c != nil; c = c.parent {
315+
c.allVars[name] = n + 1
316+
if c.sigTypes.IsGeneric() {
317+
break
318+
}
292319
}
293320
return varName
294321
}
@@ -331,14 +358,20 @@ func isVarOrConst(o types.Object) bool {
331358
return false
332359
}
333360

334-
func isPkgLevel(o types.Object) bool {
335-
return o.Parent() != nil && o.Parent().Parent() == types.Universe
361+
func typeVarLevel(o types.Object) varLevel {
362+
if _, ok := o.Type().(*types.TypeParam); ok {
363+
return varGenericFactory
364+
}
365+
if o.Parent() != nil && o.Parent().Parent() == types.Universe {
366+
return varPackage
367+
}
368+
return varFuncLocal
336369
}
337370

338371
// objectName returns a JS identifier corresponding to the given types.Object.
339372
// Repeated calls for the same object will return the same name.
340373
func (fc *funcContext) objectName(o types.Object) string {
341-
if isPkgLevel(o) {
374+
if typeVarLevel(o) == varPackage {
342375
fc.pkgCtx.dependencies[o] = true
343376

344377
if o.Pkg() != fc.pkgCtx.Pkg || (isVarOrConst(o) && o.Exported()) {
@@ -348,7 +381,7 @@ func (fc *funcContext) objectName(o types.Object) string {
348381

349382
name, ok := fc.pkgCtx.objectNames[o]
350383
if !ok {
351-
name = fc.newVariable(o.Name(), isPkgLevel(o))
384+
name = fc.newVariable(o.Name(), typeVarLevel(o))
352385
fc.pkgCtx.objectNames[o] = name
353386
}
354387

@@ -359,13 +392,13 @@ func (fc *funcContext) objectName(o types.Object) string {
359392
}
360393

361394
func (fc *funcContext) varPtrName(o *types.Var) string {
362-
if isPkgLevel(o) && o.Exported() {
395+
if typeVarLevel(o) == varPackage && o.Exported() {
363396
return fc.pkgVar(o.Pkg()) + "." + o.Name() + "$ptr"
364397
}
365398

366399
name, ok := fc.pkgCtx.varPtrNames[o]
367400
if !ok {
368-
name = fc.newVariable(o.Name()+"$ptr", isPkgLevel(o))
401+
name = fc.newVariable(o.Name()+"$ptr", typeVarLevel(o))
369402
fc.pkgCtx.varPtrNames[o] = name
370403
}
371404
return name
@@ -385,6 +418,8 @@ func (fc *funcContext) typeName(ty types.Type) string {
385418
return "$error"
386419
}
387420
return fc.objectName(t.Obj())
421+
case *types.TypeParam:
422+
return fc.objectName(t.Obj())
388423
case *types.Interface:
389424
if t.Empty() {
390425
return "$emptyInterface"
@@ -397,8 +432,8 @@ func (fc *funcContext) typeName(ty types.Type) string {
397432
// repeatedly.
398433
anonType, ok := fc.pkgCtx.anonTypeMap.At(ty).(*types.TypeName)
399434
if !ok {
400-
fc.initArgs(ty) // cause all embedded types to be registered
401-
varName := fc.newVariable(strings.ToLower(typeKind(ty)[5:])+"Type", true)
435+
fc.initArgs(ty) // cause all dependency types to be registered
436+
varName := fc.newVariable(strings.ToLower(typeKind(ty)[5:])+"Type", varPackage)
402437
anonType = types.NewTypeName(token.NoPos, fc.pkgCtx.Pkg, varName, ty) // fake types.TypeName
403438
fc.pkgCtx.anonTypes = append(fc.pkgCtx.anonTypes, anonType)
404439
fc.pkgCtx.anonTypeMap.Set(ty, anonType)
@@ -815,6 +850,12 @@ func (st signatureTypes) HasNamedResults() bool {
815850
return st.HasResults() && st.Sig.Results().At(0).Name() != ""
816851
}
817852

853+
// IsGeneric returns true if the signature represents a generic function or a
854+
// method of a generic type.
855+
func (st signatureTypes) IsGeneric() bool {
856+
return st.Sig.TypeParams().Len() > 0 || st.Sig.RecvTypeParams().Len() > 0
857+
}
858+
818859
// ErrorAt annotates an error with a position in the source code.
819860
func ErrorAt(err error, fset *token.FileSet, pos token.Pos) error {
820861
return fmt.Errorf("%s: %w", fset.Position(pos), err)

0 commit comments

Comments
 (0)