forked from src-d/go-mysql-server
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmysql_test.go
148 lines (123 loc) · 2.85 KB
/
mysql_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
package testmysql
import (
"database/sql"
"reflect"
"testing"
_ "github.com/go-sql-driver/mysql"
)
const connectionString = "root:@tcp(127.0.0.1:3306)/mydb"
func TestMySQL(t *testing.T) {
db, err := sql.Open("mysql", connectionString)
if err != nil {
t.Fatalf("can't connect to mysql: %s", err)
}
rs, err := db.Query("SELECT name, email FROM mytable ORDER BY name, email")
if err != nil {
t.Fatalf("unable to get rows: %s", err)
}
var rows [][2]string
for rs.Next() {
var row [2]string
if err := rs.Scan(&row[0], &row[1]); err != nil {
t.Errorf("got error scanning row: %s", err)
}
rows = append(rows, row)
}
if err := rs.Err(); err != nil {
t.Errorf("got unexpected error: %s", err)
}
expected := [][2]string{
{"Evil Bob", "[email protected]"},
{"Jane Doe", "[email protected]"},
{"John Doe", "[email protected]"},
{"John Doe", "[email protected]"},
}
if len(expected) != len(rows) {
t.Errorf("got %d rows, expecting %d", len(rows), len(expected))
}
for i := range rows {
if rows[i][0] != expected[i][0] || rows[i][1] != expected[i][1] {
t.Errorf(
"incorrect row %d, got: {%s, %s}, expected: {%s, %s}",
i,
rows[i][0], rows[i][1],
expected[i][0], expected[i][1],
)
}
}
}
func TestGrafana(t *testing.T) {
db, err := sql.Open("mysql", connectionString)
if err != nil {
t.Fatalf("can't connect to mysql: %s", err)
}
tests := []struct {
query string
expected [][]string
}{
{
`SELECT 1`,
[][]string{{"1"}},
},
{
`select @@version_comment limit 1`,
[][]string{{""}},
},
{
`describe table mytable`,
[][]string{
{"name", "TEXT"},
{"email", "TEXT"},
{"phone_numbers", "JSON"},
{"created_at", "TIMESTAMP"},
},
},
{
`select count(*) from mytable where created_at ` +
`between '2000-01-01T00:00:00Z' and '2999-01-01T00:00:00Z'`,
[][]string{{"4"}},
},
}
for _, c := range tests {
rs, err := db.Query(c.query)
if err != nil {
t.Fatalf("unable to execute query: %s", err)
}
result := getResult(t, rs)
if !reflect.DeepEqual(result, c.expected) {
t.Fatalf("rows do not match, expected: %v, got: %v", c.expected, result)
}
}
}
func getResult(t *testing.T, rs *sql.Rows) [][]string {
t.Helper()
columns, err := rs.Columns()
if err != nil {
t.Fatalf("unable to get columns: %s", err)
}
var result [][]string
p := make([]interface{}, len(columns))
for rs.Next() {
row := make([]interface{}, len(columns))
for i := range row {
p[i] = &row[i]
}
err = rs.Scan(p...)
if err != nil {
t.Fatalf("could not retrieve row: %s", err)
}
result = append(result, getStringSlice(row))
}
return result
}
func getStringSlice(row []interface{}) []string {
rowStrings := make([]string, len(row))
for i, r := range row {
if r == nil {
rowStrings[i] = "NULL"
} else {
rowStrings[i] = string(r.([]uint8))
}
}
return rowStrings
}