From dfa09452ca06ce03fe82393eedb8c8bbe076782d Mon Sep 17 00:00:00 2001 From: Miguel Molina Date: Tue, 4 Jun 2019 11:55:27 +0200 Subject: [PATCH] sql/expression: allow null literals in case branches Signed-off-by: Miguel Molina --- engine_test.go | 4 ++++ sql/analyzer/validation_rules.go | 6 +++--- sql/expression/case.go | 11 +++++++++-- sql/expression/case_test.go | 19 ++++++++++++++++++- 4 files changed, 34 insertions(+), 6 deletions(-) diff --git a/engine_test.go b/engine_test.go index 1b0d39cdc..dfda34bb2 100644 --- a/engine_test.go +++ b/engine_test.go @@ -1230,6 +1230,10 @@ var queries = []struct { `SELECT SUBSTR(SUBSTRING('0123456789ABCDEF', 1, 10), -4)`, []sql.Row{{"6789"}}, }, + { + `SELECT CASE i WHEN 1 THEN i ELSE NULL END FROM mytable`, + []sql.Row{{int64(1)}, {nil}, {nil}}, + }, } func TestQueries(t *testing.T) { diff --git a/sql/analyzer/validation_rules.go b/sql/analyzer/validation_rules.go index 06d1f13b6..e7448d9e2 100644 --- a/sql/analyzer/validation_rules.go +++ b/sql/analyzer/validation_rules.go @@ -3,11 +3,11 @@ package analyzer import ( "strings" - "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql/expression/function" "github.com/src-d/go-mysql-server/sql/plan" + "gopkg.in/src-d/go-errors.v1" ) const ( @@ -242,14 +242,14 @@ func validateCaseResultTypes(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Nod case *expression.Case: typ := e.Type() for _, b := range e.Branches { - if b.Value.Type() != typ { + if b.Value.Type() != typ && b.Value.Type() != sql.Null { err = ErrCaseResultType.New(typ, b.Value, b.Value.Type(), e) return false } } if e.Else != nil { - if e.Else.Type() != typ { + if e.Else.Type() != typ && e.Else.Type() != sql.Null { err = ErrCaseResultType.New(typ, e.Else, e.Else.Type(), e) return false } diff --git a/sql/expression/case.go b/sql/expression/case.go index 388f3a7ee..28eef6d03 100644 --- a/sql/expression/case.go +++ b/sql/expression/case.go @@ -28,9 +28,16 @@ func NewCase(expr sql.Expression, branches []CaseBranch, elseExpr sql.Expression // Type implements the sql.Expression interface. func (c *Case) Type() sql.Type { for _, b := range c.Branches { - return b.Value.Type() + if b.Value.Type() != sql.Null { + return b.Value.Type() + } + } + + if c.Else.Type() != sql.Null { + return c.Else.Type() } - return c.Else.Type() + + return sql.Null } // IsNullable implements the sql.Expression interface. diff --git a/sql/expression/case_test.go b/sql/expression/case_test.go index 6319db2d7..80a24aa47 100644 --- a/sql/expression/case_test.go +++ b/sql/expression/case_test.go @@ -3,8 +3,8 @@ package expression import ( "testing" - "github.com/stretchr/testify/require" "github.com/src-d/go-mysql-server/sql" + "github.com/stretchr/testify/require" ) func TestCase(t *testing.T) { @@ -127,3 +127,20 @@ func TestCase(t *testing.T) { }) } } + +func TestCaseNullBranch(t *testing.T) { + require := require.New(t) + f := NewCase( + NewGetField(0, sql.Int64, "x", false), + []CaseBranch{ + { + Cond: NewLiteral(int64(1), sql.Int64), + Value: NewLiteral(nil, sql.Null), + }, + }, + nil, + ) + result, err := f.Eval(sql.NewEmptyContext(), sql.Row{int64(1)}) + require.NoError(err) + require.Nil(result) +}