Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Implemented UPDATE #832

Merged
merged 2 commits into from
Oct 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func (e *Engine) Query(
case *plan.CreateIndex:
typ = sql.CreateIndexProcess
perm = auth.ReadPerm | auth.WritePerm
case *plan.InsertInto, *plan.DeleteFrom, *plan.DropIndex, *plan.UnlockTables, *plan.LockTables:
case *plan.InsertInto, *plan.DeleteFrom, *plan.Update, *plan.DropIndex, *plan.UnlockTables, *plan.LockTables:
perm = auth.ReadPerm | auth.WritePerm
}

Expand Down
149 changes: 143 additions & 6 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/src-d/go-mysql-server/sql/plan"
"github.com/src-d/go-mysql-server/test"


"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -2245,6 +2246,142 @@ func TestReplaceIntoErrors(t *testing.T) {
}
}

func TestUpdate(t *testing.T) {
var updates = []struct {
updateQuery string
expectedUpdate []sql.Row
selectQuery string
expectedSelect []sql.Row
}{
{
"UPDATE mytable SET s = 'updated';",
[]sql.Row{{int64(3), int64(3)}},
"SELECT * FROM mytable;",
[]sql.Row{{int64(1), "updated"}, {int64(2), "updated"}, {int64(3), "updated"}},
},
{
"UPDATE mytable SET s = 'updated' WHERE i > 9999;",
[]sql.Row{{int64(0), int64(0)}},
"SELECT * FROM mytable;",
[]sql.Row{{int64(1), "first row"}, {int64(2), "second row"}, {int64(3), "third row"}},
},
{
"UPDATE mytable SET s = 'updated' WHERE i = 1;",
[]sql.Row{{int64(1), int64(1)}},
"SELECT * FROM mytable;",
[]sql.Row{{int64(1), "updated"}, {int64(2), "second row"}, {int64(3), "third row"}},
},
{
"UPDATE mytable SET s = 'updated' WHERE i <> 9999;",
[]sql.Row{{int64(3), int64(3)}},
"SELECT * FROM mytable;",
[]sql.Row{{int64(1), "updated"},{int64(2), "updated"},{int64(3), "updated"}},
},
{
"UPDATE floattable SET f32 = f32 + f32, f64 = f32 * f64 WHERE i = 2;",
[]sql.Row{{int64(1), int64(1)}},
"SELECT * FROM floattable WHERE i = 2;",
[]sql.Row{{int64(2), float32(3.0), float64(4.5)}},
},
{
"UPDATE floattable SET f32 = 5, f32 = 4 WHERE i = 1;",
[]sql.Row{{int64(1), int64(1)}},
"SELECT f32 FROM floattable WHERE i = 1;",
[]sql.Row{{float32(4.0)}},
},
{
"UPDATE mytable SET s = 'first row' WHERE i = 1;",
[]sql.Row{{int64(1), int64(0)}},
"SELECT * FROM mytable;",
[]sql.Row{{int64(1), "first row"}, {int64(2), "second row"}, {int64(3), "third row"}},
},
{
"UPDATE niltable SET b = NULL WHERE f IS NULL;",
[]sql.Row{{int64(2), int64(1)}},
"SELECT * FROM niltable WHERE f IS NULL;",
[]sql.Row{{int64(4), nil, nil}, {nil, nil, nil}},
},
{
"UPDATE mytable SET s = 'updated' ORDER BY i ASC LIMIT 2;",
[]sql.Row{{int64(2), int64(2)}},
"SELECT * FROM mytable;",
[]sql.Row{{int64(1), "updated"}, {int64(2), "updated"}, {int64(3), "third row"}},
},
{
"UPDATE mytable SET s = 'updated' ORDER BY i DESC LIMIT 2;",
[]sql.Row{{int64(2), int64(2)}},
"SELECT * FROM mytable;",
[]sql.Row{{int64(1), "first row"}, {int64(2), "updated"}, {int64(3), "updated"}},
},
{
"UPDATE mytable SET s = 'updated' ORDER BY i LIMIT 1 OFFSET 1;",
[]sql.Row{{int64(1), int64(1)}},
"SELECT * FROM mytable;",
[]sql.Row{{int64(1), "first row"}, {int64(2), "updated"}, {int64(3), "third row"}},
},
{
"UPDATE mytable SET s = 'updated';",
[]sql.Row{{int64(3), int64(3)}},
"SELECT * FROM mytable;",
[]sql.Row{{int64(1), "updated"}, {int64(2), "updated"}, {int64(3), "updated"}},
},
}

for _, update := range updates {
e := newEngine(t)
ctx := newCtx()
testQueryWithContext(ctx, t, e, update.updateQuery, update.expectedUpdate)
testQueryWithContext(ctx, t, e, update.selectQuery, update.expectedSelect)
}
}

func TestUpdateErrors(t *testing.T) {
var expectedFailures = []struct {
name string
query string
}{
{
"invalid table",
"UPDATE doesnotexist SET i = 0;",
},
{
"invalid column set",
"UPDATE mytable SET z = 0;",
},
{
"invalid column set value",
"UPDATE mytable SET i = z;",
},
{
"invalid column where",
"UPDATE mytable SET s = 'hi' WHERE z = 1;",
},
{
"invalid column order by",
"UPDATE mytable SET s = 'hi' ORDER BY z;",
},
{
"negative limit",
"UPDATE mytable SET s = 'hi' LIMIT -1;",
},
{
"negative offset",
"UPDATE mytable SET s = 'hi' LIMIT 1 OFFSET -1;",
},
{
"set null on non-nullable",
"UPDATE mytable SET s = NULL;",
},
}

for _, expectedFailure := range expectedFailures {
t.Run(expectedFailure.name, func(t *testing.T) {
_, _, err := newEngine(t).Query(newCtx(), expectedFailure.query)
require.Error(t, err)
})
}
}

const testNumPartitions = 5

func TestAmbiguousColumnResolution(t *testing.T) {
Expand Down Expand Up @@ -2670,12 +2807,12 @@ func newEngineWithParallelism(t *testing.T, parallelism int) *sqle.Engine {

insertRows(
t, floatTable,
sql.NewRow(1, float32(1.0), float64(1.0)),
sql.NewRow(2, float32(1.5), float64(1.5)),
sql.NewRow(3, float32(2.0), float64(2.0)),
sql.NewRow(4, float32(2.5), float64(2.5)),
sql.NewRow(-1, float32(-1.0), float64(-1.0)),
sql.NewRow(-2, float32(-1.5), float64(-1.5)),
sql.NewRow(int64(1), float32(1.0), float64(1.0)),
sql.NewRow(int64(2), float32(1.5), float64(1.5)),
sql.NewRow(int64(3), float32(2.0), float64(2.0)),
sql.NewRow(int64(4), float32(2.5), float64(2.5)),
sql.NewRow(int64(-1), float32(-1.0), float64(-1.0)),
sql.NewRow(int64(-2), float32(-1.5), float64(-1.5)),
)

nilTable := memory.NewPartitionedTable("niltable", sql.Schema{
Expand Down
31 changes: 31 additions & 0 deletions memory/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,37 @@ func (t *Table) Delete(ctx *sql.Context, row sql.Row) error {
return nil
}

func (t *Table) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.Row) error {
if err := checkRow(t.schema, oldRow); err != nil {
return err
}
if err := checkRow(t.schema, newRow); err != nil {
return err
}

matches := false
for partitionIndex, partition := range t.partitions {
for partitionRowIndex, partitionRow := range partition {
matches = true
for rIndex, val := range oldRow {
if val != partitionRow[rIndex] {
matches = false
break
}
}
if matches {
t.partitions[partitionIndex][partitionRowIndex] = newRow
break
}
}
if matches {
break
}
}

return nil
}

func checkRow(schema sql.Schema, row sql.Row) error {
if len(row) != len(schema) {
return sql.ErrUnexpectedRowLength.New(len(schema), len(row))
Expand Down
2 changes: 1 addition & 1 deletion sql/analyzer/pushdown.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func pushdown(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {

// don't do pushdown on certain queries
switch n.(type) {
case *plan.InsertInto, *plan.DeleteFrom, *plan.CreateIndex:
case *plan.InsertInto, *plan.DeleteFrom, *plan.Update, *plan.CreateIndex:
return n, nil
}

Expand Down
6 changes: 6 additions & 0 deletions sql/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,12 @@ type Replacer interface {
Inserter
}

// Updater allows rows to be updated.
type Updater interface {
// Update the given row. Provides both the old and new rows.
Update(ctx *Context, old Row, new Row) error
}

// Database represents the database.
type Database interface {
Nameable
Expand Down
61 changes: 61 additions & 0 deletions sql/expression/set.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package expression

import (
"fmt"
"github.com/src-d/go-mysql-server/sql"
"gopkg.in/src-d/go-errors.v1"
)

var errCannotSetField = errors.NewKind("Expected GetField expression on left but got %T")

// SetField updates the value of a field from a row.
type SetField struct {
BinaryExpression
}

// NewSetField creates a new SetField expression.
func NewSetField(colName, expr sql.Expression) sql.Expression {
return &SetField{BinaryExpression{Left: colName, Right: expr}}
}

func (s *SetField) String() string {
return fmt.Sprintf("SETFIELD %s = %s", s.Left, s.Right)
}

// Type implements the Expression interface.
func (s *SetField) Type() sql.Type {
return s.Left.Type()
}

// Eval implements the Expression interface.
// Returns a copy of the given row with an updated value.
func (s *SetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
getField, ok := s.Left.(*GetField)
if !ok {
return nil, errCannotSetField.New(s.Left)
}
if getField.fieldIndex < 0 || getField.fieldIndex >= len(row) {
return nil, ErrIndexOutOfBounds.New(getField.fieldIndex, len(row))
}
val, err := s.Right.Eval(ctx, row)
if err != nil {
return nil, err
}
if val != nil {
val, err = getField.fieldType.Convert(val)
if err != nil {
return nil, err
}
}
updatedRow := row.Copy()
updatedRow[getField.fieldIndex] = val
return updatedRow, nil
}

// WithChildren implements the Expression interface.
func (s *SetField) WithChildren(children ...sql.Expression) (sql.Expression, error) {
if len(children) != 2 {
return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 2)
}
return NewSetField(children[0], children[1]), nil
}
61 changes: 61 additions & 0 deletions sql/parse/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ func convert(ctx *sql.Context, stmt sqlparser.Statement, query string) (sql.Node
return plan.NewRollback(), nil
case *sqlparser.Delete:
return convertDelete(ctx, n)
case *sqlparser.Update:
return convertUpdate(ctx, n)
}
}

Expand Down Expand Up @@ -429,6 +431,49 @@ func convertDelete(ctx *sql.Context, d *sqlparser.Delete) (sql.Node, error) {
return plan.NewDeleteFrom(node), nil
}

func convertUpdate(ctx *sql.Context, d *sqlparser.Update) (sql.Node, error) {
node, err := tableExprsToTable(ctx, d.TableExprs)
if err != nil {
return nil, err
}

updateExprs, err := updateExprsToExpressions(ctx, d.Exprs)
if err != nil {
return nil, err
}

if d.Where != nil {
node, err = whereToFilter(ctx, d.Where, node)
if err != nil {
return nil, err
}
}

if len(d.OrderBy) != 0 {
node, err = orderByToSort(ctx, d.OrderBy, node)
if err != nil {
return nil, err
}
}

// Limit must wrap offset, and not vice-versa, so that skipped rows don't count toward the returned row count.
if d.Limit != nil && d.Limit.Offset != nil {
node, err = offsetToOffset(ctx, d.Limit.Offset, node)
if err != nil {
return nil, err
}
}

if d.Limit != nil {
node, err = limitToLimit(ctx, d.Limit.Rowcount, node)
if err != nil {
return nil, err
}
}

return plan.NewUpdate(node, updateExprs), nil
}

func columnDefinitionToSchema(colDef []*sqlparser.ColumnDefinition) (sql.Schema, error) {
var schema sql.Schema
for _, cd := range colDef {
Expand Down Expand Up @@ -1241,6 +1286,22 @@ func intervalExprToExpression(ctx *sql.Context, e *sqlparser.IntervalExpr) (sql.
return expression.NewInterval(expr, e.Unit), nil
}

func updateExprsToExpressions(ctx *sql.Context, e sqlparser.UpdateExprs) ([]sql.Expression, error) {
res := make([]sql.Expression, len(e))
for i, updateExpr := range e {
colName, err := exprToExpression(ctx, updateExpr.Name)
if err != nil {
return nil, err
}
innerExpr, err := exprToExpression(ctx, updateExpr.Expr)
if err != nil {
return nil, err
}
res[i] = expression.NewSetField(colName, innerExpr)
}
return res, nil
}

func removeComments(s string) string {
r := bufio.NewReader(strings.NewReader(s))
var result []rune
Expand Down
Loading