diff --git a/connection.go b/connection.go index cc53793..da303fa 100644 --- a/connection.go +++ b/connection.go @@ -11,15 +11,17 @@ import ( // hiveOptions for opened Hive sessions. type hiveOptions struct { - PollIntervalSeconds int64 - BatchSize int64 + PollIntervalSeconds int64 + BatchSize int64 + ColumnsWithoutTableName bool // column names not contains table name } type hiveConnection struct { - thrift *hiveserver2.TCLIServiceClient - session *hiveserver2.TSessionHandle - options hiveOptions - ctx context.Context + thrift *hiveserver2.TCLIServiceClient + session *hiveserver2.TSessionHandle + options hiveOptions + ctx context.Context + paramsInterpolator *ParamsInterpolator } func (c *hiveConnection) Begin() (driver.Tx, error) { @@ -81,6 +83,13 @@ func removeLastSemicolon(s string) string { } func (c *hiveConnection) execute(ctx context.Context, query string, args []driver.NamedValue) (*hiveserver2.TExecuteStatementResp, error) { + var err error + if len(args) != 0 { + query, err = c.paramsInterpolator.InterpolateNamedValue(query, args) + if err != nil { + return nil, err + } + } executeReq := hiveserver2.NewTExecuteStatementReq() executeReq.SessionHandle = c.session executeReq.Statement = removeLastSemicolon(query) diff --git a/driver.go b/driver.go index 2d0d25c..e2cfa4e 100644 --- a/driver.go +++ b/driver.go @@ -66,12 +66,17 @@ func (d drv) Open(dsn string) (driver.Conn, error) { return nil, err } - options := hiveOptions{PollIntervalSeconds: 5, BatchSize: int64(cfg.Batch)} + options := hiveOptions{ + PollIntervalSeconds: 5, + BatchSize: int64(cfg.Batch), + ColumnsWithoutTableName: cfg.ColumnsWithoutTableName, + } conn := &hiveConnection{ - thrift: client, - session: session.SessionHandle, - options: options, - ctx: context.Background(), + thrift: client, + session: session.SessionHandle, + options: options, + ctx: context.Background(), + paramsInterpolator: NewParamsInterpolator(), } return conn, nil } diff --git a/driver_test.go b/driver_test.go index da73614..baf1f68 100644 --- a/driver_test.go +++ b/driver_test.go @@ -2,6 +2,7 @@ package gohive import ( "database/sql" + "database/sql/driver" "fmt" "os" "reflect" @@ -163,3 +164,11 @@ func TestExec(t *testing.T) { defer db.Close() a.NoError(err) } + +func TestExecArgs(t *testing.T) { + a := assert.New(t) + db, _ := newDB("churn") + _, err := db.Exec("insert into churn.test (gender) values (?)", []driver.Value{"Female"}) + defer db.Close() + a.NoError(err) +} diff --git a/dsn.go b/dsn.go index 2c7c591..68856ac 100644 --- a/dsn.go +++ b/dsn.go @@ -9,13 +9,14 @@ import ( ) type Config struct { - User string - Passwd string - Addr string - DBName string - Auth string - Batch int - SessionCfg map[string]string + User string + Passwd string + Addr string + DBName string + Auth string + Batch int + ColumnsWithoutTableName bool // column names not contains table name + SessionCfg map[string]string } var ( @@ -25,11 +26,12 @@ var ( ) const ( - sessionConfPrefix = "session." - authConfName = "auth" - defaultAuth = "NOSASL" - batchSizeName = "batch" - defaultBatchSize = 10000 + sessionConfPrefix = "session." + authConfName = "auth" + defaultAuth = "NOSASL" + batchSizeName = "batch" + columnsWithoutTableNameName = "columns_without_table_name" + defaultBatchSize = 10000 ) // ParseDSN requires DSN names in the format [user[:password]@]addr/dbname. @@ -60,6 +62,8 @@ func ParseDSN(dsn string) (*Config, error) { auth := defaultAuth batch := defaultBatchSize + columnsWithoutTableName := false + var err error sc := make(map[string]string) if len(sub[3]) > 0 && sub[3][0] == '?' { qry, _ := url.ParseQuery(sub[3][1:]) @@ -74,6 +78,12 @@ func ParseDSN(dsn string) (*Config, error) { } batch = bch } + if v, found := qry[columnsWithoutTableNameName]; found { + columnsWithoutTableName, err = strconv.ParseBool(v[0]) + if err != nil { + return nil, err + } + } for k, v := range qry { if strings.HasPrefix(k, sessionConfPrefix) { @@ -83,13 +93,14 @@ func ParseDSN(dsn string) (*Config, error) { } return &Config{ - User: user, - Passwd: passwd, - Addr: addr, - DBName: dbname, - Auth: auth, - Batch: batch, - SessionCfg: sc, + User: user, + Passwd: passwd, + Addr: addr, + DBName: dbname, + Auth: auth, + Batch: batch, + ColumnsWithoutTableName: columnsWithoutTableName, + SessionCfg: sc, }, nil } @@ -103,6 +114,9 @@ func (cfg *Config) FormatDSN() string { if len(cfg.Auth) > 0 { dsn += fmt.Sprintf("&auth=%s", cfg.Auth) } + if cfg.ColumnsWithoutTableName { + dsn += "&columns_without_table_name=true" + } if len(cfg.SessionCfg) > 0 { for k, v := range cfg.SessionCfg { dsn += fmt.Sprintf("&%s%s=%s", sessionConfPrefix, k, v) diff --git a/dsn_test.go b/dsn_test.go index d41ca75..7518668 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -117,3 +117,15 @@ func TestFormatDSNWithoutDBName(t *testing.T) { ds2 := cfg.FormatDSN() assert.Equal(t, ds2, ds) } + +func TestFormatDSNColumnsWithoutTableNameName(t *testing.T) { + ds := "user:passwd@127.0.0.1?columns_without_table_name=true" + cfg, e := ParseDSN(ds) + assert.Nil(t, e) + assert.True(t, cfg.ColumnsWithoutTableName) + + ds2 := "user:passwd@127.0.0.1" + cfg2, e := ParseDSN(ds2) + assert.Nil(t, e) + assert.False(t, cfg2.ColumnsWithoutTableName) +} diff --git a/params_replacer.go b/params_replacer.go new file mode 100644 index 0000000..4147af1 --- /dev/null +++ b/params_replacer.go @@ -0,0 +1,201 @@ +package gohive + +import ( + "database/sql/driver" + "encoding/hex" + "encoding/json" + "fmt" + "strconv" + "strings" + "time" +) + +const ( + TimeStampLayout = "2006-01-02 15:04:05.999999999" + DateLayout = "2006-01-02" +) + +type ParamsInterpolator struct { + Local *time.Location +} + +func NewParamsInterpolator() *ParamsInterpolator { + return &ParamsInterpolator{ + Local: time.Local, + } +} + +func (p *ParamsInterpolator) InterpolateNamedValue(query string, namedArgs []driver.NamedValue) (string, error) { + args, err := namedValueToValue(namedArgs) + if err != nil { + return "", err + } + return p.Interpolate(query, args) +} + +func (p *ParamsInterpolator) Interpolate(query string, args []driver.Value) (string, error) { + if strings.Count(query, "?") != len(args) { + return "", fmt.Errorf("gohive driver: number of ? [%d] must be equal to len(args): [%d]", + strings.Count(query, "?"), len(args)) + } + + var err error + + argIdx := 0 + var buf = make([]byte, 0, len(query)+len(args)*15) + for i := 0; i < len(query); i++ { + q := strings.IndexByte(query[i:], '?') + if q == -1 { + buf = append(buf, query[i:]...) + break + } + buf = append(buf, query[i:i+q]...) + i += q + + arg := args[argIdx] + argIdx++ + + buf, err = p.interpolateOne(buf, arg) + if err != nil { + return "", fmt.Errorf("gohive driver: failed to interpolate failed: %w, args[%d]: [%v]", + err, argIdx, arg) + } + + } + if argIdx != len(args) { + return "", fmt.Errorf("gohive driver: args are not all filled into SQL, argIdx: %d, total: %d", + argIdx, len(args)) + } + return string(buf), nil + +} + +func (p *ParamsInterpolator) interpolateOne(buf []byte, arg driver.Value) ([]byte, error) { + if arg == nil { + buf = append(buf, "NULL"...) + return buf, nil + } + + switch v := arg.(type) { + case int64: + buf = strconv.AppendInt(buf, v, 10) + case uint64: + // Handle uint64 explicitly because our custom ConvertValue emits unsigned values + buf = strconv.AppendUint(buf, v, 10) + case float64: + buf = strconv.AppendFloat(buf, v, 'g', -1, 64) + case bool: + if v { + buf = append(buf, "'true'"...) + } else { + buf = append(buf, "'false'"...) + } + case time.Time: + if v.IsZero() { + buf = append(buf, "'0000-00-00'"...) + } else { + buf = append(buf, '\'') + buf = appendDateTime(buf, v.In(p.Local)) + buf = append(buf, '\'') + } + case json.RawMessage: + buf = append(buf, '\'') + buf = appendBytes(buf, v) + buf = append(buf, '\'') + case []byte: + if v == nil { + buf = append(buf, "NULL"...) + } else { + buf = append(buf, "X'"...) + buf = appendBytes(buf, v) + buf = append(buf, '\'') + } + case string: + buf = append(buf, '\'') + buf = escapeStringBackslash(buf, v) + buf = append(buf, '\'') + default: + return nil, fmt.Errorf("gohive driver: unexpected args type: %T", arg) + } + return buf, nil +} + +func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { + args := make([]driver.Value, len(named)) + for n, param := range named { + if len(param.Name) > 0 { + return nil, fmt.Errorf("gohive driver: driver does not support the use of Named Parameters") + } + args[n] = param.Value + } + return args, nil +} + +func appendBytes(buf, v []byte) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)+hex.EncodedLen(len(v))) + pos += hex.Encode(buf[pos:], v) + return buf[:pos] +} + +func appendDateTime(buf []byte, t time.Time) []byte { + buf = t.AppendFormat(buf, TimeStampLayout) + return buf +} + +func escapeStringBackslash(buf []byte, v string) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2) + + for i := 0; i < len(v); i++ { + c := v[i] + switch c { + case '\x00': + buf[pos+1] = '0' + buf[pos] = '\\' + pos += 2 + case '\n': + buf[pos+1] = 'n' + buf[pos] = '\\' + pos += 2 + case '\r': + buf[pos+1] = 'r' + buf[pos] = '\\' + pos += 2 + case '\x1a': + buf[pos+1] = 'Z' + buf[pos] = '\\' + pos += 2 + case '\'': + buf[pos+1] = '\'' + buf[pos] = '\\' + pos += 2 + case '"': + buf[pos+1] = '"' + buf[pos] = '\\' + pos += 2 + case '\\': + buf[pos+1] = '\\' + buf[pos] = '\\' + pos += 2 + default: + buf[pos] = c + pos++ + } + } + + return buf[:pos] +} + +// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize. +// If cap(buf) is not enough, reallocate new buffer. +func reserveBuffer(buf []byte, appendSize int) []byte { + newSize := len(buf) + appendSize + if cap(buf) < newSize { + // Grow buffer exponentially + newBuf := make([]byte, len(buf)*2+appendSize) + copy(newBuf, buf) + buf = newBuf + } + return buf[:newSize] +} diff --git a/params_replacer_test.go b/params_replacer_test.go new file mode 100644 index 0000000..801ab38 --- /dev/null +++ b/params_replacer_test.go @@ -0,0 +1,102 @@ +package gohive + +import ( + "database/sql/driver" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestParamsInterpolator_Interpolate(t *testing.T) { + shanghaiLoc, err := time.LoadLocation("Asia/Shanghai") + assert.NoError(t, err) + type fields struct { + Local *time.Location + } + type args struct { + query string + args []driver.Value + } + tests := []struct { + name string + fields fields + args args + want string + wantErr assert.ErrorAssertionFunc + }{ + { + name: "number of ? [1] must be equal to len(args): [2]", + fields: fields{ + Local: time.Local, + }, + args: args{ + query: "SELECT * FROM table_name WHERE id = ?;", + args: []driver.Value{int64(1), string("123")}, + }, + want: "", + wantErr: assert.Error, + }, + { + name: "int", + fields: fields{ + Local: time.Local, + }, + args: args{ + query: "SELECT * FROM table_name WHERE id = ?;", + args: []driver.Value{int64(1)}, + }, + want: "SELECT * FROM table_name WHERE id = 1;", + wantErr: assert.NoError, + }, + { + name: "string bytes time zone", + fields: fields{ + Local: shanghaiLoc, + }, + args: args{ + query: "INSERT INTO table_name (field1, field2, field3) VALUES (?, ?, ?,?);", + args: []driver.Value{int64(1), string("\"hello\""), []byte("123abc&()"), time.Date(2024, 5, 5, 0, 0, 0, 0, shanghaiLoc)}, + }, + want: "INSERT INTO table_name (field1, field2, field3) VALUES (1, '\\\"hello\\\"', X'313233616263262829','2024-05-05 00:00:00');", + wantErr: assert.NoError, + }, + { + name: "\\", + fields: fields{ + Local: time.Local, + }, + args: args{ + query: "UPDATE table_name SET field1 = ?, field2 = ? WHERE id = ?;", + args: []driver.Value{int64(1), string("\"hello\""), []byte("123")}, + }, + want: "UPDATE table_name SET field1 = 1, field2 = '\\\"hello\\\"' WHERE id = X'313233';", + wantErr: assert.NoError, + }, + { + name: "\\\\\\", + fields: fields{ + Local: time.Local, + }, + args: args{ + query: "DELETE FROM table_name WHERE id = ?;", + args: []driver.Value{string(`abc \\\&&&`)}, + }, + want: "DELETE FROM table_name WHERE id = 'abc \\\\\\\\\\\\&&&';", + wantErr: assert.NoError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &ParamsInterpolator{ + Local: tt.fields.Local, + } + got, err := p.Interpolate(tt.args.query, tt.args.args) + if !tt.wantErr(t, err, fmt.Sprintf("Interpolate(%v, %v)", tt.args.query, tt.args.args)) { + return + } + assert.Equalf(t, tt.want, got, "Interpolate(%v, %v)", tt.args.query, tt.args.args) + }) + } +} diff --git a/rows.go b/rows.go index 11a7151..170da5e 100644 --- a/rows.go +++ b/rows.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "reflect" + "strings" "time" hiveserver2 "sqlflow.org/gohive/hiveserver2/gen-go/tcliservice" @@ -86,6 +87,10 @@ func (r *rowSet) Columns() []string { } ret := make([]string, len(r.columns)) for i, col := range r.columns { + if r.options.ColumnsWithoutTableName { + ret[i] = columnRemoveTable(col.ColumnName) + continue + } ret[i] = col.ColumnName } r.columnStrs = ret @@ -93,6 +98,14 @@ func (r *rowSet) Columns() []string { return r.columnStrs } +func columnRemoveTable(n string) string { + index := strings.Index(n, ".") + if index == -1 { + return n + } + return n[index+1:] +} + func (r *rowSet) Close() (err error) { return nil } @@ -211,7 +224,12 @@ func (r *rowSet) batchFetch() error { r.resultSet = make([][]interface{}, colLen) for i := 0; i < colLen; i++ { - v, length := convertColumn(rs[i]) + typeDesc := r.columns[i].TypeDesc + v, length, err := convertColumn(rs[i], typeDesc) + if err != nil { + return fmt.Errorf("convertColumn failed: %w, col_idx: %d, value: %+v, type: %+v", + err, i, rs[i], typeDesc) + } c := make([]interface{}, length) for j := 0; j < length; j++ { c[j] = reflect.ValueOf(v).Index(j).Interface() @@ -221,25 +239,73 @@ func (r *rowSet) batchFetch() error { return nil } -func convertColumn(col *hiveserver2.TColumn) (colValues interface{}, length int) { +var ( + dateTypeSet = map[hiveserver2.TTypeId]bool{ + hiveserver2.TTypeId_DATE_TYPE: true, + hiveserver2.TTypeId_TIMESTAMP_TYPE: true, + //hiveserver2.TTypeId_TIMESTAMPLOCALTZ_TYPE: true, // TODO: suport TIMESTAMPLOCALTZ + } +) + +func convertColumn(col *hiveserver2.TColumn, typeDesc *hiveserver2.TTypeDesc) (colValues interface{}, length int, err error) { + types := typeDesc.GetTypes() + var primitiveTypeID hiveserver2.TTypeId + var typ *hiveserver2.TTypeEntry + if len(types) > 0 { + typ = types[0] + if typ != nil && typ.GetPrimitiveEntry() != nil { + primitiveTypeID = typ.GetPrimitiveEntry().GetType() + } + } switch { case col.IsSetStringVal(): - return col.GetStringVal().GetValues(), len(col.GetStringVal().GetValues()) + strValues, length := col.GetStringVal().GetValues(), len(col.GetStringVal().GetValues()) + if dateTypeSet[primitiveTypeID] { + colValues, err = convertToDate(strValues, primitiveTypeID) + if err != nil { + return nil, 0, err + } + return colValues, length, nil + } + return strValues, length, nil case col.IsSetBoolVal(): - return col.GetBoolVal().GetValues(), len(col.GetBoolVal().GetValues()) + return col.GetBoolVal().GetValues(), len(col.GetBoolVal().GetValues()), nil case col.IsSetByteVal(): - return col.GetByteVal().GetValues(), len(col.GetByteVal().GetValues()) + return col.GetByteVal().GetValues(), len(col.GetByteVal().GetValues()), nil case col.IsSetI16Val(): - return col.GetI16Val().GetValues(), len(col.GetI16Val().GetValues()) + return col.GetI16Val().GetValues(), len(col.GetI16Val().GetValues()), nil case col.IsSetI32Val(): - return col.GetI32Val().GetValues(), len(col.GetI32Val().GetValues()) + return col.GetI32Val().GetValues(), len(col.GetI32Val().GetValues()), nil case col.IsSetI64Val(): - return col.GetI64Val().GetValues(), len(col.GetI64Val().GetValues()) + return col.GetI64Val().GetValues(), len(col.GetI64Val().GetValues()), nil case col.IsSetDoubleVal(): - return col.GetDoubleVal().GetValues(), len(col.GetDoubleVal().GetValues()) + return col.GetDoubleVal().GetValues(), len(col.GetDoubleVal().GetValues()), nil default: - return nil, 0 + return nil, 0, nil + } +} + +func convertToDate(values []string, primitiveTypeID hiveserver2.TTypeId) (res []time.Time, err error) { + for _, v := range values { + switch primitiveTypeID { + case hiveserver2.TTypeId_DATE_TYPE: + t, err := time.ParseInLocation(DateLayout, v, time.Local) + if err != nil { + return nil, err + } + res = append(res, t) + case hiveserver2.TTypeId_TIMESTAMP_TYPE: + t, err := time.ParseInLocation(TimeStampLayout, v, time.Local) + if err != nil { + return nil, err + } + res = append(res, t) + //case hiveserver2.TTypeId_TIMESTAMPLOCALTZ_TYPE: // TODO: support TIMESTAMPLOCALTZ + default: + return nil, fmt.Errorf("convertToDate failed, unsupported type %s", primitiveTypeID) + } } + return res, nil } func (s hiveStatus) isStopped() bool { diff --git a/rows_test.go b/rows_test.go new file mode 100644 index 0000000..befc546 --- /dev/null +++ b/rows_test.go @@ -0,0 +1,51 @@ +package gohive + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func Test_columnRemoveTable(t *testing.T) { + type args struct { + n string + } + tests := []struct { + name string + args args + want string + }{ + { + name: "", + args: args{ + n: "t.name", + }, + want: "name", + }, + { + name: "", + args: args{ + n: "name", + }, + want: "name", + }, + { + name: "", + args: args{ + n: "", + }, + want: "", + }, + { + name: "", + args: args{ + n: ".", + }, + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, columnRemoveTable(tt.args.n), "columnRemoveTable(%v)", tt.args.n) + }) + } +}