diff --git a/sql/plan/cross_join.go b/sql/plan/cross_join.go new file mode 100644 index 000000000..5704792a7 --- /dev/null +++ b/sql/plan/cross_join.go @@ -0,0 +1,112 @@ +package plan + +import ( + "io" + + "github.com/gitql/gitql/sql" +) + +type CrossJoin struct { + BinaryNode +} + +func NewCrossJoin(left sql.Node, right sql.Node) *CrossJoin { + return &CrossJoin{ + BinaryNode: BinaryNode{ + Left: left, + Right: right, + }, + } +} + +func (p *CrossJoin) Schema() sql.Schema { + return append(p.Left.Schema(), p.Right.Schema()...) +} + +func (p *CrossJoin) Resolved() bool { + return p.Left.Resolved() && p.Right.Resolved() +} + +func (p *CrossJoin) RowIter() (sql.RowIter, error) { + li, err := p.Left.RowIter() + if err != nil { + return nil, err + } + + ri, err := p.Right.RowIter() + if err != nil { + return nil, err + } + + return &crossJoinIterator{ + li: li, + ri: ri, + }, nil +} + +func (p *CrossJoin) TransformUp(f func(sql.Node) sql.Node) sql.Node { + ln := p.BinaryNode.Left.TransformUp(f) + rn := p.BinaryNode.Right.TransformUp(f) + + n := NewCrossJoin(ln, rn) + + return f(n) +} + +func (p *CrossJoin) TransformExpressionsUp(f func(sql.Expression) sql.Expression) sql.Node { + ln := p.BinaryNode.Left.TransformExpressionsUp(f) + rn := p.BinaryNode.Right.TransformExpressionsUp(f) + + return NewCrossJoin(ln, rn) +} + +type crossJoinIterator struct { + li sql.RowIter + ri sql.RowIter + + // TODO use a method to reset right iterator in order to not duplicate rows into memory + rightRows []sql.Row + index int + leftRow sql.Row +} + +func (i *crossJoinIterator) Next() (sql.Row, error) { + for { + if i.leftRow == nil { + lr, err := i.li.Next() + if err != nil { + return nil, err + } + + i.leftRow = lr + } + + if len(i.rightRows) == 0 { + err := i.fillRows() + if err != nil && err != io.EOF { + return nil, err + } + } + + if i.index <= len(i.rightRows)-1 { + fields := append(i.leftRow.Fields(), i.rightRows[i.index].Fields()...) + i.index++ + + return sql.NewMemoryRow(fields...), nil + } + + i.index = 0 + i.leftRow = nil + } +} + +func (i *crossJoinIterator) fillRows() error { + for { + rr, err := i.ri.Next() + if err != nil { + return err + } + + i.rightRows = append(i.rightRows, rr) + } +} diff --git a/sql/plan/cross_join_test.go b/sql/plan/cross_join_test.go new file mode 100644 index 000000000..34d6873a4 --- /dev/null +++ b/sql/plan/cross_join_test.go @@ -0,0 +1,99 @@ +package plan + +import ( + "io" + "testing" + + "github.com/gitql/gitql/mem" + "github.com/gitql/gitql/sql" + + "github.com/stretchr/testify/assert" +) + +func TestCrossJoin(t *testing.T) { + assert := assert.New(t) + lSchema := sql.Schema{ + sql.Field{"lcol1", sql.String}, + sql.Field{"lcol2", sql.String}, + sql.Field{"lcol3", sql.Integer}, + sql.Field{"lcol4", sql.BigInteger}, + } + + rSchema := sql.Schema{ + sql.Field{"rcol1", sql.String}, + sql.Field{"rcol2", sql.String}, + sql.Field{"rcol3", sql.Integer}, + sql.Field{"rcol4", sql.BigInteger}, + } + + resultSchema := sql.Schema{ + sql.Field{"lcol1", sql.String}, + sql.Field{"lcol2", sql.String}, + sql.Field{"lcol3", sql.Integer}, + sql.Field{"lcol4", sql.BigInteger}, + sql.Field{"rcol1", sql.String}, + sql.Field{"rcol2", sql.String}, + sql.Field{"rcol3", sql.Integer}, + sql.Field{"rcol4", sql.BigInteger}, + } + + ltable := mem.NewTable("left", lSchema) + rtable := mem.NewTable("right", rSchema) + insertData(assert, ltable) + insertData(assert, rtable) + + j := NewCrossJoin(ltable, rtable) + + assert.Equal(j.Schema(), resultSchema) + + iter, err := j.RowIter() + assert.Nil(err) + assert.NotNil(iter) + + row, err := iter.Next() + assert.Nil(err) + assert.NotNil(row) + + assert.Equal(8, len(row.Fields())) + + assert.Equal("col1_1", row.Fields()[0]) + assert.Equal("col2_1", row.Fields()[1]) + assert.Equal(int32(1111), row.Fields()[2]) + assert.Equal(int64(2222), row.Fields()[3]) + assert.Equal("col1_1", row.Fields()[4]) + assert.Equal("col2_1", row.Fields()[5]) + assert.Equal(int32(1111), row.Fields()[6]) + assert.Equal(int64(2222), row.Fields()[7]) + + row, err = iter.Next() + assert.Nil(err) + assert.NotNil(row) + + assert.Equal("col1_1", row.Fields()[0]) + assert.Equal("col2_1", row.Fields()[1]) + assert.Equal(int32(1111), row.Fields()[2]) + assert.Equal(int64(2222), row.Fields()[3]) + assert.Equal("col1_2", row.Fields()[4]) + assert.Equal("col2_2", row.Fields()[5]) + assert.Equal(int32(3333), row.Fields()[6]) + assert.Equal(int64(4444), row.Fields()[7]) + + for i := 0; i < 2; i++ { + row, err = iter.Next() + assert.Nil(err) + assert.NotNil(row) + } + + // total: 4 rows + row, err = iter.Next() + assert.NotNil(err) + assert.Equal(err, io.EOF) + assert.Nil(row) +} + +func insertData(assert *assert.Assertions, table *mem.Table) { + err := table.Insert("col1_1", "col2_1", int32(1111), int64(2222)) + assert.Nil(err) + err = table.Insert("col1_2", "col2_2", int32(3333), int64(4444)) + assert.Nil(err) +}