Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit 8405ee2

Browse files
authored
sql/expression: allow null literals in case branches (#741)
sql/expression: allow null literals in case branches
2 parents 26a0ec9 + dfa0945 commit 8405ee2

File tree

4 files changed

+34
-6
lines changed

4 files changed

+34
-6
lines changed

Diff for: engine_test.go

+4
Original file line numberDiff line numberDiff line change
@@ -1230,6 +1230,10 @@ var queries = []struct {
12301230
`SELECT SUBSTR(SUBSTRING('0123456789ABCDEF', 1, 10), -4)`,
12311231
[]sql.Row{{"6789"}},
12321232
},
1233+
{
1234+
`SELECT CASE i WHEN 1 THEN i ELSE NULL END FROM mytable`,
1235+
[]sql.Row{{int64(1)}, {nil}, {nil}},
1236+
},
12331237
}
12341238

12351239
func TestQueries(t *testing.T) {

Diff for: sql/analyzer/validation_rules.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ package analyzer
33
import (
44
"strings"
55

6-
"gopkg.in/src-d/go-errors.v1"
76
"github.com/src-d/go-mysql-server/sql"
87
"github.com/src-d/go-mysql-server/sql/expression"
98
"github.com/src-d/go-mysql-server/sql/expression/function"
109
"github.com/src-d/go-mysql-server/sql/plan"
10+
"gopkg.in/src-d/go-errors.v1"
1111
)
1212

1313
const (
@@ -242,14 +242,14 @@ func validateCaseResultTypes(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Nod
242242
case *expression.Case:
243243
typ := e.Type()
244244
for _, b := range e.Branches {
245-
if b.Value.Type() != typ {
245+
if b.Value.Type() != typ && b.Value.Type() != sql.Null {
246246
err = ErrCaseResultType.New(typ, b.Value, b.Value.Type(), e)
247247
return false
248248
}
249249
}
250250

251251
if e.Else != nil {
252-
if e.Else.Type() != typ {
252+
if e.Else.Type() != typ && e.Else.Type() != sql.Null {
253253
err = ErrCaseResultType.New(typ, e.Else, e.Else.Type(), e)
254254
return false
255255
}

Diff for: sql/expression/case.go

+9-2
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,16 @@ func NewCase(expr sql.Expression, branches []CaseBranch, elseExpr sql.Expression
2828
// Type implements the sql.Expression interface.
2929
func (c *Case) Type() sql.Type {
3030
for _, b := range c.Branches {
31-
return b.Value.Type()
31+
if b.Value.Type() != sql.Null {
32+
return b.Value.Type()
33+
}
34+
}
35+
36+
if c.Else.Type() != sql.Null {
37+
return c.Else.Type()
3238
}
33-
return c.Else.Type()
39+
40+
return sql.Null
3441
}
3542

3643
// IsNullable implements the sql.Expression interface.

Diff for: sql/expression/case_test.go

+18-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ package expression
33
import (
44
"testing"
55

6-
"github.com/stretchr/testify/require"
76
"github.com/src-d/go-mysql-server/sql"
7+
"github.com/stretchr/testify/require"
88
)
99

1010
func TestCase(t *testing.T) {
@@ -127,3 +127,20 @@ func TestCase(t *testing.T) {
127127
})
128128
}
129129
}
130+
131+
func TestCaseNullBranch(t *testing.T) {
132+
require := require.New(t)
133+
f := NewCase(
134+
NewGetField(0, sql.Int64, "x", false),
135+
[]CaseBranch{
136+
{
137+
Cond: NewLiteral(int64(1), sql.Int64),
138+
Value: NewLiteral(nil, sql.Null),
139+
},
140+
},
141+
nil,
142+
)
143+
result, err := f.Eval(sql.NewEmptyContext(), sql.Row{int64(1)})
144+
require.NoError(err)
145+
require.Nil(result)
146+
}

0 commit comments

Comments
 (0)