From 4d8aab91ea4573c5a947c3394aed13167f933aa7 Mon Sep 17 00:00:00 2001 From: lance6716 Date: Fri, 22 Sep 2023 17:26:08 +0800 Subject: [PATCH 1/3] add String() for FieldValue Signed-off-by: lance6716 --- client/client_test.go | 68 +++++++++++++++++++++++++++++++++++++++++++ mysql/field.go | 20 +++++++++++++ 2 files changed, 88 insertions(+) diff --git a/client/client_test.go b/client/client_test.go index 7fddd9dad..045064a9f 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -423,3 +423,71 @@ func (s *clientTestSuite) TestStmt_Trans() { str, _ = r.GetString(0, 0) require.Equal(s.T(), `abc`, str) } + +func (s *clientTestSuite) TestFieldValueString() { + _, err := s.c.Execute( + ` +CREATE TABLE field_value_test ( + c_int int, + c_bit bit(8), + c_tinyint_u tinyint unsigned, + c_decimal decimal(10, 5), + c_float float(10), + c_date date, + c_datetime datetime(3), + c_timestamp timestamp(4), + c_time time(5), + c_year year, + c_char char(10), + c_varchar varchar(10), + c_binary binary(10), + c_varbinary varbinary(10), + c_blob blob, + c_enum enum('a', 'b', 'c'), + c_set set('a', 'b', 'c'), + c_json json, + c_null int +)`) + require.NoError(s.T(), err) + s.T().Cleanup(func() { + s.c.Execute( + `DROP TABLE field_value_test`) + }) + + _, err = s.c.Execute(` +INSERT INTO field_value_test VALUES ( + 1, 2, 3, 4.5, 6.7, + '2019-01-01', '2019-01-01 01:01:01.123', '2019-01-01 01:01:01.1234', '01:01:01.12345', 2019, + 'char', 'varchar', 'binary', 'varbinary', 'blob', 'a', 'a,b', '{"a": 1}', + NULL +)`) + require.NoError(s.T(), err) + + result, err := s.c.Execute(`SELECT * FROM field_value_test`) + require.NoError(s.T(), err) + require.Len(s.T(), result.Values, 1) + expected := []string{ + `1`, "'\x02'", `3`, `'4.50000'`, `6.7`, + `'2019-01-01'`, `'2019-01-01 01:01:01.123'`, `'2019-01-01 01:01:01.1234'`, `'01:01:01.12345'`, `2019`, + `'char'`, `'varchar'`, "'binary\x00\x00\x00\x00'", `'varbinary'`, `'blob'`, `'a'`, `'a,b'`, `'{"a": 1}'`, + `NULL`, + } + for i, v := range result.Values[0] { + require.Equal(s.T(), expected[i], v.String()) + } + + // test can directly use to build a SQL, through it's not safe in most cases + sql := fmt.Sprintf("INSERT INTO field_value_test VALUES (%s)", strings.Join(expected, ",")) + _, err = s.c.Execute(sql) + require.NoError(s.T(), err) + result, err = s.c.Execute(`SELECT * FROM field_value_test`) + require.NoError(s.T(), err) + // check again, everything is same + require.Len(s.T(), result.Values, 2) + for i, v := range result.Values[0] { + require.Equal(s.T(), expected[i], v.String()) + } + for i, v := range result.Values[1] { + require.Equal(s.T(), expected[i], v.String()) + } +} diff --git a/mysql/field.go b/mysql/field.go index 9504e931b..2cb57467a 100644 --- a/mysql/field.go +++ b/mysql/field.go @@ -2,6 +2,8 @@ package mysql import ( "encoding/binary" + "fmt" + "strconv" "github.com/go-mysql-org/go-mysql/utils" ) @@ -210,3 +212,21 @@ func (fv *FieldValue) Value() interface{} { return nil } } + +// String returns a MySQL literal string that equals the value. +func (fv *FieldValue) String() string { + switch fv.Type { + case FieldValueTypeNull: + return "NULL" + case FieldValueTypeUnsigned: + return strconv.FormatUint(fv.AsUint64(), 10) + case FieldValueTypeSigned: + return strconv.FormatInt(fv.AsInt64(), 10) + case FieldValueTypeFloat: + return strconv.FormatFloat(fv.AsFloat64(), 'f', -1, 64) + case FieldValueTypeString: + return "'" + string(fv.AsString()) + "'" + default: + return fmt.Sprintf("unknown type %d of FieldValue", fv.Type) + } +} From bc3808ad09580c0e895effa5fa9fbac532c941d1 Mon Sep 17 00:00:00 2001 From: lance6716 Date: Fri, 22 Sep 2023 17:28:22 +0800 Subject: [PATCH 2/3] fix lint Signed-off-by: lance6716 --- client/client_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/client_test.go b/client/client_test.go index 045064a9f..2cec7e4a9 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -450,7 +450,7 @@ CREATE TABLE field_value_test ( )`) require.NoError(s.T(), err) s.T().Cleanup(func() { - s.c.Execute( + _, _ = s.c.Execute( `DROP TABLE field_value_test`) }) From 77e7e6458b47755c47dcfa8a55b9f962e72308bb Mon Sep 17 00:00:00 2001 From: lance6716 Date: Fri, 22 Sep 2023 17:39:50 +0800 Subject: [PATCH 3/3] escape Signed-off-by: lance6716 --- client/client_test.go | 6 +++--- mysql/field.go | 13 ++++++++++++- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/client/client_test.go b/client/client_test.go index 2cec7e4a9..c47c795ef 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -458,7 +458,7 @@ CREATE TABLE field_value_test ( INSERT INTO field_value_test VALUES ( 1, 2, 3, 4.5, 6.7, '2019-01-01', '2019-01-01 01:01:01.123', '2019-01-01 01:01:01.1234', '01:01:01.12345', 2019, - 'char', 'varchar', 'binary', 'varbinary', 'blob', 'a', 'a,b', '{"a": 1}', + 'cha\'r', 'varchar', 'binary', 'varbinary', 'blob', 'a', 'a,b', '{"a": 1}', NULL )`) require.NoError(s.T(), err) @@ -469,14 +469,14 @@ INSERT INTO field_value_test VALUES ( expected := []string{ `1`, "'\x02'", `3`, `'4.50000'`, `6.7`, `'2019-01-01'`, `'2019-01-01 01:01:01.123'`, `'2019-01-01 01:01:01.1234'`, `'01:01:01.12345'`, `2019`, - `'char'`, `'varchar'`, "'binary\x00\x00\x00\x00'", `'varbinary'`, `'blob'`, `'a'`, `'a,b'`, `'{"a": 1}'`, + `'cha\'r'`, `'varchar'`, "'binary\x00\x00\x00\x00'", `'varbinary'`, `'blob'`, `'a'`, `'a,b'`, `'{"a": 1}'`, `NULL`, } for i, v := range result.Values[0] { require.Equal(s.T(), expected[i], v.String()) } - // test can directly use to build a SQL, through it's not safe in most cases + // test can directly use to build a SQL, though it's not safe in most cases sql := fmt.Sprintf("INSERT INTO field_value_test VALUES (%s)", strings.Join(expected, ",")) _, err = s.c.Execute(sql) require.NoError(s.T(), err) diff --git a/mysql/field.go b/mysql/field.go index 2cb57467a..da83dca2c 100644 --- a/mysql/field.go +++ b/mysql/field.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "fmt" "strconv" + "strings" "github.com/go-mysql-org/go-mysql/utils" ) @@ -225,7 +226,17 @@ func (fv *FieldValue) String() string { case FieldValueTypeFloat: return strconv.FormatFloat(fv.AsFloat64(), 'f', -1, 64) case FieldValueTypeString: - return "'" + string(fv.AsString()) + "'" + b := strings.Builder{} + b.Grow(len(fv.str) + 2) + b.WriteByte('\'') + for i := range fv.str { + if fv.str[i] == '\'' { + b.WriteByte('\\') + } + b.WriteByte(fv.str[i]) + } + b.WriteByte('\'') + return b.String() default: return fmt.Sprintf("unknown type %d of FieldValue", fv.Type) }