Skip to content

Commit ddad04c

Browse files
committed
allow updating expected vars/consts inside functions
1 parent 97735af commit ddad04c

File tree

2 files changed

+102
-20
lines changed

2 files changed

+102
-20
lines changed

assert/assert_ext_test.go

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package assert_test
22

33
import (
4+
"go/ast"
45
"go/parser"
56
"go/token"
67
"io/ioutil"
@@ -56,6 +57,48 @@ expected value
5657
expected := "const expectedTwo = `this is the new\nexpected value\n`"
5758
assert.Assert(t, strings.Contains(string(raw), expected), "actual=%v", string(raw))
5859
})
60+
61+
t.Run("var inside function is updated when -update=true", func(t *testing.T) {
62+
patchUpdate(t)
63+
t.Cleanup(func() {
64+
resetVariable(t, "expectedInsideFunc", "")
65+
})
66+
67+
actual := `this is the new
68+
expected value
69+
for var inside function
70+
`
71+
expectedInsideFunc := ``
72+
73+
assert.Equal(t, actual, expectedInsideFunc)
74+
75+
raw, err := ioutil.ReadFile(fileName(t))
76+
assert.NilError(t, err)
77+
78+
expected := "expectedInsideFunc := `this is the new\nexpected value\nfor var inside function\n`"
79+
assert.Assert(t, strings.Contains(string(raw), expected), "actual=%v", string(raw))
80+
})
81+
82+
t.Run("const inside function is updated when -update=true", func(t *testing.T) {
83+
patchUpdate(t)
84+
t.Cleanup(func() {
85+
resetVariable(t, "expectedConstInsideFunc", "")
86+
})
87+
88+
actual := `this is the new
89+
expected value
90+
for const inside function
91+
`
92+
const expectedConstInsideFunc = ``
93+
94+
assert.Equal(t, actual, expectedConstInsideFunc)
95+
96+
raw, err := ioutil.ReadFile(fileName(t))
97+
assert.NilError(t, err)
98+
99+
expected := "const expectedConstInsideFunc = `this is the new\nexpected value\nfor const inside function\n`"
100+
assert.Assert(t, strings.Contains(string(raw), expected), "actual=%v", string(raw))
101+
})
59102
}
60103

61104
// expectedOne is updated by running the tests with -update
@@ -87,7 +130,33 @@ func resetVariable(t *testing.T, varName string, value string) {
87130
astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors|parser.ParseComments)
88131
assert.NilError(t, err)
89132

90-
err = source.UpdateVariable(filename, fileset, astFile, varName, value)
133+
var ident *ast.Ident
134+
ast.Inspect(astFile, func(n ast.Node) bool {
135+
switch v := n.(type) {
136+
case *ast.AssignStmt:
137+
if len(v.Lhs) == 1 {
138+
if id, ok := v.Lhs[0].(*ast.Ident); ok {
139+
if id.Name == varName {
140+
ident = id
141+
return false
142+
}
143+
}
144+
}
145+
146+
case *ast.ValueSpec:
147+
for _, id := range v.Names {
148+
if id.Name == varName {
149+
ident = id
150+
return false
151+
}
152+
}
153+
}
154+
155+
return true
156+
})
157+
assert.Assert(t, ident != nil, "failed to get ident for %s", varName)
158+
159+
err = source.UpdateVariable(filename, fileset, astFile, ident, value)
91160
assert.NilError(t, err, "failed to reset file")
92161
}
93162

internal/source/update.go

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ func UpdateExpectedValue(stackIndex int, x, y interface{}) error {
5454
return ErrNotFound
5555
}
5656

57-
argIndex, varName := getVarNameForExpectedValueArg(expr)
58-
if argIndex < 0 || varName == "" {
57+
argIndex, ident := getVarNameForExpectedValueArg(expr)
58+
if argIndex < 0 || ident == nil {
5959
debug("no arguments started with the word 'expected': %v",
6060
debugFormatNode{Node: &ast.CallExpr{Args: expr}})
6161
return ErrNotFound
@@ -71,7 +71,7 @@ func UpdateExpectedValue(stackIndex int, x, y interface{}) error {
7171
debug("value must be type string, got %T", value)
7272
return ErrNotFound
7373
}
74-
return UpdateVariable(filename, fileset, astFile, varName, strValue)
74+
return UpdateVariable(filename, fileset, astFile, ident, strValue)
7575
}
7676

7777
// UpdateVariable writes to filename the contents of astFile with the value of
@@ -80,10 +80,10 @@ func UpdateVariable(
8080
filename string,
8181
fileset *token.FileSet,
8282
astFile *ast.File,
83-
varName string,
83+
ident *ast.Ident,
8484
value string,
8585
) error {
86-
obj := astFile.Scope.Objects[varName]
86+
obj := ident.Obj
8787
if obj == nil {
8888
return ErrNotFound
8989
}
@@ -92,20 +92,33 @@ func UpdateVariable(
9292
return ErrNotFound
9393
}
9494

95-
spec, ok := obj.Decl.(*ast.ValueSpec)
96-
if !ok {
95+
switch decl := obj.Decl.(type) {
96+
case *ast.ValueSpec:
97+
if len(decl.Names) != 1 {
98+
debug("more than one name in ast.ValueSpec")
99+
return ErrNotFound
100+
}
101+
102+
decl.Values[0] = &ast.BasicLit{
103+
Kind: token.STRING,
104+
Value: "`" + value + "`",
105+
}
106+
107+
case *ast.AssignStmt:
108+
if len(decl.Lhs) != 1 {
109+
debug("more than one name in ast.AssignStmt")
110+
return ErrNotFound
111+
}
112+
113+
decl.Rhs[0] = &ast.BasicLit{
114+
Kind: token.STRING,
115+
Value: "`" + value + "`",
116+
}
117+
118+
default:
97119
debug("can only update *ast.ValueSpec, found %T", obj.Decl)
98120
return ErrNotFound
99121
}
100-
if len(spec.Names) != 1 {
101-
debug("more than one name in ast.ValueSpec")
102-
return ErrNotFound
103-
}
104-
105-
spec.Values[0] = &ast.BasicLit{
106-
Kind: token.STRING,
107-
Value: "`" + value + "`",
108-
}
109122

110123
var buf bytes.Buffer
111124
if err := format.Node(&buf, fileset, astFile); err != nil {
@@ -125,14 +138,14 @@ func UpdateVariable(
125138
return nil
126139
}
127140

128-
func getVarNameForExpectedValueArg(expr []ast.Expr) (int, string) {
141+
func getVarNameForExpectedValueArg(expr []ast.Expr) (int, *ast.Ident) {
129142
for i := 1; i < 3; i++ {
130143
switch e := expr[i].(type) {
131144
case *ast.Ident:
132145
if strings.HasPrefix(strings.ToLower(e.Name), "expected") {
133-
return i, e.Name
146+
return i, e
134147
}
135148
}
136149
}
137-
return -1, ""
150+
return -1, nil
138151
}

0 commit comments

Comments
 (0)