Skip to content

Commit ef9d47d

Browse files
committed
feat: better discriminate assignations to struct pointers
1 parent 939d65b commit ef9d47d

File tree

7 files changed

+173
-51
lines changed

7 files changed

+173
-51
lines changed

README.md

-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ go install github.com/Crocmagnon/fatcontext/cmd/fatcontext@latest
1616
fatcontext ./...
1717
```
1818

19-
There are no specific configuration options or custom command-line flags.
20-
2119
## Example
2220

2321
```go

cmd/fatcontext/main.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@ import (
77
)
88

99
func main() {
10-
singlechecker.Main(analyzer.Analyzer)
10+
singlechecker.Main(analyzer.NewAnalyzer())
1111
}

pkg/analyzer/analyzer.go

+106-44
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package analyzer
33
import (
44
"bytes"
55
"errors"
6+
"flag"
67
"fmt"
78
"go/ast"
89
"go/printer"
@@ -14,22 +15,45 @@ import (
1415
"golang.org/x/tools/go/ast/inspector"
1516
)
1617

17-
var Analyzer = &analysis.Analyzer{
18-
Name: "fatcontext",
19-
Doc: "detects nested contexts in loops and function literals",
20-
Run: run,
21-
Requires: []*analysis.Analyzer{inspect.Analyzer},
18+
const FlagCheckStructPointers = "check-struct-pointers"
19+
20+
func NewAnalyzer() *analysis.Analyzer {
21+
r := &runner{}
22+
23+
flags := flag.NewFlagSet("fatcontext", flag.ExitOnError)
24+
flags.BoolVar(&r.DetectInStructPointers, FlagCheckStructPointers, false,
25+
"set to true to detect potential fat contexts in struct pointers")
26+
27+
return &analysis.Analyzer{
28+
Name: "fatcontext",
29+
Doc: "detects nested contexts in loops and function literals",
30+
Run: r.run,
31+
Flags: *flags,
32+
Requires: []*analysis.Analyzer{inspect.Analyzer},
33+
}
2234
}
2335

2436
var errUnknown = errors.New("unknown node type")
2537

26-
func run(pass *analysis.Pass) (interface{}, error) {
38+
const (
39+
categoryInLoop = "nested context in loop"
40+
categoryInFuncLit = "nested context in function literal"
41+
categoryInStructPointer = "potential nested context in struct pointer"
42+
categoryUnsupported = "unsupported nested context type"
43+
)
44+
45+
type runner struct {
46+
DetectInStructPointers bool
47+
}
48+
49+
func (r *runner) run(pass *analysis.Pass) (interface{}, error) {
2750
inspctr := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
2851

2952
nodeFilter := []ast.Node{
3053
(*ast.ForStmt)(nil),
3154
(*ast.RangeStmt)(nil),
3255
(*ast.FuncLit)(nil),
56+
(*ast.FuncDecl)(nil),
3357
}
3458

3559
inspctr.Preorder(nodeFilter, func(node ast.Node) {
@@ -43,63 +67,87 @@ func run(pass *analysis.Pass) (interface{}, error) {
4367
return
4468
}
4569

46-
suggestedStmt := ast.AssignStmt{
47-
Lhs: assignStmt.Lhs,
48-
TokPos: assignStmt.TokPos,
49-
Tok: token.DEFINE,
50-
Rhs: assignStmt.Rhs,
51-
}
52-
suggested, err := render(pass.Fset, &suggestedStmt)
53-
54-
var fixes []analysis.SuggestedFix
55-
if err == nil {
56-
fixes = append(fixes, analysis.SuggestedFix{
57-
Message: "replace `=` with `:=`",
58-
TextEdits: []analysis.TextEdit{
59-
{
60-
Pos: assignStmt.Pos(),
61-
End: assignStmt.End(),
62-
NewText: suggested,
63-
},
64-
},
65-
})
70+
category := getCategory(pass, node, assignStmt)
71+
72+
if r.shouldIgnoreReport(category) {
73+
return
6674
}
6775

76+
fixes := r.getSuggestedFixes(pass, assignStmt, category)
77+
6878
pass.Report(analysis.Diagnostic{
6979
Pos: assignStmt.Pos(),
70-
Message: getReportMessage(node),
80+
Message: category,
7181
SuggestedFixes: fixes,
7282
})
7383
})
7484

7585
return nil, nil
7686
}
7787

78-
func getReportMessage(node ast.Node) string {
88+
func (r *runner) shouldIgnoreReport(category string) bool {
89+
return category == categoryInStructPointer && !r.DetectInStructPointers
90+
}
91+
92+
func (r *runner) getSuggestedFixes(pass *analysis.Pass, assignStmt *ast.AssignStmt, category string) []analysis.SuggestedFix {
93+
switch category {
94+
case categoryInStructPointer, categoryUnsupported:
95+
return nil
96+
}
97+
98+
suggestedStmt := ast.AssignStmt{
99+
Lhs: assignStmt.Lhs,
100+
TokPos: assignStmt.TokPos,
101+
Tok: token.DEFINE,
102+
Rhs: assignStmt.Rhs,
103+
}
104+
suggested, err := render(pass.Fset, &suggestedStmt)
105+
106+
var fixes []analysis.SuggestedFix
107+
if err == nil {
108+
fixes = append(fixes, analysis.SuggestedFix{
109+
Message: "replace `=` with `:=`",
110+
TextEdits: []analysis.TextEdit{
111+
{
112+
Pos: assignStmt.Pos(),
113+
End: assignStmt.End(),
114+
NewText: suggested,
115+
},
116+
},
117+
})
118+
}
119+
120+
return fixes
121+
}
122+
123+
func getCategory(pass *analysis.Pass, node ast.Node, assignStmt *ast.AssignStmt) string {
79124
switch node.(type) {
80125
case *ast.ForStmt, *ast.RangeStmt:
81-
return "nested context in loop"
82-
case *ast.FuncLit:
83-
return "nested context in function literal"
84-
default:
85-
return "unsupported nested context type"
126+
return categoryInLoop
86127
}
87-
}
88128

89-
func getBody(node ast.Node) (*ast.BlockStmt, error) {
90-
forStmt, ok := node.(*ast.ForStmt)
91-
if ok {
92-
return forStmt.Body, nil
129+
if isPointer(pass, assignStmt.Lhs[0]) {
130+
return categoryInStructPointer
93131
}
94132

95-
rangeStmt, ok := node.(*ast.RangeStmt)
96-
if ok {
97-
return rangeStmt.Body, nil
133+
switch node.(type) {
134+
case *ast.FuncLit, *ast.FuncDecl:
135+
return categoryInFuncLit
136+
default:
137+
return categoryUnsupported
98138
}
139+
}
99140

100-
funcLit, ok := node.(*ast.FuncLit)
101-
if ok {
102-
return funcLit.Body, nil
141+
func getBody(node ast.Node) (*ast.BlockStmt, error) {
142+
switch typedNode := node.(type) {
143+
case *ast.ForStmt:
144+
return typedNode.Body, nil
145+
case *ast.RangeStmt:
146+
return typedNode.Body, nil
147+
case *ast.FuncLit:
148+
return typedNode.Body, nil
149+
case *ast.FuncDecl:
150+
return typedNode.Body, nil
103151
}
104152

105153
return nil, errUnknown
@@ -174,6 +222,10 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as
174222
continue
175223
}
176224

225+
if isPointer(pass, assignStmt.Lhs[0]) {
226+
return assignStmt
227+
}
228+
177229
// allow assignment to non-pointer children of values defined within the loop
178230
if isWithinLoop(assignStmt.Lhs[0], node, pass) {
179231
continue
@@ -249,3 +301,13 @@ func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
249301
}
250302
}
251303
}
304+
305+
func isPointer(pass *analysis.Pass, exp ast.Node) bool {
306+
switch n := exp.(type) {
307+
case *ast.SelectorExpr:
308+
sel, ok := pass.TypesInfo.Selections[n]
309+
return ok && sel.Indirect()
310+
}
311+
312+
return false
313+
}

pkg/analyzer/analyzer_test.go

+19-3
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,28 @@ import (
1010
"github.com/Crocmagnon/fatcontext/pkg/analyzer"
1111
)
1212

13-
func TestAll(t *testing.T) {
13+
func TestAnalyzer(t *testing.T) {
1414
wd, err := os.Getwd()
1515
if err != nil {
1616
t.Fatalf("Failed to get wd: %s", err)
1717
}
18-
testdata := filepath.Join(filepath.Dir(filepath.Dir(wd)), "testdata")
18+
testdata := filepath.Join(wd, "testdata")
1919

20-
analysistest.Run(t, testdata, analyzer.Analyzer, "./...")
20+
t.Run("no func decl", func(t *testing.T) {
21+
an := analyzer.NewAnalyzer()
22+
analysistest.Run(t, testdata, an, "./common")
23+
analysistest.Run(t, testdata, an, "./no_structpointer")
24+
})
25+
26+
t.Run("func decl", func(t *testing.T) {
27+
an := analyzer.NewAnalyzer()
28+
29+
err := an.Flags.Set(analyzer.FlagCheckStructPointers, "true")
30+
if err != nil {
31+
t.Fatal(err)
32+
}
33+
34+
analysistest.Run(t, testdata, an, "./common")
35+
analysistest.Run(t, testdata, an, "./structpointer")
36+
})
2137
}

testdata/src/example.go pkg/analyzer/testdata/common/example.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package src
1+
package common
22

33
import (
44
"context"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package common
2+
3+
import (
4+
"context"
5+
)
6+
7+
type Container struct {
8+
Ctx context.Context
9+
}
10+
11+
func something() func(*Container) {
12+
return func(r *Container) {
13+
ctx := r.Ctx
14+
ctx = context.WithValue(ctx, "key", "val")
15+
r.Ctx = ctx
16+
}
17+
}
18+
19+
func blah(r *Container) {
20+
ctx := r.Ctx
21+
ctx = context.WithValue(ctx, "key", "val")
22+
r.Ctx = ctx
23+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package common
2+
3+
import (
4+
"context"
5+
)
6+
7+
type Container struct {
8+
Ctx context.Context
9+
}
10+
11+
func something() func(*Container) {
12+
return func(r *Container) {
13+
ctx := r.Ctx
14+
ctx = context.WithValue(ctx, "key", "val")
15+
r.Ctx = ctx // want "potential nested context in struct pointer"
16+
}
17+
}
18+
19+
func blah(r *Container) {
20+
ctx := r.Ctx
21+
ctx = context.WithValue(ctx, "key", "val")
22+
r.Ctx = ctx // want "potential nested context in struct pointer"
23+
}

0 commit comments

Comments
 (0)