Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit 151d56b

Browse files
authored
*: implement subquery expressions (#835)
*: implement subquery expressions
2 parents 5a8c745 + 258a735 commit 151d56b

26 files changed

+744
-182
lines changed

Diff for: SUPPORTED.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,6 @@
8080
- div
8181
- %
8282

83-
## Subqueries
84-
- supported only as tables, not as expressions.
85-
8683
## Functions
8784
- ARRAY_LENGTH
8885
- CEIL
@@ -133,3 +130,6 @@
133130
- WEEKDAY
134131
- YEAR
135132
- YEARWEEK
133+
134+
## Subqueries
135+
Supported both as a table and as expressions but they can't access the parent query scope.

Diff for: engine.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package sqle // import "github.com/src-d/go-mysql-server"
1+
package sqle
22

33
import (
44
"time"

Diff for: engine_test.go

+26-4
Original file line numberDiff line numberDiff line change
@@ -1548,6 +1548,28 @@ var queries = []struct {
15481548
{int64(5), "there is some text in here"},
15491549
},
15501550
},
1551+
{
1552+
`SELECT i FROM mytable WHERE i = (SELECT 1)`,
1553+
[]sql.Row{{int64(1)}},
1554+
},
1555+
{
1556+
`SELECT i FROM mytable WHERE i IN (SELECT i FROM mytable)`,
1557+
[]sql.Row{
1558+
{int64(1)},
1559+
{int64(2)},
1560+
{int64(3)},
1561+
},
1562+
},
1563+
{
1564+
`SELECT i FROM mytable WHERE i NOT IN (SELECT i FROM mytable ORDER BY i ASC LIMIT 2)`,
1565+
[]sql.Row{
1566+
{int64(3)},
1567+
},
1568+
},
1569+
{
1570+
`SELECT (SELECT i FROM mytable ORDER BY i ASC LIMIT 1) AS x`,
1571+
[]sql.Row{{int64(1)}},
1572+
},
15511573
}
15521574

15531575
func TestQueries(t *testing.T) {
@@ -1901,7 +1923,7 @@ func TestInsertInto(t *testing.T) {
19011923
[]sql.Row{{int64(1)}},
19021924
"SELECT * FROM typestable WHERE id = 999;",
19031925
[]sql.Row{{
1904-
int64(999), int64(-math.MaxInt8-1), int64(-math.MaxInt16-1), int64(-math.MaxInt32-1), int64(-math.MaxInt64-1),
1926+
int64(999), int64(-math.MaxInt8 - 1), int64(-math.MaxInt16 - 1), int64(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1),
19051927
int64(0), int64(0), int64(0), int64(0),
19061928
float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64),
19071929
timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"),
@@ -1919,7 +1941,7 @@ func TestInsertInto(t *testing.T) {
19191941
[]sql.Row{{int64(1)}},
19201942
"SELECT * FROM typestable WHERE id = 999;",
19211943
[]sql.Row{{
1922-
int64(999), int64(-math.MaxInt8-1), int64(-math.MaxInt16-1), int64(-math.MaxInt32-1), int64(-math.MaxInt64-1),
1944+
int64(999), int64(-math.MaxInt8 - 1), int64(-math.MaxInt16 - 1), int64(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1),
19231945
int64(0), int64(0), int64(0), int64(0),
19241946
float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64),
19251947
timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"),
@@ -2101,7 +2123,7 @@ func TestReplaceInto(t *testing.T) {
21012123
[]sql.Row{{int64(1)}},
21022124
"SELECT * FROM typestable WHERE id = 999;",
21032125
[]sql.Row{{
2104-
int64(999), int64(-math.MaxInt8-1), int64(-math.MaxInt16-1), int64(-math.MaxInt32-1), int64(-math.MaxInt64-1),
2126+
int64(999), int64(-math.MaxInt8 - 1), int64(-math.MaxInt16 - 1), int64(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1),
21052127
int64(0), int64(0), int64(0), int64(0),
21062128
float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64),
21072129
timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"),
@@ -2119,7 +2141,7 @@ func TestReplaceInto(t *testing.T) {
21192141
[]sql.Row{{int64(1)}},
21202142
"SELECT * FROM typestable WHERE id = 999;",
21212143
[]sql.Row{{
2122-
int64(999), int64(-math.MaxInt8-1), int64(-math.MaxInt16-1), int64(-math.MaxInt32-1), int64(-math.MaxInt64-1),
2144+
int64(999), int64(-math.MaxInt8 - 1), int64(-math.MaxInt16 - 1), int64(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1),
21232145
int64(0), int64(0), int64(0), int64(0),
21242146
float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64),
21252147
timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"),

Diff for: log.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package sqle // import "github.com/src-d/go-mysql-server"
1+
package sqle
22

33
import (
44
"github.com/golang/glog"

Diff for: server/server.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package server // import "github.com/src-d/go-mysql-server/server"
1+
package server
22

33
import (
44
"time"

Diff for: sql/analyzer/analyzer.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package analyzer // import "github.com/src-d/go-mysql-server/sql/analyzer"
1+
package analyzer
22

33
import (
44
"os"

Diff for: sql/analyzer/assign_indexes.go

+13-1
Original file line numberDiff line numberDiff line change
@@ -759,8 +759,20 @@ func containsColumns(e sql.Expression) bool {
759759
return result
760760
}
761761

762+
func containsSubquery(e sql.Expression) bool {
763+
var result bool
764+
expression.Inspect(e, func(e sql.Expression) bool {
765+
if _, ok := e.(*expression.Subquery); ok {
766+
result = true
767+
return false
768+
}
769+
return true
770+
})
771+
return result
772+
}
773+
762774
func isEvaluable(e sql.Expression) bool {
763-
return !containsColumns(e)
775+
return !containsColumns(e) && !containsSubquery(e)
764776
}
765777

766778
func canMergeIndexes(a, b sql.IndexLookup) bool {

Diff for: sql/analyzer/resolve_subqueries.go

+23-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package analyzer
22

33
import (
44
"github.com/src-d/go-mysql-server/sql"
5+
"github.com/src-d/go-mysql-server/sql/expression"
56
"github.com/src-d/go-mysql-server/sql/plan"
67
)
78

@@ -10,7 +11,7 @@ func resolveSubqueries(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, err
1011
defer span.Finish()
1112

1213
a.Log("resolving subqueries")
13-
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
14+
n, err := plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
1415
switch n := n.(type) {
1516
case *plan.SubqueryAlias:
1617
a.Log("found subquery %q with child of type %T", n.Name(), n.Child)
@@ -24,4 +25,25 @@ func resolveSubqueries(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, err
2425
return n, nil
2526
}
2627
})
28+
if err != nil {
29+
return nil, err
30+
}
31+
32+
return plan.TransformExpressionsUp(n, func(e sql.Expression) (sql.Expression, error) {
33+
s, ok := e.(*expression.Subquery)
34+
if !ok || s.Resolved() {
35+
return e, nil
36+
}
37+
38+
q, err := a.Analyze(ctx, s.Query)
39+
if err != nil {
40+
return nil, err
41+
}
42+
43+
if qp, ok := q.(*plan.QueryProcess); ok {
44+
q = qp.Child
45+
}
46+
47+
return s.WithQuery(q), nil
48+
})
2749
}

Diff for: sql/analyzer/validation_rules.go

+27
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ const (
2020
validateCaseResultTypesRule = "validate_case_result_types"
2121
validateIntervalUsageRule = "validate_interval_usage"
2222
validateExplodeUsageRule = "validate_explode_usage"
23+
validateSubqueryColumnsRule = "validate_subquery_columns"
2324
)
2425

2526
var (
@@ -57,6 +58,12 @@ var (
5758
ErrExplodeInvalidUse = errors.NewKind(
5859
"using EXPLODE is not supported outside a Project node",
5960
)
61+
62+
// ErrSubqueryColumns is returned when an expression subquery returns
63+
// more than a single column.
64+
ErrSubqueryColumns = errors.NewKind(
65+
"subquery expressions can only return a single column",
66+
)
6067
)
6168

6269
// DefaultValidationRules to apply while analyzing nodes.
@@ -70,6 +77,7 @@ var DefaultValidationRules = []Rule{
7077
{validateCaseResultTypesRule, validateCaseResultTypes},
7178
{validateIntervalUsageRule, validateIntervalUsage},
7279
{validateExplodeUsageRule, validateExplodeUsage},
80+
{validateSubqueryColumnsRule, validateSubqueryColumns},
7381
}
7482

7583
func validateIsResolved(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
@@ -322,6 +330,25 @@ func validateExplodeUsage(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node,
322330
return n, nil
323331
}
324332

333+
func validateSubqueryColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
334+
valid := true
335+
plan.InspectExpressions(n, func(e sql.Expression) bool {
336+
s, ok := e.(*expression.Subquery)
337+
if ok && len(s.Query.Schema()) != 1 {
338+
valid = false
339+
return false
340+
}
341+
342+
return true
343+
})
344+
345+
if !valid {
346+
return nil, ErrSubqueryColumns.New()
347+
}
348+
349+
return n, nil
350+
}
351+
325352
func stringContains(strs []string, target string) bool {
326353
for _, s := range strs {
327354
if s == target {

Diff for: sql/analyzer/validation_rules_test.go

+31
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,37 @@ func TestValidateExplodeUsage(t *testing.T) {
674674
}
675675
}
676676

677+
func TestValidateSubqueryColumns(t *testing.T) {
678+
require := require.New(t)
679+
ctx := sql.NewEmptyContext()
680+
681+
node := plan.NewProject([]sql.Expression{
682+
expression.NewSubquery(plan.NewProject(
683+
[]sql.Expression{
684+
lit(1),
685+
lit(2),
686+
},
687+
dummyNode{true},
688+
)),
689+
}, dummyNode{true})
690+
691+
_, err := validateSubqueryColumns(ctx, nil, node)
692+
require.Error(err)
693+
require.True(ErrSubqueryColumns.Is(err))
694+
695+
node = plan.NewProject([]sql.Expression{
696+
expression.NewSubquery(plan.NewProject(
697+
[]sql.Expression{
698+
lit(1),
699+
},
700+
dummyNode{true},
701+
)),
702+
}, dummyNode{true})
703+
704+
_, err = validateSubqueryColumns(ctx, nil, node)
705+
require.NoError(err)
706+
}
707+
677708
type dummyNode struct{ resolved bool }
678709

679710
func (n dummyNode) String() string { return "dummynode" }

Diff for: sql/core.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package sql // import "github.com/src-d/go-mysql-server/sql"
1+
package sql
22

33
import (
44
"fmt"

Diff for: sql/expression/comparison.go

+56-2
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,6 @@ func (in *In) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
466466
return nil, err
467467
}
468468

469-
// TODO: support subqueries
470469
switch right := in.Right().(type) {
471470
case Tuple:
472471
for _, el := range right {
@@ -496,6 +495,34 @@ func (in *In) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
496495
}
497496
}
498497

498+
return false, nil
499+
case *Subquery:
500+
if leftElems > 1 {
501+
return nil, ErrInvalidOperandColumns.New(leftElems, 1)
502+
}
503+
504+
typ := right.Type()
505+
values, err := right.EvalMultiple(ctx)
506+
if err != nil {
507+
return nil, err
508+
}
509+
510+
for _, val := range values {
511+
val, err = typ.Convert(val)
512+
if err != nil {
513+
return nil, err
514+
}
515+
516+
cmp, err := typ.Compare(left, val)
517+
if err != nil {
518+
return nil, err
519+
}
520+
521+
if cmp == 0 {
522+
return true, nil
523+
}
524+
}
525+
499526
return false, nil
500527
default:
501528
return nil, ErrUnsupportedInOperand.New(right)
@@ -547,7 +574,6 @@ func (in *NotIn) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
547574
return nil, err
548575
}
549576

550-
// TODO: support subqueries
551577
switch right := in.Right().(type) {
552578
case Tuple:
553579
for _, el := range right {
@@ -577,6 +603,34 @@ func (in *NotIn) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
577603
}
578604
}
579605

606+
return true, nil
607+
case *Subquery:
608+
if leftElems > 1 {
609+
return nil, ErrInvalidOperandColumns.New(leftElems, 1)
610+
}
611+
612+
typ := right.Type()
613+
values, err := right.EvalMultiple(ctx)
614+
if err != nil {
615+
return nil, err
616+
}
617+
618+
for _, val := range values {
619+
val, err = typ.Convert(val)
620+
if err != nil {
621+
return nil, err
622+
}
623+
624+
cmp, err := typ.Compare(left, val)
625+
if err != nil {
626+
return nil, err
627+
}
628+
629+
if cmp == 0 {
630+
return false, nil
631+
}
632+
}
633+
580634
return true, nil
581635
default:
582636
return nil, ErrUnsupportedInOperand.New(right)

0 commit comments

Comments
 (0)