Skip to content

Commit caaf586

Browse files
authored
fix the SQL Injection (#330)
1 parent 0e148b8 commit caaf586

File tree

4 files changed

+89
-46
lines changed

4 files changed

+89
-46
lines changed

global/global.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,3 +311,22 @@ func CommandLine(cmd string, args []string) string {
311311
}
312312
return result
313313
}
314+
315+
// EscapeQuote escape the string the single quote, double quote, and backtick
316+
func EscapeQuote(str string) string {
317+
type Escape struct {
318+
From string
319+
To string
320+
}
321+
escape := []Escape{
322+
{From: "`", To: ""}, // remove the backtick
323+
{From: `\`, To: `\\`},
324+
{From: `'`, To: `\'`},
325+
{From: `"`, To: `\"`},
326+
}
327+
328+
for _, e := range escape {
329+
str = strings.ReplaceAll(str, e.From, e.To)
330+
}
331+
return str
332+
}

global/global_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,3 +438,12 @@ func TestCommandLine(t *testing.T) {
438438
s = CommandLine("kubectl", []string{"get", "pod", "--all-namespaces", "-o", "json"})
439439
assert.Equal(t, "kubectl get pod --all-namespaces -o json", s)
440440
}
441+
442+
func TestEscape(t *testing.T) {
443+
assert.Equal(t, "test", EscapeQuote("test"))
444+
assert.Equal(t, "test", EscapeQuote("`test`"))
445+
assert.Equal(t, `\'test\'`, EscapeQuote("'test'"))
446+
assert.Equal(t, `\"test\"`, EscapeQuote(`"test"`))
447+
assert.Equal(t, `\\test\\`, EscapeQuote(`\test\`))
448+
449+
}

probe/client/mysql/mysql.go

Lines changed: 55 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -106,50 +106,65 @@ func (r *MySQL) Probe() (bool, string) {
106106

107107
// Check if we need to query specific data
108108
if len(r.Data) > 0 {
109-
for k, v := range r.Data {
110-
log.Debugf("[%s / %s / %s] - Verifying Data - [%s] : [%s]", r.ProbeKind, r.ProbeName, r.ProbeTag, k, v)
111-
sql, err := r.getSQL(k)
112-
if err != nil {
113-
return false, err.Error()
114-
}
115-
log.Debugf("[%s / %s / %s] - SQL - [%s]", r.ProbeKind, r.ProbeName, r.ProbeTag, sql)
116-
rows, err := db.Query(sql)
117-
if err != nil {
118-
return false, err.Error()
119-
}
120-
if !rows.Next() {
121-
rows.Close()
122-
return false, fmt.Sprintf("No data found for [%s]", k)
123-
}
124-
//check the value is equal to the value in data
125-
var value string
126-
if err := rows.Scan(&value); err != nil {
127-
rows.Close()
128-
return false, err.Error()
129-
}
130-
if value != v {
131-
rows.Close()
132-
return false, fmt.Sprintf("Value not match for [%s] expected [%s] got [%s] ", k, v, value)
133-
}
134-
rows.Close()
135-
log.Debugf("[%s / %s / %s] - Data Verified Successfully! - [%s] : [%s]", r.ProbeKind, r.ProbeName, r.ProbeTag, k, v)
136-
}
137-
} else {
138-
err = db.Ping()
139-
if err != nil {
109+
if err := r.ProbeWithDataVerification(db); err != nil {
140110
return false, err.Error()
141111
}
142-
row, err := db.Query("show status like \"uptime\"") // run a SQL to test
143-
if err != nil {
112+
} else {
113+
if err := r.ProbeWithPing(db); err != nil {
144114
return false, err.Error()
145115
}
146-
defer row.Close()
147116
}
148117

149118
return true, "Check MySQL Server Successfully!"
150119

151120
}
152121

122+
// ProbeWithPing do the health check with ping
123+
func (r *MySQL) ProbeWithPing(db *sql.DB) error {
124+
if err := db.Ping(); err != nil {
125+
return err
126+
}
127+
row, err := db.Query("show status like \"uptime\"") // run a SQL to test
128+
if err != nil {
129+
return err
130+
}
131+
defer row.Close()
132+
return nil
133+
}
134+
135+
// ProbeWithDataVerification do the health check with data verification
136+
func (r *MySQL) ProbeWithDataVerification(db *sql.DB) error {
137+
for k, v := range r.Data {
138+
log.Debugf("[%s / %s / %s] - Verifying Data - [%s] : [%s]", r.ProbeKind, r.ProbeName, r.ProbeTag, k, v)
139+
sql, err := r.getSQL(k)
140+
if err != nil {
141+
return err
142+
}
143+
log.Debugf("[%s / %s / %s] - SQL - [%s]", r.ProbeKind, r.ProbeName, r.ProbeTag, sql)
144+
rows, err := db.Query(sql)
145+
if err != nil {
146+
return err
147+
}
148+
if !rows.Next() {
149+
rows.Close()
150+
return fmt.Errorf("No data found for [%s]", k)
151+
}
152+
//check the value is equal to the value in data
153+
var value string
154+
if err := rows.Scan(&value); err != nil {
155+
rows.Close()
156+
return err
157+
}
158+
if value != v {
159+
rows.Close()
160+
return fmt.Errorf("Value not match for [%s] expected [%s] got [%s] ", k, v, value)
161+
}
162+
rows.Close()
163+
log.Debugf("[%s / %s / %s] - Data Verified Successfully! - [%s] : [%s]", r.ProbeKind, r.ProbeName, r.ProbeTag, k, v)
164+
}
165+
return nil
166+
}
167+
153168
// getSQL get the SQL statement
154169
// input: database:table:column:key:value
155170
// output: SELECT column FROM database.table WHERE key = value
@@ -161,16 +176,16 @@ func (r *MySQL) getSQL(str string) (string, error) {
161176
if len(fields) != 5 {
162177
return "", fmt.Errorf("Invalid SQL data - [%s]. (syntax: database:table:field:key:value)", str)
163178
}
164-
db := fields[0]
165-
table := fields[1]
166-
field := fields[2]
167-
key := fields[3]
168-
value := fields[4]
179+
db := global.EscapeQuote(fields[0])
180+
table := global.EscapeQuote(fields[1])
181+
field := global.EscapeQuote(fields[2])
182+
key := global.EscapeQuote(fields[3])
183+
value := global.EscapeQuote(fields[4])
169184
//check value is int or not
170185
if _, err := strconv.Atoi(value); err != nil {
171186
return "", fmt.Errorf("Invalid SQL data - [%s], the value must be int", str)
172187
}
173188

174-
sql := fmt.Sprintf("SELECT %s FROM %s.%s WHERE %s = %s", field, db, table, key, value)
189+
sql := fmt.Sprintf("SELECT `%s` FROM `%s`.`%s` WHERE `%s` = %s", field, db, table, key, value)
175190
return sql, nil
176191
}

probe/client/postgres/postgres.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,16 +190,16 @@ func (r *PostgreSQL) getSQL(str string) (string, string, error) {
190190
if len(fields) != 5 {
191191
return "", "", fmt.Errorf("Invalid SQL data - [%s]. (syntax: database:table:field:key:value)", str)
192192
}
193-
db := fields[0]
194-
table := fields[1]
195-
field := fields[2]
196-
key := fields[3]
197-
value := fields[4]
193+
db := global.EscapeQuote(fields[0])
194+
table := global.EscapeQuote(fields[1])
195+
field := global.EscapeQuote(fields[2])
196+
key := global.EscapeQuote(fields[3])
197+
value := global.EscapeQuote(fields[4])
198198
//check value is int or not
199199
if _, err := strconv.Atoi(value); err != nil {
200200
return "", "", fmt.Errorf("Invalid SQL data - [%s], the value must be int", str)
201201
}
202202

203-
sql := fmt.Sprintf("SELECT %s FROM %s WHERE %s = %s", field, table, key, value)
203+
sql := fmt.Sprintf(`SELECT "%s" FROM "%s" WHERE "%s" = %s`, field, table, key, value)
204204
return db, sql, nil
205205
}

0 commit comments

Comments
 (0)