Skip to content

Commit cf5cb00

Browse files
adonovangopherbot
authored andcommitted
internal/astutil: PreorderStack: a safer ast.Inspect for stacks
This CL defines PreorderStack, a safer function than ast.Inspect for when you need to maintain a stack. Beware, the stack that it produces does not include n itself--a half-open interval--so that nested traversals compose correctly. The CL also uses the new function in various places in x/tools where appropriate; in some cases it was clearer to rewrite using cursor.Cursor. + test Updates golang/go#73319 Change-Id: I843122cdd49cc4af8a7318badd8c34389479a92a Reviewed-on: https://go-review.googlesource.com/c/tools/+/664635 Auto-Submit: Alan Donovan <[email protected]> Commit-Queue: Alan Donovan <[email protected]> Reviewed-by: Robert Findley <[email protected]> LUCI-TryBot-Result: Go LUCI <[email protected]>
1 parent f76b112 commit cf5cb00

File tree

8 files changed

+171
-115
lines changed

8 files changed

+171
-115
lines changed

go/analysis/passes/lostcancel/lostcancel.go

+7-14
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"golang.org/x/tools/go/ast/inspector"
1818
"golang.org/x/tools/go/cfg"
1919
"golang.org/x/tools/internal/analysisinternal"
20+
"golang.org/x/tools/internal/astutil"
2021
)
2122

2223
//go:embed doc.go
@@ -83,30 +84,22 @@ func runFunc(pass *analysis.Pass, node ast.Node) {
8384
// {FuncDecl,FuncLit,CallExpr,SelectorExpr}.
8485

8586
// Find the set of cancel vars to analyze.
86-
stack := make([]ast.Node, 0, 32)
87-
ast.Inspect(node, func(n ast.Node) bool {
88-
switch n.(type) {
89-
case *ast.FuncLit:
90-
if len(stack) > 0 {
91-
return false // don't stray into nested functions
92-
}
93-
case nil:
94-
stack = stack[:len(stack)-1] // pop
95-
return true
87+
astutil.PreorderStack(node, nil, func(n ast.Node, stack []ast.Node) bool {
88+
if _, ok := n.(*ast.FuncLit); ok && len(stack) > 0 {
89+
return false // don't stray into nested functions
9690
}
97-
stack = append(stack, n) // push
9891

99-
// Look for [{AssignStmt,ValueSpec} CallExpr SelectorExpr]:
92+
// Look for n=SelectorExpr beneath stack=[{AssignStmt,ValueSpec} CallExpr]:
10093
//
10194
// ctx, cancel := context.WithCancel(...)
10295
// ctx, cancel = context.WithCancel(...)
10396
// var ctx, cancel = context.WithCancel(...)
10497
//
105-
if !isContextWithCancel(pass.TypesInfo, n) || !isCall(stack[len(stack)-2]) {
98+
if !isContextWithCancel(pass.TypesInfo, n) || !isCall(stack[len(stack)-1]) {
10699
return true
107100
}
108101
var id *ast.Ident // id of cancel var
109-
stmt := stack[len(stack)-3]
102+
stmt := stack[len(stack)-2]
110103
switch stmt := stmt.(type) {
111104
case *ast.ValueSpec:
112105
if len(stmt.Names) > 1 {

gopls/internal/golang/codeaction.go

+10-21
Original file line numberDiff line numberDiff line change
@@ -713,33 +713,24 @@ func refactorRewriteEliminateDotImport(ctx context.Context, req *codeActionsRequ
713713

714714
// Go through each use of the dot imported package, checking its scope for
715715
// shadowing and calculating an edit to qualify the identifier.
716-
var stack []ast.Node
717-
ast.Inspect(req.pgf.File, func(n ast.Node) bool {
718-
if n == nil {
719-
stack = stack[:len(stack)-1] // pop
720-
return false
721-
}
722-
stack = append(stack, n) // push
716+
for curId := range req.pgf.Cursor.Preorder((*ast.Ident)(nil)) {
717+
ident := curId.Node().(*ast.Ident)
723718

724-
ident, ok := n.(*ast.Ident)
725-
if !ok {
726-
return true
727-
}
728719
// Only keep identifiers that use a symbol from the
729720
// dot imported package.
730721
use := req.pkg.TypesInfo().Uses[ident]
731722
if use == nil || use.Pkg() == nil {
732-
return true
723+
continue
733724
}
734725
if use.Pkg() != imported {
735-
return true
726+
continue
736727
}
737728

738729
// Only qualify unqualified identifiers (due to dot imports).
739730
// All other references to a symbol imported from another package
740731
// are nested within a select expression (pkg.Foo, v.Method, v.Field).
741-
if is[*ast.SelectorExpr](stack[len(stack)-2]) {
742-
return true
732+
if is[*ast.SelectorExpr](curId.Parent().Node()) {
733+
continue
743734
}
744735

745736
// Make sure that the package name will not be shadowed by something else in scope.
@@ -750,24 +741,22 @@ func refactorRewriteEliminateDotImport(ctx context.Context, req *codeActionsRequ
750741
// allowed to go through.
751742
sc := fileScope.Innermost(ident.Pos())
752743
if sc == nil {
753-
return true
744+
continue
754745
}
755746
_, obj := sc.LookupParent(newName, ident.Pos())
756747
if obj != nil {
757-
return true
748+
continue
758749
}
759750

760751
rng, err := req.pgf.PosRange(ident.Pos(), ident.Pos()) // sic, zero-width range before ident
761752
if err != nil {
762-
return true
753+
continue
763754
}
764755
edits = append(edits, protocol.TextEdit{
765756
Range: rng,
766757
NewText: newName + ".",
767758
})
768-
769-
return true
770-
})
759+
}
771760

772761
req.addEditAction("Eliminate dot import", nil, protocol.DocumentChangeEdit(
773762
req.fh,

gopls/internal/golang/hover.go

+3-8
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import (
3838
gastutil "golang.org/x/tools/gopls/internal/util/astutil"
3939
"golang.org/x/tools/gopls/internal/util/bug"
4040
"golang.org/x/tools/gopls/internal/util/safetoken"
41+
internalastutil "golang.org/x/tools/internal/astutil"
4142
"golang.org/x/tools/internal/event"
4243
"golang.org/x/tools/internal/stdlib"
4344
"golang.org/x/tools/internal/tokeninternal"
@@ -1502,16 +1503,10 @@ func findDeclInfo(files []*ast.File, pos token.Pos) (decl ast.Decl, spec ast.Spe
15021503
stack := make([]ast.Node, 0, 20)
15031504

15041505
// Allocate the closure once, outside the loop.
1505-
f := func(n ast.Node) bool {
1506+
f := func(n ast.Node, stack []ast.Node) bool {
15061507
if found {
15071508
return false
15081509
}
1509-
if n != nil {
1510-
stack = append(stack, n) // push
1511-
} else {
1512-
stack = stack[:len(stack)-1] // pop
1513-
return false
1514-
}
15151510

15161511
// Skip subtrees (incl. files) that don't contain the search point.
15171512
if !(n.Pos() <= pos && pos < n.End()) {
@@ -1596,7 +1591,7 @@ func findDeclInfo(files []*ast.File, pos token.Pos) (decl ast.Decl, spec ast.Spe
15961591
return true
15971592
}
15981593
for _, file := range files {
1599-
ast.Inspect(file, f)
1594+
internalastutil.PreorderStack(file, stack, f)
16001595
if found {
16011596
return decl, spec, field
16021597
}

gopls/internal/golang/rename_check.go

+31-33
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ import (
4545
"golang.org/x/tools/go/ast/astutil"
4646
"golang.org/x/tools/gopls/internal/cache"
4747
"golang.org/x/tools/gopls/internal/util/safetoken"
48+
"golang.org/x/tools/internal/astutil/cursor"
49+
"golang.org/x/tools/internal/astutil/edge"
4850
"golang.org/x/tools/internal/typeparams"
4951
"golang.org/x/tools/internal/typesinternal"
5052
"golang.org/x/tools/refactor/satisfy"
@@ -338,64 +340,61 @@ func deeper(x, y *types.Scope) bool {
338340
// lexical block enclosing the reference. If fn returns false the
339341
// iteration is terminated and findLexicalRefs returns false.
340342
func forEachLexicalRef(pkg *cache.Package, obj types.Object, fn func(id *ast.Ident, block *types.Scope) bool) bool {
343+
filter := []ast.Node{
344+
(*ast.Ident)(nil),
345+
(*ast.SelectorExpr)(nil),
346+
(*ast.CompositeLit)(nil),
347+
}
341348
ok := true
342-
var stack []ast.Node
343-
344-
var visit func(n ast.Node) bool
345-
visit = func(n ast.Node) bool {
346-
if n == nil {
347-
stack = stack[:len(stack)-1] // pop
348-
return false
349-
}
349+
var visit func(cur cursor.Cursor, push bool) (descend bool)
350+
visit = func(cur cursor.Cursor, push bool) (descend bool) {
350351
if !ok {
351352
return false // bail out
352353
}
353-
354-
stack = append(stack, n) // push
355-
switch n := n.(type) {
354+
if !push {
355+
return false
356+
}
357+
switch n := cur.Node().(type) {
356358
case *ast.Ident:
357359
if pkg.TypesInfo().Uses[n] == obj {
358-
block := enclosingBlock(pkg.TypesInfo(), stack)
360+
block := enclosingBlock(pkg.TypesInfo(), cur)
359361
if !fn(n, block) {
360362
ok = false
361363
}
362364
}
363-
return visit(nil) // pop stack
364365

365366
case *ast.SelectorExpr:
366367
// don't visit n.Sel
367-
ast.Inspect(n.X, visit)
368-
return visit(nil) // pop stack, don't descend
368+
cur.ChildAt(edge.SelectorExpr_X, -1).Inspect(filter, visit)
369+
return false // don't descend
369370

370371
case *ast.CompositeLit:
371372
// Handle recursion ourselves for struct literals
372373
// so we don't visit field identifiers.
373374
tv, ok := pkg.TypesInfo().Types[n]
374375
if !ok {
375-
return visit(nil) // pop stack, don't descend
376+
return false // don't descend
376377
}
377378
if is[*types.Struct](typeparams.CoreType(typeparams.Deref(tv.Type))) {
378379
if n.Type != nil {
379-
ast.Inspect(n.Type, visit)
380+
cur.ChildAt(edge.CompositeLit_Type, -1).Inspect(filter, visit)
380381
}
381-
for _, elt := range n.Elts {
382-
if kv, ok := elt.(*ast.KeyValueExpr); ok {
383-
ast.Inspect(kv.Value, visit)
384-
} else {
385-
ast.Inspect(elt, visit)
382+
for i, elt := range n.Elts {
383+
curElt := cur.ChildAt(edge.CompositeLit_Elts, i)
384+
if _, ok := elt.(*ast.KeyValueExpr); ok {
385+
// skip kv.Key
386+
curElt = curElt.ChildAt(edge.KeyValueExpr_Value, -1)
386387
}
388+
curElt.Inspect(filter, visit)
387389
}
388-
return visit(nil) // pop stack, don't descend
390+
return false // don't descend
389391
}
390392
}
391393
return true
392394
}
393395

394-
for _, f := range pkg.Syntax() {
395-
ast.Inspect(f, visit)
396-
if len(stack) != 0 {
397-
panic(stack)
398-
}
396+
for _, pgf := range pkg.CompiledGoFiles() {
397+
pgf.Cursor.Inspect(filter, visit)
399398
if !ok {
400399
break
401400
}
@@ -404,11 +403,10 @@ func forEachLexicalRef(pkg *cache.Package, obj types.Object, fn func(id *ast.Ide
404403
}
405404

406405
// enclosingBlock returns the innermost block logically enclosing the
407-
// specified AST node (an ast.Ident), specified in the form of a path
408-
// from the root of the file, [file...n].
409-
func enclosingBlock(info *types.Info, stack []ast.Node) *types.Scope {
410-
for i := range stack {
411-
n := stack[len(stack)-1-i]
406+
// AST node (an ast.Ident), specified as a Cursor.
407+
func enclosingBlock(info *types.Info, curId cursor.Cursor) *types.Scope {
408+
for cur := range curId.Enclosing() {
409+
n := cur.Node()
412410
// For some reason, go/types always associates a
413411
// function's scope with its FuncType.
414412
// See comments about scope above.

internal/astutil/util.go

+32
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,35 @@ func PosInStringLiteral(lit *ast.BasicLit, offset int) (token.Pos, error) {
5757
}
5858
return pos, nil
5959
}
60+
61+
// PreorderStack traverses the tree rooted at root,
62+
// calling f before visiting each node.
63+
//
64+
// Each call to f provides the current node and traversal stack,
65+
// consisting of the original value of stack appended with all nodes
66+
// from root to n, excluding n itself. (This design allows calls
67+
// to PreorderStack to be nested without double counting.)
68+
//
69+
// If f returns false, the traversal skips over that subtree. Unlike
70+
// [ast.Inspect], no second call to f is made after visiting node n.
71+
// In practice, the second call is nearly always used only to pop the
72+
// stack, and it is surprisingly tricky to do this correctly; see
73+
// https://go.dev/issue/73319.
74+
func PreorderStack(root ast.Node, stack []ast.Node, f func(n ast.Node, stack []ast.Node) bool) {
75+
before := len(stack)
76+
ast.Inspect(root, func(n ast.Node) bool {
77+
if n != nil {
78+
if !f(n, stack) {
79+
// Do not push, as there will be no corresponding pop.
80+
return false
81+
}
82+
stack = append(stack, n) // push
83+
} else {
84+
stack = stack[:len(stack)-1] // pop
85+
}
86+
return true
87+
})
88+
if len(stack) != before {
89+
panic("push/pop mismatch")
90+
}
91+
}

internal/astutil/util_test.go

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// Copyright 2025 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package astutil_test
6+
7+
import (
8+
"fmt"
9+
"go/ast"
10+
"go/parser"
11+
"go/token"
12+
"reflect"
13+
"strings"
14+
"testing"
15+
16+
"golang.org/x/tools/internal/astutil"
17+
)
18+
19+
func TestPreorderStack(t *testing.T) {
20+
const src = `package a
21+
func f() {
22+
print("hello")
23+
}
24+
func g() {
25+
print("goodbye")
26+
panic("oops")
27+
}
28+
`
29+
fset := token.NewFileSet()
30+
f, _ := parser.ParseFile(fset, "a.go", src, 0)
31+
32+
str := func(n ast.Node) string {
33+
return strings.TrimPrefix(reflect.TypeOf(n).String(), "*ast.")
34+
}
35+
36+
var events []string
37+
var gotStack []string
38+
astutil.PreorderStack(f, nil, func(n ast.Node, stack []ast.Node) bool {
39+
events = append(events, str(n))
40+
if decl, ok := n.(*ast.FuncDecl); ok && decl.Name.Name == "f" {
41+
return false // skip subtree of f()
42+
}
43+
if lit, ok := n.(*ast.BasicLit); ok && lit.Value == `"oops"` {
44+
for _, n := range stack {
45+
gotStack = append(gotStack, str(n))
46+
}
47+
}
48+
return true
49+
})
50+
51+
// Check sequence of events.
52+
const wantEvents = `[File Ident ` + // package a
53+
`FuncDecl ` + // func f() [pruned]
54+
`FuncDecl Ident FuncType FieldList BlockStmt ` + // func g()
55+
`ExprStmt CallExpr Ident BasicLit ` + // print...
56+
`ExprStmt CallExpr Ident BasicLit]` // panic...
57+
if got := fmt.Sprint(events); got != wantEvents {
58+
t.Errorf("PreorderStack events:\ngot: %s\nwant: %s", got, wantEvents)
59+
}
60+
61+
// Check captured stack.
62+
const wantStack = `[File FuncDecl BlockStmt ExprStmt CallExpr]`
63+
if got := fmt.Sprint(gotStack); got != wantStack {
64+
t.Errorf("PreorderStack stack:\ngot: %s\nwant: %s", got, wantStack)
65+
}
66+
67+
}

0 commit comments

Comments
 (0)