Skip to content

Commit f3acb58

Browse files
authored
sql: support parsing and executing CROSS JOIN. (#62)
* added support to parse CROSS JOIN. * improved CROSS JOIN code. * added tests for CROSS JOIN with one empty side. * added tests for CROSS JOIN parsing and analysis.
1 parent 003dc36 commit f3acb58

File tree

6 files changed

+125
-34
lines changed

6 files changed

+125
-34
lines changed

sql/analyzer/analyzer_test.go

+23
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@ func TestAnalyzer_Analyze(t *testing.T) {
1717
assert := require.New(t)
1818

1919
table := mem.NewTable("mytable", sql.Schema{{"i", sql.Integer}})
20+
table2 := mem.NewTable("mytable2", sql.Schema{{"i2", sql.Integer}})
2021
db := mem.NewDatabase("mydb")
2122
db.AddTable("mytable", table)
23+
db.AddTable("mytable2", table2)
2224

2325
catalog := &sql.Catalog{Databases: []sql.Database{db}}
2426
a := analyzer.New(catalog)
@@ -95,6 +97,27 @@ func TestAnalyzer_Analyze(t *testing.T) {
9597
)
9698
assert.Nil(err)
9799
assert.Equal(expected, analyzed)
100+
101+
notAnalyzed = plan.NewProject(
102+
[]sql.Expression{
103+
expression.NewUnresolvedColumn("i"),
104+
expression.NewUnresolvedColumn("i2"),
105+
},
106+
plan.NewCrossJoin(
107+
plan.NewUnresolvedTable("mytable"),
108+
plan.NewUnresolvedTable("mytable2"),
109+
),
110+
)
111+
analyzed, err = a.Analyze(notAnalyzed)
112+
expected = plan.NewProject(
113+
[]sql.Expression{
114+
expression.NewGetField(0, sql.Integer, "i"),
115+
expression.NewGetField(1, sql.Integer, "i2"),
116+
},
117+
plan.NewCrossJoin(table, table2),
118+
)
119+
assert.Nil(err)
120+
assert.Equal(expected, analyzed)
98121
}
99122

100123
func TestAnalyzer_Analyze_MaxIterations(t *testing.T) {

sql/analyzer/rules.go

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ func resolveColumns(a *Analyzer, n sql.Node) sql.Node {
3939

4040
child := n.Children()[0]
4141

42+
//TODO: Fail when there is no unambiguous resolution.
4243
colMap := map[string]*expression.GetField{}
4344
for idx, child := range child.Schema() {
4445
colMap[child.Name] = expression.NewGetField(idx, child.Type, child.Name)

sql/parse/parse.go

+26-3
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,34 @@ func convertSelect(s *sqlparser.Select) (sql.Node, error) {
9393
}
9494

9595
func tableExprsToTable(te sqlparser.TableExprs) (sql.Node, error) {
96-
if len(te) != 1 {
97-
return nil, errUnsupportedFeature("more than one table")
96+
if len(te) == 0 {
97+
return nil, errUnsupportedFeature("zero tables in FROM")
9898
}
9999

100-
switch t := (te[0]).(type) {
100+
var nodes []sql.Node
101+
for _, t := range te {
102+
n, err := tableExprToTable(t)
103+
if err != nil {
104+
return nil, err
105+
}
106+
107+
nodes = append(nodes, n)
108+
}
109+
110+
if len(nodes) == 1 {
111+
return nodes[0], nil
112+
}
113+
114+
if len(nodes) == 2 {
115+
return plan.NewCrossJoin(nodes[0], nodes[1]), nil
116+
}
117+
118+
//TODO: Support N tables in JOIN.
119+
return nil, errUnsupportedFeature("more than 2 tables in JOIN")
120+
}
121+
122+
func tableExprToTable(te sqlparser.TableExpr) (sql.Node, error) {
123+
switch t := (te).(type) {
101124
default:
102125
return nil, errUnsupported(te)
103126
case *sqlparser.AliasedTableExpr:

sql/parse/parse_test.go

+10
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,16 @@ var fixtures = map[string]sql.Node{
117117
),
118118
),
119119
),
120+
`SELECT foo, bar FROM t1, t2;`: plan.NewProject(
121+
[]sql.Expression{
122+
expression.NewUnresolvedColumn("foo"),
123+
expression.NewUnresolvedColumn("bar"),
124+
},
125+
plan.NewCrossJoin(
126+
plan.NewUnresolvedTable("t1"),
127+
plan.NewUnresolvedTable("t2"),
128+
),
129+
),
120130
}
121131

122132
func TestParse(t *testing.T) {

sql/plan/cross_join.go

+18-17
Original file line numberDiff line numberDiff line change
@@ -71,33 +71,34 @@ type crossJoinIterator struct {
7171
}
7272

7373
func (i *crossJoinIterator) Next() (sql.Row, error) {
74-
for {
75-
if i.leftRow == nil {
76-
lr, err := i.li.Next()
77-
if err != nil {
78-
return nil, err
79-
}
80-
81-
i.leftRow = lr
74+
if len(i.rightRows) == 0 {
75+
if err := i.fillRows(); err != io.EOF {
76+
return nil, err
8277
}
8378

8479
if len(i.rightRows) == 0 {
85-
err := i.fillRows()
86-
if err != nil && err != io.EOF {
87-
return nil, err
88-
}
80+
return nil, io.EOF
8981
}
82+
}
9083

91-
if i.index <= len(i.rightRows)-1 {
92-
fields := append(i.leftRow.Fields(), i.rightRows[i.index].Fields()...)
93-
i.index++
94-
95-
return sql.NewMemoryRow(fields...), nil
84+
if i.leftRow == nil {
85+
lr, err := i.li.Next()
86+
if err != nil {
87+
return nil, err
9688
}
9789

90+
i.index = 0
91+
i.leftRow = lr
92+
}
93+
94+
fields := append(i.leftRow.Fields(), i.rightRows[i.index].Fields()...)
95+
i.index++
96+
if i.index >= len(i.rightRows) {
9897
i.index = 0
9998
i.leftRow = nil
10099
}
100+
101+
return sql.NewMemoryRow(fields...), nil
101102
}
102103

103104
func (i *crossJoinIterator) fillRows() error {

sql/plan/cross_join_test.go

+47-14
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,22 @@ import (
1010
"github.com/stretchr/testify/assert"
1111
)
1212

13+
var lSchema = sql.Schema{
14+
sql.Field{"lcol1", sql.String},
15+
sql.Field{"lcol2", sql.String},
16+
sql.Field{"lcol3", sql.Integer},
17+
sql.Field{"lcol4", sql.BigInteger},
18+
}
19+
20+
var rSchema = sql.Schema{
21+
sql.Field{"rcol1", sql.String},
22+
sql.Field{"rcol2", sql.String},
23+
sql.Field{"rcol3", sql.Integer},
24+
sql.Field{"rcol4", sql.BigInteger},
25+
}
26+
1327
func TestCrossJoin(t *testing.T) {
1428
assert := assert.New(t)
15-
lSchema := sql.Schema{
16-
sql.Field{"lcol1", sql.String},
17-
sql.Field{"lcol2", sql.String},
18-
sql.Field{"lcol3", sql.Integer},
19-
sql.Field{"lcol4", sql.BigInteger},
20-
}
21-
22-
rSchema := sql.Schema{
23-
sql.Field{"rcol1", sql.String},
24-
sql.Field{"rcol2", sql.String},
25-
sql.Field{"rcol3", sql.Integer},
26-
sql.Field{"rcol4", sql.BigInteger},
27-
}
2829

2930
resultSchema := sql.Schema{
3031
sql.Field{"lcol1", sql.String},
@@ -44,7 +45,7 @@ func TestCrossJoin(t *testing.T) {
4445

4546
j := NewCrossJoin(ltable, rtable)
4647

47-
assert.Equal(j.Schema(), resultSchema)
48+
assert.Equal(resultSchema, j.Schema())
4849

4950
iter, err := j.RowIter()
5051
assert.Nil(err)
@@ -91,6 +92,38 @@ func TestCrossJoin(t *testing.T) {
9192
assert.Nil(row)
9293
}
9394

95+
func TestCrossJoin_Empty(t *testing.T) {
96+
assert := assert.New(t)
97+
98+
ltable := mem.NewTable("left", lSchema)
99+
rtable := mem.NewTable("right", rSchema)
100+
insertData(assert, ltable)
101+
102+
j := NewCrossJoin(ltable, rtable)
103+
104+
iter, err := j.RowIter()
105+
assert.Nil(err)
106+
assert.NotNil(iter)
107+
108+
row, err := iter.Next()
109+
assert.Equal(io.EOF, err)
110+
assert.Nil(row)
111+
112+
ltable = mem.NewTable("left", lSchema)
113+
rtable = mem.NewTable("right", rSchema)
114+
insertData(assert, rtable)
115+
116+
j = NewCrossJoin(ltable, rtable)
117+
118+
iter, err = j.RowIter()
119+
assert.Nil(err)
120+
assert.NotNil(iter)
121+
122+
row, err = iter.Next()
123+
assert.Equal(io.EOF, err)
124+
assert.Nil(row)
125+
}
126+
94127
func insertData(assert *assert.Assertions, table *mem.Table) {
95128
err := table.Insert("col1_1", "col2_1", int32(1111), int64(2222))
96129
assert.Nil(err)

0 commit comments

Comments
 (0)