@@ -3,6 +3,7 @@ package analyzer
3
3
import (
4
4
"bytes"
5
5
"errors"
6
+ "flag"
6
7
"fmt"
7
8
"go/ast"
8
9
"go/printer"
@@ -14,22 +15,45 @@ import (
14
15
"golang.org/x/tools/go/ast/inspector"
15
16
)
16
17
17
- var Analyzer = & analysis.Analyzer {
18
- Name : "fatcontext" ,
19
- Doc : "detects nested contexts in loops and function literals" ,
20
- Run : run ,
21
- Requires : []* analysis.Analyzer {inspect .Analyzer },
18
+ const FlagCheckStructPointers = "check-struct-pointers"
19
+
20
+ func NewAnalyzer () * analysis.Analyzer {
21
+ r := & runner {}
22
+
23
+ flags := flag .NewFlagSet ("fatcontext" , flag .ExitOnError )
24
+ flags .BoolVar (& r .DetectInStructPointers , FlagCheckStructPointers , false ,
25
+ "set to true to detect potential fat contexts in struct pointers" )
26
+
27
+ return & analysis.Analyzer {
28
+ Name : "fatcontext" ,
29
+ Doc : "detects nested contexts in loops and function literals" ,
30
+ Run : r .run ,
31
+ Flags : * flags ,
32
+ Requires : []* analysis.Analyzer {inspect .Analyzer },
33
+ }
22
34
}
23
35
24
36
var errUnknown = errors .New ("unknown node type" )
25
37
26
- func run (pass * analysis.Pass ) (interface {}, error ) {
38
+ const (
39
+ categoryInLoop = "nested context in loop"
40
+ categoryInFuncLit = "nested context in function literal"
41
+ categoryInStructPointer = "potential nested context in struct pointer"
42
+ categoryUnsupported = "unsupported nested context type"
43
+ )
44
+
45
+ type runner struct {
46
+ DetectInStructPointers bool
47
+ }
48
+
49
+ func (r * runner ) run (pass * analysis.Pass ) (interface {}, error ) {
27
50
inspctr := pass .ResultOf [inspect .Analyzer ].(* inspector.Inspector )
28
51
29
52
nodeFilter := []ast.Node {
30
53
(* ast .ForStmt )(nil ),
31
54
(* ast .RangeStmt )(nil ),
32
55
(* ast .FuncLit )(nil ),
56
+ (* ast .FuncDecl )(nil ),
33
57
}
34
58
35
59
inspctr .Preorder (nodeFilter , func (node ast.Node ) {
@@ -43,63 +67,87 @@ func run(pass *analysis.Pass) (interface{}, error) {
43
67
return
44
68
}
45
69
46
- suggestedStmt := ast.AssignStmt {
47
- Lhs : assignStmt .Lhs ,
48
- TokPos : assignStmt .TokPos ,
49
- Tok : token .DEFINE ,
50
- Rhs : assignStmt .Rhs ,
51
- }
52
- suggested , err := render (pass .Fset , & suggestedStmt )
53
-
54
- var fixes []analysis.SuggestedFix
55
- if err == nil {
56
- fixes = append (fixes , analysis.SuggestedFix {
57
- Message : "replace `=` with `:=`" ,
58
- TextEdits : []analysis.TextEdit {
59
- {
60
- Pos : assignStmt .Pos (),
61
- End : assignStmt .End (),
62
- NewText : suggested ,
63
- },
64
- },
65
- })
70
+ category := getCategory (pass , node , assignStmt )
71
+
72
+ if r .shouldIgnoreReport (category ) {
73
+ return
66
74
}
67
75
76
+ fixes := r .getSuggestedFixes (pass , assignStmt , category )
77
+
68
78
pass .Report (analysis.Diagnostic {
69
79
Pos : assignStmt .Pos (),
70
- Message : getReportMessage ( node ) ,
80
+ Message : category ,
71
81
SuggestedFixes : fixes ,
72
82
})
73
83
})
74
84
75
85
return nil , nil
76
86
}
77
87
78
- func getReportMessage (node ast.Node ) string {
88
+ func (r * runner ) shouldIgnoreReport (category string ) bool {
89
+ return category == categoryInStructPointer && ! r .DetectInStructPointers
90
+ }
91
+
92
+ func (r * runner ) getSuggestedFixes (pass * analysis.Pass , assignStmt * ast.AssignStmt , category string ) []analysis.SuggestedFix {
93
+ switch category {
94
+ case categoryInStructPointer , categoryUnsupported :
95
+ return nil
96
+ }
97
+
98
+ suggestedStmt := ast.AssignStmt {
99
+ Lhs : assignStmt .Lhs ,
100
+ TokPos : assignStmt .TokPos ,
101
+ Tok : token .DEFINE ,
102
+ Rhs : assignStmt .Rhs ,
103
+ }
104
+ suggested , err := render (pass .Fset , & suggestedStmt )
105
+
106
+ var fixes []analysis.SuggestedFix
107
+ if err == nil {
108
+ fixes = append (fixes , analysis.SuggestedFix {
109
+ Message : "replace `=` with `:=`" ,
110
+ TextEdits : []analysis.TextEdit {
111
+ {
112
+ Pos : assignStmt .Pos (),
113
+ End : assignStmt .End (),
114
+ NewText : suggested ,
115
+ },
116
+ },
117
+ })
118
+ }
119
+
120
+ return fixes
121
+ }
122
+
123
+ func getCategory (pass * analysis.Pass , node ast.Node , assignStmt * ast.AssignStmt ) string {
79
124
switch node .(type ) {
80
125
case * ast.ForStmt , * ast.RangeStmt :
81
- return "nested context in loop"
82
- case * ast.FuncLit :
83
- return "nested context in function literal"
84
- default :
85
- return "unsupported nested context type"
126
+ return categoryInLoop
86
127
}
87
- }
88
128
89
- func getBody (node ast.Node ) (* ast.BlockStmt , error ) {
90
- forStmt , ok := node .(* ast.ForStmt )
91
- if ok {
92
- return forStmt .Body , nil
129
+ if isPointer (pass , assignStmt .Lhs [0 ]) {
130
+ return categoryInStructPointer
93
131
}
94
132
95
- rangeStmt , ok := node .(* ast.RangeStmt )
96
- if ok {
97
- return rangeStmt .Body , nil
133
+ switch node .(type ) {
134
+ case * ast.FuncLit , * ast.FuncDecl :
135
+ return categoryInFuncLit
136
+ default :
137
+ return categoryUnsupported
98
138
}
139
+ }
99
140
100
- funcLit , ok := node .(* ast.FuncLit )
101
- if ok {
102
- return funcLit .Body , nil
141
+ func getBody (node ast.Node ) (* ast.BlockStmt , error ) {
142
+ switch typedNode := node .(type ) {
143
+ case * ast.ForStmt :
144
+ return typedNode .Body , nil
145
+ case * ast.RangeStmt :
146
+ return typedNode .Body , nil
147
+ case * ast.FuncLit :
148
+ return typedNode .Body , nil
149
+ case * ast.FuncDecl :
150
+ return typedNode .Body , nil
103
151
}
104
152
105
153
return nil , errUnknown
@@ -174,6 +222,10 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as
174
222
continue
175
223
}
176
224
225
+ if isPointer (pass , assignStmt .Lhs [0 ]) {
226
+ return assignStmt
227
+ }
228
+
177
229
// allow assignment to non-pointer children of values defined within the loop
178
230
if isWithinLoop (assignStmt .Lhs [0 ], node , pass ) {
179
231
continue
@@ -249,3 +301,13 @@ func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
249
301
}
250
302
}
251
303
}
304
+
305
+ func isPointer (pass * analysis.Pass , exp ast.Node ) bool {
306
+ switch n := exp .(type ) {
307
+ case * ast.SelectorExpr :
308
+ sel , ok := pass .TypesInfo .Selections [n ]
309
+ return ok && sel .Indirect ()
310
+ }
311
+
312
+ return false
313
+ }
0 commit comments