Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit fbdd9bf

Browse files
committed
Implemented UPDATE
Signed-off-by: Daylon Wilkins <[email protected]>
1 parent b09e8c1 commit fbdd9bf

File tree

8 files changed

+493
-8
lines changed

8 files changed

+493
-8
lines changed

engine.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ func (e *Engine) Query(
120120
case *plan.CreateIndex:
121121
typ = sql.CreateIndexProcess
122122
perm = auth.ReadPerm | auth.WritePerm
123-
case *plan.InsertInto, *plan.DeleteFrom, *plan.DropIndex, *plan.UnlockTables, *plan.LockTables:
123+
case *plan.InsertInto, *plan.DeleteFrom, *plan.Update, *plan.DropIndex, *plan.UnlockTables, *plan.LockTables:
124124
perm = auth.ReadPerm | auth.WritePerm
125125
}
126126

engine_test.go

+143-6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"github.com/src-d/go-mysql-server/sql/plan"
1919
"github.com/src-d/go-mysql-server/test"
2020

21+
2122
"github.com/stretchr/testify/require"
2223
)
2324

@@ -2209,6 +2210,142 @@ func TestReplaceIntoErrors(t *testing.T) {
22092210
}
22102211
}
22112212

2213+
func TestUpdate(t *testing.T) {
2214+
var updates = []struct {
2215+
updateQuery string
2216+
expectedUpdate []sql.Row
2217+
selectQuery string
2218+
expectedSelect []sql.Row
2219+
}{
2220+
{
2221+
"UPDATE mytable SET s = 'updated';",
2222+
[]sql.Row{{int64(3), int64(3)}},
2223+
"SELECT * FROM mytable;",
2224+
[]sql.Row{{int64(1), "updated"}, {int64(2), "updated"}, {int64(3), "updated"}},
2225+
},
2226+
{
2227+
"UPDATE mytable SET s = 'updated' WHERE i > 9999;",
2228+
[]sql.Row{{int64(0), int64(0)}},
2229+
"SELECT * FROM mytable;",
2230+
[]sql.Row{{int64(1), "first row"}, {int64(2), "second row"}, {int64(3), "third row"}},
2231+
},
2232+
{
2233+
"UPDATE mytable SET s = 'updated' WHERE i = 1;",
2234+
[]sql.Row{{int64(1), int64(1)}},
2235+
"SELECT * FROM mytable;",
2236+
[]sql.Row{{int64(1), "updated"}, {int64(2), "second row"}, {int64(3), "third row"}},
2237+
},
2238+
{
2239+
"UPDATE mytable SET s = 'updated' WHERE i <> 9999;",
2240+
[]sql.Row{{int64(3), int64(3)}},
2241+
"SELECT * FROM mytable;",
2242+
[]sql.Row{{int64(1), "updated"},{int64(2), "updated"},{int64(3), "updated"}},
2243+
},
2244+
{
2245+
"UPDATE floattable SET f32 = f32 + f32, f64 = f32 * f64 WHERE i = 2;",
2246+
[]sql.Row{{int64(1), int64(1)}},
2247+
"SELECT * FROM floattable WHERE i = 2;",
2248+
[]sql.Row{{int64(2), float32(3.0), float64(4.5)}},
2249+
},
2250+
{
2251+
"UPDATE floattable SET f32 = 5, f32 = 4 WHERE i = 1;",
2252+
[]sql.Row{{int64(1), int64(1)}},
2253+
"SELECT f32 FROM floattable WHERE i = 1;",
2254+
[]sql.Row{{float32(4.0)}},
2255+
},
2256+
{
2257+
"UPDATE mytable SET s = 'first row' WHERE i = 1;",
2258+
[]sql.Row{{int64(1), int64(0)}},
2259+
"SELECT * FROM mytable;",
2260+
[]sql.Row{{int64(1), "first row"}, {int64(2), "second row"}, {int64(3), "third row"}},
2261+
},
2262+
{
2263+
"UPDATE niltable SET b = NULL WHERE f IS NULL;",
2264+
[]sql.Row{{int64(2), int64(1)}},
2265+
"SELECT * FROM niltable WHERE f IS NULL;",
2266+
[]sql.Row{{int64(4), nil, nil}, {nil, nil, nil}},
2267+
},
2268+
{
2269+
"UPDATE mytable SET s = 'updated' ORDER BY i ASC LIMIT 2;",
2270+
[]sql.Row{{int64(2), int64(2)}},
2271+
"SELECT * FROM mytable;",
2272+
[]sql.Row{{int64(1), "updated"}, {int64(2), "updated"}, {int64(3), "third row"}},
2273+
},
2274+
{
2275+
"UPDATE mytable SET s = 'updated' ORDER BY i DESC LIMIT 2;",
2276+
[]sql.Row{{int64(2), int64(2)}},
2277+
"SELECT * FROM mytable;",
2278+
[]sql.Row{{int64(1), "first row"}, {int64(2), "updated"}, {int64(3), "updated"}},
2279+
},
2280+
{
2281+
"UPDATE mytable SET s = 'updated' ORDER BY i LIMIT 1 OFFSET 1;",
2282+
[]sql.Row{{int64(1), int64(1)}},
2283+
"SELECT * FROM mytable;",
2284+
[]sql.Row{{int64(1), "first row"}, {int64(2), "updated"}, {int64(3), "third row"}},
2285+
},
2286+
{
2287+
"UPDATE mytable SET s = 'updated';",
2288+
[]sql.Row{{int64(3), int64(3)}},
2289+
"SELECT * FROM mytable;",
2290+
[]sql.Row{{int64(1), "updated"}, {int64(2), "updated"}, {int64(3), "updated"}},
2291+
},
2292+
}
2293+
2294+
for _, update := range updates {
2295+
e := newEngine(t)
2296+
ctx := newCtx()
2297+
testQueryWithContext(ctx, t, e, update.updateQuery, update.expectedUpdate)
2298+
testQueryWithContext(ctx, t, e, update.selectQuery, update.expectedSelect)
2299+
}
2300+
}
2301+
2302+
func TestUpdateErrors(t *testing.T) {
2303+
var expectedFailures = []struct {
2304+
name string
2305+
query string
2306+
}{
2307+
{
2308+
"invalid table",
2309+
"UPDATE doesnotexist SET i = 0;",
2310+
},
2311+
{
2312+
"invalid column set",
2313+
"UPDATE mytable SET z = 0;",
2314+
},
2315+
{
2316+
"invalid column set value",
2317+
"UPDATE mytable SET i = z;",
2318+
},
2319+
{
2320+
"invalid column where",
2321+
"UPDATE mytable SET s = 'hi' WHERE z = 1;",
2322+
},
2323+
{
2324+
"invalid column order by",
2325+
"UPDATE mytable SET s = 'hi' ORDER BY z;",
2326+
},
2327+
{
2328+
"negative limit",
2329+
"UPDATE mytable SET s = 'hi' LIMIT -1;",
2330+
},
2331+
{
2332+
"negative offset",
2333+
"UPDATE mytable SET s = 'hi' LIMIT 1 OFFSET -1;",
2334+
},
2335+
{
2336+
"set null on non-nullable",
2337+
"UPDATE mytable SET s = NULL;",
2338+
},
2339+
}
2340+
2341+
for _, expectedFailure := range expectedFailures {
2342+
t.Run(expectedFailure.name, func(t *testing.T) {
2343+
_, _, err := newEngine(t).Query(newCtx(), expectedFailure.query)
2344+
require.Error(t, err)
2345+
})
2346+
}
2347+
}
2348+
22122349
const testNumPartitions = 5
22132350

22142351
func TestAmbiguousColumnResolution(t *testing.T) {
@@ -2634,12 +2771,12 @@ func newEngineWithParallelism(t *testing.T, parallelism int) *sqle.Engine {
26342771

26352772
insertRows(
26362773
t, floatTable,
2637-
sql.NewRow(1, float32(1.0), float64(1.0)),
2638-
sql.NewRow(2, float32(1.5), float64(1.5)),
2639-
sql.NewRow(3, float32(2.0), float64(2.0)),
2640-
sql.NewRow(4, float32(2.5), float64(2.5)),
2641-
sql.NewRow(-1, float32(-1.0), float64(-1.0)),
2642-
sql.NewRow(-2, float32(-1.5), float64(-1.5)),
2774+
sql.NewRow(int64(1), float32(1.0), float64(1.0)),
2775+
sql.NewRow(int64(2), float32(1.5), float64(1.5)),
2776+
sql.NewRow(int64(3), float32(2.0), float64(2.0)),
2777+
sql.NewRow(int64(4), float32(2.5), float64(2.5)),
2778+
sql.NewRow(int64(-1), float32(-1.0), float64(-1.0)),
2779+
sql.NewRow(int64(-2), float32(-1.5), float64(-1.5)),
26432780
)
26442781

26452782
nilTable := memory.NewPartitionedTable("niltable", sql.Schema{

memory/table.go

+31
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,37 @@ func (t *Table) Delete(ctx *sql.Context, row sql.Row) error {
290290
return nil
291291
}
292292

293+
func (t *Table) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.Row) error {
294+
if err := checkRow(t.schema, oldRow); err != nil {
295+
return err
296+
}
297+
if err := checkRow(t.schema, newRow); err != nil {
298+
return err
299+
}
300+
301+
matches := false
302+
for partitionIndex, partition := range t.partitions {
303+
for partitionRowIndex, partitionRow := range partition {
304+
matches = true
305+
for rIndex, val := range oldRow {
306+
if val != partitionRow[rIndex] {
307+
matches = false
308+
break
309+
}
310+
}
311+
if matches {
312+
t.partitions[partitionIndex][partitionRowIndex] = newRow
313+
break
314+
}
315+
}
316+
if matches {
317+
break
318+
}
319+
}
320+
321+
return nil
322+
}
323+
293324
func checkRow(schema sql.Schema, row sql.Row) error {
294325
if len(row) != len(schema) {
295326
return sql.ErrUnexpectedRowLength.New(len(schema), len(row))

sql/analyzer/pushdown.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ func pushdown(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
2020

2121
// don't do pushdown on certain queries
2222
switch n.(type) {
23-
case *plan.InsertInto, *plan.DeleteFrom, *plan.CreateIndex:
23+
case *plan.InsertInto, *plan.DeleteFrom, *plan.Update, *plan.CreateIndex:
2424
return n, nil
2525
}
2626

sql/core.go

+6
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,12 @@ type Replacer interface {
217217
Inserter
218218
}
219219

220+
// Updater allows rows to be updated.
221+
type Updater interface {
222+
// Update the given row. Provides both the old and new rows.
223+
Update(ctx *Context, old Row, new Row) error
224+
}
225+
220226
// Database represents the database.
221227
type Database interface {
222228
Nameable

sql/expression/set.go

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package expression
2+
3+
import (
4+
"fmt"
5+
"github.com/src-d/go-mysql-server/sql"
6+
"gopkg.in/src-d/go-errors.v1"
7+
)
8+
9+
var errCannotSetField = errors.NewKind("Expected GetField expression on left but got %T")
10+
11+
// SetField updates the value of a field from a row.
12+
type SetField struct {
13+
BinaryExpression
14+
}
15+
16+
// NewSetField creates a new SetField expression.
17+
func NewSetField(colName, expr sql.Expression) sql.Expression {
18+
return &SetField{BinaryExpression{Left: colName, Right: expr}}
19+
}
20+
21+
func (s *SetField) String() string {
22+
return fmt.Sprintf("SETFIELD %s = %s", s.Left, s.Right)
23+
}
24+
25+
// Type implements the Expression interface.
26+
func (s *SetField) Type() sql.Type {
27+
return s.Left.Type()
28+
}
29+
30+
// Eval implements the Expression interface.
31+
// Returns a copy of the given row with an updated value.
32+
func (s *SetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
33+
getField, ok := s.Left.(*GetField)
34+
if !ok {
35+
return nil, errCannotSetField.New(s.Left)
36+
}
37+
if getField.fieldIndex < 0 || getField.fieldIndex >= len(row) {
38+
return nil, ErrIndexOutOfBounds.New(getField.fieldIndex, len(row))
39+
}
40+
val, err := s.Right.Eval(ctx, row)
41+
if err != nil {
42+
return nil, err
43+
}
44+
if val != nil {
45+
val, err = getField.fieldType.Convert(val)
46+
if err != nil {
47+
return nil, err
48+
}
49+
}
50+
updatedRow := row.Copy()
51+
updatedRow[getField.fieldIndex] = val
52+
return updatedRow, nil
53+
}
54+
55+
// WithChildren implements the Expression interface.
56+
func (s *SetField) WithChildren(children ...sql.Expression) (sql.Expression, error) {
57+
if len(children) != 2 {
58+
return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 2)
59+
}
60+
return NewSetField(children[0], children[1]), nil
61+
}

sql/parse/parse.go

+61
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ func convert(ctx *sql.Context, stmt sqlparser.Statement, query string) (sql.Node
150150
return plan.NewRollback(), nil
151151
case *sqlparser.Delete:
152152
return convertDelete(ctx, n)
153+
case *sqlparser.Update:
154+
return convertUpdate(ctx, n)
153155
}
154156
}
155157

@@ -426,6 +428,49 @@ func convertDelete(ctx *sql.Context, d *sqlparser.Delete) (sql.Node, error) {
426428
return plan.NewDeleteFrom(node), nil
427429
}
428430

431+
func convertUpdate(ctx *sql.Context, d *sqlparser.Update) (sql.Node, error) {
432+
node, err := tableExprsToTable(ctx, d.TableExprs)
433+
if err != nil {
434+
return nil, err
435+
}
436+
437+
updateExprs, err := updateExprsToExpressions(d.Exprs)
438+
if err != nil {
439+
return nil, err
440+
}
441+
442+
if d.Where != nil {
443+
node, err = whereToFilter(d.Where, node)
444+
if err != nil {
445+
return nil, err
446+
}
447+
}
448+
449+
if len(d.OrderBy) != 0 {
450+
node, err = orderByToSort(d.OrderBy, node)
451+
if err != nil {
452+
return nil, err
453+
}
454+
}
455+
456+
// Limit must wrap offset, and not vice-versa, so that skipped rows don't count toward the returned row count.
457+
if d.Limit != nil && d.Limit.Offset != nil {
458+
node, err = offsetToOffset(ctx, d.Limit.Offset, node)
459+
if err != nil {
460+
return nil, err
461+
}
462+
}
463+
464+
if d.Limit != nil {
465+
node, err = limitToLimit(ctx, d.Limit.Rowcount, node)
466+
if err != nil {
467+
return nil, err
468+
}
469+
}
470+
471+
return plan.NewUpdate(node, updateExprs), nil
472+
}
473+
429474
func columnDefinitionToSchema(colDef []*sqlparser.ColumnDefinition) (sql.Schema, error) {
430475
var schema sql.Schema
431476
for _, cd := range colDef {
@@ -1201,6 +1246,22 @@ func intervalExprToExpression(e *sqlparser.IntervalExpr) (sql.Expression, error)
12011246
return expression.NewInterval(expr, e.Unit), nil
12021247
}
12031248

1249+
func updateExprsToExpressions(e sqlparser.UpdateExprs) ([]sql.Expression, error) {
1250+
res := make([]sql.Expression, len(e))
1251+
for i, updateExpr := range e {
1252+
colName, err := exprToExpression(updateExpr.Name)
1253+
if err != nil {
1254+
return nil, err
1255+
}
1256+
innerExpr, err := exprToExpression(updateExpr.Expr)
1257+
if err != nil {
1258+
return nil, err
1259+
}
1260+
res[i] = expression.NewSetField(colName, innerExpr)
1261+
}
1262+
return res, nil
1263+
}
1264+
12041265
func removeComments(s string) string {
12051266
r := bufio.NewReader(strings.NewReader(s))
12061267
var result []rune

0 commit comments

Comments
 (0)