Skip to content

Commit 206a86d

Browse files
authored
sql/plan: cross join implementation (src-d#58)
1 parent 9296559 commit 206a86d

File tree

2 files changed

+211
-0
lines changed

2 files changed

+211
-0
lines changed

Diff for: sql/plan/cross_join.go

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
package plan
2+
3+
import (
4+
"io"
5+
6+
"github.com/gitql/gitql/sql"
7+
)
8+
9+
type CrossJoin struct {
10+
BinaryNode
11+
}
12+
13+
func NewCrossJoin(left sql.Node, right sql.Node) *CrossJoin {
14+
return &CrossJoin{
15+
BinaryNode: BinaryNode{
16+
Left: left,
17+
Right: right,
18+
},
19+
}
20+
}
21+
22+
func (p *CrossJoin) Schema() sql.Schema {
23+
return append(p.Left.Schema(), p.Right.Schema()...)
24+
}
25+
26+
func (p *CrossJoin) Resolved() bool {
27+
return p.Left.Resolved() && p.Right.Resolved()
28+
}
29+
30+
func (p *CrossJoin) RowIter() (sql.RowIter, error) {
31+
li, err := p.Left.RowIter()
32+
if err != nil {
33+
return nil, err
34+
}
35+
36+
ri, err := p.Right.RowIter()
37+
if err != nil {
38+
return nil, err
39+
}
40+
41+
return &crossJoinIterator{
42+
li: li,
43+
ri: ri,
44+
}, nil
45+
}
46+
47+
func (p *CrossJoin) TransformUp(f func(sql.Node) sql.Node) sql.Node {
48+
ln := p.BinaryNode.Left.TransformUp(f)
49+
rn := p.BinaryNode.Right.TransformUp(f)
50+
51+
n := NewCrossJoin(ln, rn)
52+
53+
return f(n)
54+
}
55+
56+
func (p *CrossJoin) TransformExpressionsUp(f func(sql.Expression) sql.Expression) sql.Node {
57+
ln := p.BinaryNode.Left.TransformExpressionsUp(f)
58+
rn := p.BinaryNode.Right.TransformExpressionsUp(f)
59+
60+
return NewCrossJoin(ln, rn)
61+
}
62+
63+
type crossJoinIterator struct {
64+
li sql.RowIter
65+
ri sql.RowIter
66+
67+
// TODO use a method to reset right iterator in order to not duplicate rows into memory
68+
rightRows []sql.Row
69+
index int
70+
leftRow sql.Row
71+
}
72+
73+
func (i *crossJoinIterator) Next() (sql.Row, error) {
74+
for {
75+
if i.leftRow == nil {
76+
lr, err := i.li.Next()
77+
if err != nil {
78+
return nil, err
79+
}
80+
81+
i.leftRow = lr
82+
}
83+
84+
if len(i.rightRows) == 0 {
85+
err := i.fillRows()
86+
if err != nil && err != io.EOF {
87+
return nil, err
88+
}
89+
}
90+
91+
if i.index <= len(i.rightRows)-1 {
92+
fields := append(i.leftRow.Fields(), i.rightRows[i.index].Fields()...)
93+
i.index++
94+
95+
return sql.NewMemoryRow(fields...), nil
96+
}
97+
98+
i.index = 0
99+
i.leftRow = nil
100+
}
101+
}
102+
103+
func (i *crossJoinIterator) fillRows() error {
104+
for {
105+
rr, err := i.ri.Next()
106+
if err != nil {
107+
return err
108+
}
109+
110+
i.rightRows = append(i.rightRows, rr)
111+
}
112+
}

Diff for: sql/plan/cross_join_test.go

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
package plan
2+
3+
import (
4+
"io"
5+
"testing"
6+
7+
"github.com/gitql/gitql/mem"
8+
"github.com/gitql/gitql/sql"
9+
10+
"github.com/stretchr/testify/assert"
11+
)
12+
13+
func TestCrossJoin(t *testing.T) {
14+
assert := assert.New(t)
15+
lSchema := sql.Schema{
16+
sql.Field{"lcol1", sql.String},
17+
sql.Field{"lcol2", sql.String},
18+
sql.Field{"lcol3", sql.Integer},
19+
sql.Field{"lcol4", sql.BigInteger},
20+
}
21+
22+
rSchema := sql.Schema{
23+
sql.Field{"rcol1", sql.String},
24+
sql.Field{"rcol2", sql.String},
25+
sql.Field{"rcol3", sql.Integer},
26+
sql.Field{"rcol4", sql.BigInteger},
27+
}
28+
29+
resultSchema := sql.Schema{
30+
sql.Field{"lcol1", sql.String},
31+
sql.Field{"lcol2", sql.String},
32+
sql.Field{"lcol3", sql.Integer},
33+
sql.Field{"lcol4", sql.BigInteger},
34+
sql.Field{"rcol1", sql.String},
35+
sql.Field{"rcol2", sql.String},
36+
sql.Field{"rcol3", sql.Integer},
37+
sql.Field{"rcol4", sql.BigInteger},
38+
}
39+
40+
ltable := mem.NewTable("left", lSchema)
41+
rtable := mem.NewTable("right", rSchema)
42+
insertData(assert, ltable)
43+
insertData(assert, rtable)
44+
45+
j := NewCrossJoin(ltable, rtable)
46+
47+
assert.Equal(j.Schema(), resultSchema)
48+
49+
iter, err := j.RowIter()
50+
assert.Nil(err)
51+
assert.NotNil(iter)
52+
53+
row, err := iter.Next()
54+
assert.Nil(err)
55+
assert.NotNil(row)
56+
57+
assert.Equal(8, len(row.Fields()))
58+
59+
assert.Equal("col1_1", row.Fields()[0])
60+
assert.Equal("col2_1", row.Fields()[1])
61+
assert.Equal(int32(1111), row.Fields()[2])
62+
assert.Equal(int64(2222), row.Fields()[3])
63+
assert.Equal("col1_1", row.Fields()[4])
64+
assert.Equal("col2_1", row.Fields()[5])
65+
assert.Equal(int32(1111), row.Fields()[6])
66+
assert.Equal(int64(2222), row.Fields()[7])
67+
68+
row, err = iter.Next()
69+
assert.Nil(err)
70+
assert.NotNil(row)
71+
72+
assert.Equal("col1_1", row.Fields()[0])
73+
assert.Equal("col2_1", row.Fields()[1])
74+
assert.Equal(int32(1111), row.Fields()[2])
75+
assert.Equal(int64(2222), row.Fields()[3])
76+
assert.Equal("col1_2", row.Fields()[4])
77+
assert.Equal("col2_2", row.Fields()[5])
78+
assert.Equal(int32(3333), row.Fields()[6])
79+
assert.Equal(int64(4444), row.Fields()[7])
80+
81+
for i := 0; i < 2; i++ {
82+
row, err = iter.Next()
83+
assert.Nil(err)
84+
assert.NotNil(row)
85+
}
86+
87+
// total: 4 rows
88+
row, err = iter.Next()
89+
assert.NotNil(err)
90+
assert.Equal(err, io.EOF)
91+
assert.Nil(row)
92+
}
93+
94+
func insertData(assert *assert.Assertions, table *mem.Table) {
95+
err := table.Insert("col1_1", "col2_1", int32(1111), int64(2222))
96+
assert.Nil(err)
97+
err = table.Insert("col1_2", "col2_2", int32(3333), int64(4444))
98+
assert.Nil(err)
99+
}

0 commit comments

Comments
 (0)