Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit 166881a

Browse files
committed
sql: implement count distinct
Signed-off-by: Miguel Molina <[email protected]>
1 parent 7f8224b commit 166881a

File tree

7 files changed

+196
-27
lines changed

7 files changed

+196
-27
lines changed

SUPPORTED.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
## Grouping expressions
2020
- AVG
21-
- COUNT
21+
- COUNT and COUNT(DISTINCT)
2222
- MAX
2323
- MIN
2424
- SUM (always returns DOUBLE)

engine_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,6 +1291,10 @@ var queries = []struct {
12911291
`SELECT LAST(i) FROM (SELECT i FROM mytable ORDER BY i) t`,
12921292
[]sql.Row{{int64(3)}},
12931293
},
1294+
{
1295+
`SELECT COUNT(DISTINCT t.i) FROM tabletest t, mytable t2`,
1296+
[]sql.Row{{int64(3)}},
1297+
},
12941298
}
12951299

12961300
func TestQueries(t *testing.T) {

sql/analyzer/resolve_having.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,11 @@ func aggregationEquals(a, b sql.Expression) bool {
383383
// the same.
384384
_, ok := b.(*aggregation.Count)
385385
return ok
386+
case *aggregation.CountDistinct:
387+
// it doesn't matter what's inside a Count, the result will be
388+
// the same.
389+
_, ok := b.(*aggregation.CountDistinct)
390+
return ok
386391
case *aggregation.Sum:
387392
b, ok := b.(*aggregation.Sum)
388393
if !ok {

sql/expression/function/aggregation/count.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package aggregation
33
import (
44
"fmt"
55

6+
"github.com/mitchellh/hashstructure"
67
"github.com/src-d/go-mysql-server/sql"
78
"github.com/src-d/go-mysql-server/sql/expression"
89
)
@@ -87,3 +88,93 @@ func (c *Count) Eval(ctx *sql.Context, buffer sql.Row) (interface{}, error) {
8788
count := buffer[0]
8889
return count, nil
8990
}
91+
92+
// CountDistinct node to count how many rows are in the result set.
93+
type CountDistinct struct {
94+
expression.UnaryExpression
95+
}
96+
97+
// NewCountDistinct creates a new CountDistinct node.
98+
func NewCountDistinct(e sql.Expression) *CountDistinct {
99+
return &CountDistinct{expression.UnaryExpression{Child: e}}
100+
}
101+
102+
// NewBuffer creates a new buffer for the aggregation.
103+
func (c *CountDistinct) NewBuffer() sql.Row {
104+
return sql.NewRow(make(map[uint64]struct{}))
105+
}
106+
107+
// Type returns the type of the result.
108+
func (c *CountDistinct) Type() sql.Type {
109+
return sql.Int64
110+
}
111+
112+
// IsNullable returns whether the return value can be null.
113+
func (c *CountDistinct) IsNullable() bool {
114+
return false
115+
}
116+
117+
// Resolved implements the Expression interface.
118+
func (c *CountDistinct) Resolved() bool {
119+
if _, ok := c.Child.(*expression.Star); ok {
120+
return true
121+
}
122+
123+
return c.Child.Resolved()
124+
}
125+
126+
func (c *CountDistinct) String() string {
127+
return fmt.Sprintf("COUNT(DISTINCT %s)", c.Child)
128+
}
129+
130+
// WithChildren implements the Expression interface.
131+
func (c *CountDistinct) WithChildren(children ...sql.Expression) (sql.Expression, error) {
132+
if len(children) != 1 {
133+
return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1)
134+
}
135+
return NewCountDistinct(children[0]), nil
136+
}
137+
138+
// Update implements the Aggregation interface.
139+
func (c *CountDistinct) Update(ctx *sql.Context, buffer, row sql.Row) error {
140+
seen := buffer[0].(map[uint64]struct{})
141+
var value interface{}
142+
if _, ok := c.Child.(*expression.Star); ok {
143+
value = row
144+
} else {
145+
v, err := c.Child.Eval(ctx, row)
146+
if v == nil {
147+
return nil
148+
}
149+
150+
if err != nil {
151+
return err
152+
}
153+
154+
value = v
155+
}
156+
157+
hash, err := hashstructure.Hash(value, nil)
158+
if err != nil {
159+
return fmt.Errorf("count distinct unable to hash value: %s", err)
160+
}
161+
162+
seen[hash] = struct{}{}
163+
164+
return nil
165+
}
166+
167+
// Merge implements the Aggregation interface.
168+
func (c *CountDistinct) Merge(ctx *sql.Context, buffer, partial sql.Row) error {
169+
seen := buffer[0].(map[uint64]struct{})
170+
for k := range partial[0].(map[uint64]struct{}) {
171+
seen[k] = struct{}{}
172+
}
173+
return nil
174+
}
175+
176+
// Eval implements the Aggregation interface.
177+
func (c *CountDistinct) Eval(ctx *sql.Context, buffer sql.Row) (interface{}, error) {
178+
seen := buffer[0].(map[uint64]struct{})
179+
return int64(len(seen)), nil
180+
}

sql/expression/function/aggregation/count_test.go

Lines changed: 71 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,12 @@ package aggregation
33
import (
44
"testing"
55

6-
"github.com/stretchr/testify/require"
76
"github.com/src-d/go-mysql-server/sql"
87
"github.com/src-d/go-mysql-server/sql/expression"
8+
"github.com/stretchr/testify/require"
99
)
1010

11-
func TestCount_String(t *testing.T) {
12-
require := require.New(t)
13-
14-
c := NewCount(expression.NewLiteral("foo", sql.Text))
15-
require.Equal(`COUNT("foo")`, c.String())
16-
}
17-
18-
func TestCount_Eval_1(t *testing.T) {
11+
func TestCountEval1(t *testing.T) {
1912
require := require.New(t)
2013
ctx := sql.NewEmptyContext()
2114

@@ -37,39 +30,96 @@ func TestCount_Eval_1(t *testing.T) {
3730
require.Equal(int64(7), eval(t, c, b))
3831
}
3932

40-
func TestCount_Eval_Star(t *testing.T) {
33+
func TestCountEvalStar(t *testing.T) {
4134
require := require.New(t)
4235
ctx := sql.NewEmptyContext()
4336

4437
c := NewCount(expression.NewStar())
4538
b := c.NewBuffer()
4639
require.Equal(int64(0), eval(t, c, b))
4740

48-
c.Update(ctx, b, nil)
49-
c.Update(ctx, b, sql.NewRow("foo"))
50-
c.Update(ctx, b, sql.NewRow(1))
51-
c.Update(ctx, b, sql.NewRow(nil))
52-
c.Update(ctx, b, sql.NewRow(1, 2, 3))
41+
require.NoError(c.Update(ctx, b, nil))
42+
require.NoError(c.Update(ctx, b, sql.NewRow("foo")))
43+
require.NoError(c.Update(ctx, b, sql.NewRow(1)))
44+
require.NoError(c.Update(ctx, b, sql.NewRow(nil)))
45+
require.NoError(c.Update(ctx, b, sql.NewRow(1, 2, 3)))
5346
require.Equal(int64(5), eval(t, c, b))
5447

5548
b2 := c.NewBuffer()
56-
c.Update(ctx, b2, sql.NewRow())
57-
c.Update(ctx, b2, sql.NewRow("foo"))
58-
c.Merge(ctx, b, b2)
49+
require.NoError(c.Update(ctx, b2, sql.NewRow()))
50+
require.NoError(c.Update(ctx, b2, sql.NewRow("foo")))
51+
require.NoError(c.Merge(ctx, b, b2))
5952
require.Equal(int64(7), eval(t, c, b))
6053
}
6154

62-
func TestCount_Eval_String(t *testing.T) {
55+
func TestCountEvalString(t *testing.T) {
6356
require := require.New(t)
6457
ctx := sql.NewEmptyContext()
6558

6659
c := NewCount(expression.NewGetField(0, sql.Text, "", true))
6760
b := c.NewBuffer()
6861
require.Equal(int64(0), eval(t, c, b))
6962

70-
c.Update(ctx, b, sql.NewRow("foo"))
63+
require.NoError(c.Update(ctx, b, sql.NewRow("foo")))
7164
require.Equal(int64(1), eval(t, c, b))
7265

73-
c.Update(ctx, b, sql.NewRow(nil))
66+
require.NoError(c.Update(ctx, b, sql.NewRow(nil)))
7467
require.Equal(int64(1), eval(t, c, b))
7568
}
69+
70+
func TestCountDistinctEval1(t *testing.T) {
71+
require := require.New(t)
72+
ctx := sql.NewEmptyContext()
73+
74+
c := NewCountDistinct(expression.NewLiteral(1, sql.Int32))
75+
b := c.NewBuffer()
76+
require.Equal(int64(0), eval(t, c, b))
77+
78+
require.NoError(c.Update(ctx, b, nil))
79+
require.NoError(c.Update(ctx, b, sql.NewRow("foo")))
80+
require.NoError(c.Update(ctx, b, sql.NewRow(1)))
81+
require.NoError(c.Update(ctx, b, sql.NewRow(nil)))
82+
require.NoError(c.Update(ctx, b, sql.NewRow(1, 2, 3)))
83+
require.Equal(int64(1), eval(t, c, b))
84+
}
85+
86+
func TestCountDistinctEvalStar(t *testing.T) {
87+
require := require.New(t)
88+
ctx := sql.NewEmptyContext()
89+
90+
c := NewCountDistinct(expression.NewStar())
91+
b := c.NewBuffer()
92+
require.Equal(int64(0), eval(t, c, b))
93+
94+
require.NoError(c.Update(ctx, b, nil))
95+
require.NoError(c.Update(ctx, b, sql.NewRow("foo")))
96+
require.NoError(c.Update(ctx, b, sql.NewRow(1)))
97+
require.NoError(c.Update(ctx, b, sql.NewRow(nil)))
98+
require.NoError(c.Update(ctx, b, sql.NewRow(1, 2, 3)))
99+
require.Equal(int64(5), eval(t, c, b))
100+
101+
b2 := c.NewBuffer()
102+
require.NoError(c.Update(ctx, b2, sql.NewRow(1)))
103+
require.NoError(c.Update(ctx, b2, sql.NewRow("foo")))
104+
require.NoError(c.Update(ctx, b2, sql.NewRow(5)))
105+
require.NoError(c.Merge(ctx, b, b2))
106+
107+
require.Equal(int64(6), eval(t, c, b))
108+
}
109+
110+
func TestCountDistinctEvalString(t *testing.T) {
111+
require := require.New(t)
112+
ctx := sql.NewEmptyContext()
113+
114+
c := NewCountDistinct(expression.NewGetField(0, sql.Text, "", true))
115+
b := c.NewBuffer()
116+
require.Equal(int64(0), eval(t, c, b))
117+
118+
require.NoError(c.Update(ctx, b, sql.NewRow("foo")))
119+
require.Equal(int64(1), eval(t, c, b))
120+
121+
require.NoError(c.Update(ctx, b, sql.NewRow(nil)))
122+
require.NoError(c.Update(ctx, b, sql.NewRow("foo")))
123+
require.NoError(c.Update(ctx, b, sql.NewRow("bar")))
124+
require.Equal(int64(2), eval(t, c, b))
125+
}

sql/parse/parse.go

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/src-d/go-mysql-server/sql"
1414
"github.com/src-d/go-mysql-server/sql/expression"
1515
"github.com/src-d/go-mysql-server/sql/expression/function"
16+
"github.com/src-d/go-mysql-server/sql/expression/function/aggregation"
1617
"github.com/src-d/go-mysql-server/sql/plan"
1718
"gopkg.in/src-d/go-errors.v1"
1819
"vitess.io/vitess/go/vt/sqlparser"
@@ -659,9 +660,11 @@ func getInt64Value(ctx *sql.Context, expr sqlparser.Expr, errStr string) (int64,
659660
func isAggregate(e sql.Expression) bool {
660661
var isAgg bool
661662
expression.Inspect(e, func(e sql.Expression) bool {
662-
fn, ok := e.(*expression.UnresolvedFunction)
663-
if ok {
664-
isAgg = isAgg || fn.IsAggregate
663+
switch e := e.(type) {
664+
case *expression.UnresolvedFunction:
665+
isAgg = isAgg || e.IsAggregate
666+
case *aggregation.CountDistinct:
667+
isAgg = true
665668
}
666669

667670
return true
@@ -789,7 +792,15 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) {
789792
}
790793

791794
if v.Distinct {
792-
return nil, ErrUnsupportedSyntax.New("DISTINCT on aggregations")
795+
if v.Name.Lowered() != "count" {
796+
return nil, ErrUnsupportedSyntax.New("DISTINCT on non-COUNT aggregations")
797+
}
798+
799+
if len(exprs) != 1 {
800+
return nil, ErrUnsupportedSyntax.New("more than one expression in COUNT")
801+
}
802+
803+
return aggregation.NewCountDistinct(exprs[0]), nil
793804
}
794805

795806
return expression.NewUnresolvedFunction(v.Name.Lowered(),

sql/parse/parse_test.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"testing"
55

66
"github.com/src-d/go-mysql-server/sql/expression"
7+
"github.com/src-d/go-mysql-server/sql/expression/function/aggregation"
78
"github.com/src-d/go-mysql-server/sql/plan"
89
"gopkg.in/src-d/go-errors.v1"
910

@@ -1158,6 +1159,13 @@ var fixtures = map[string]sql.Node{
11581159
[]sql.Expression{},
11591160
plan.NewUnresolvedTable("foo", ""),
11601161
),
1162+
`SELECT COUNT(DISTINCT i) FROM foo`: plan.NewGroupBy(
1163+
[]sql.Expression{
1164+
aggregation.NewCountDistinct(expression.NewUnresolvedColumn("i")),
1165+
},
1166+
[]sql.Expression{},
1167+
plan.NewUnresolvedTable("foo", ""),
1168+
),
11611169
}
11621170

11631171
func TestParse(t *testing.T) {
@@ -1191,7 +1199,7 @@ var fixturesErrors = map[string]*errors.Kind{
11911199
`SELECT '2018-05-01' / INTERVAL 1 DAY`: ErrUnsupportedSyntax,
11921200
`SELECT INTERVAL 1 DAY + INTERVAL 1 DAY`: ErrUnsupportedSyntax,
11931201
`SELECT '2018-05-01' + (INTERVAL 1 DAY + INTERVAL 1 DAY)`: ErrUnsupportedSyntax,
1194-
`SELECT COUNT(DISTINCT foo) FROM b`: ErrUnsupportedSyntax,
1202+
`SELECT AVG(DISTINCT foo) FROM b`: ErrUnsupportedSyntax,
11951203
}
11961204

11971205
func TestParseErrors(t *testing.T) {

0 commit comments

Comments
 (0)