diff --git a/sql/type.go b/sql/type.go index eaac23b21..caadd9254 100644 --- a/sql/type.go +++ b/sql/type.go @@ -402,6 +402,10 @@ func (t numberT) Compare(a interface{}, b interface{}) (int, error) { func (t numberT) String() string { return t.t.String() } func compareFloats(a interface{}, b interface{}) (int, error) { + if hasNulls, res := compareNulls(a, b); hasNulls { + return res, nil + } + ca, err := cast.ToFloat64E(a) if err != nil { return 0, err @@ -423,6 +427,10 @@ func compareFloats(a interface{}, b interface{}) (int, error) { } func compareSignedInts(a interface{}, b interface{}) (int, error) { + if hasNulls, res := compareNulls(a, b); hasNulls { + return res, nil + } + ca, err := cast.ToInt64E(a) if err != nil { return 0, err @@ -444,6 +452,10 @@ func compareSignedInts(a interface{}, b interface{}) (int, error) { } func compareUnsignedInts(a interface{}, b interface{}) (int, error) { + if hasNulls, res := compareNulls(a, b); hasNulls { + return res, nil + } + ca, err := cast.ToUint64E(a) if err != nil { return 0, err @@ -540,6 +552,10 @@ func (t timestampT) Convert(v interface{}) (interface{}, error) { // Compare implements Type interface. func (t timestampT) Compare(a interface{}, b interface{}) (int, error) { + if hasNulls, res := compareNulls(a, b); hasNulls { + return res, nil + } + av := a.(time.Time) bv := b.(time.Time) if av.Before(bv) { @@ -603,6 +619,10 @@ func (t dateT) Convert(v interface{}) (interface{}, error) { } func (t dateT) Compare(a, b interface{}) (int, error) { + if hasNulls, res := compareNulls(a, b); hasNulls { + return res, nil + } + av := truncateDate(a.(time.Time)) bv := truncateDate(b.(time.Time)) if av.Before(bv) { @@ -758,6 +778,9 @@ func (t varCharT) Convert(v interface{}) (interface{}, error) { // Compare implements Type interface. func (t varCharT) Compare(a interface{}, b interface{}) (int, error) { + if hasNulls, res := compareNulls(a, b); hasNulls { + return res, nil + } return strings.Compare(a.(string), b.(string)), nil } @@ -795,6 +818,9 @@ func (t textT) Convert(v interface{}) (interface{}, error) { // Compare implements Type interface. func (t textT) Compare(a interface{}, b interface{}) (int, error) { + if hasNulls, res := compareNulls(a, b); hasNulls { + return res, nil + } return strings.Compare(a.(string), b.(string)), nil } @@ -847,6 +873,10 @@ func (t booleanT) Convert(v interface{}) (interface{}, error) { // Compare implements Type interface. func (t booleanT) Compare(a interface{}, b interface{}) (int, error) { + if hasNulls, res := compareNulls(a, b); hasNulls { + return res, nil + } + if a == b { return 0, nil } @@ -899,6 +929,9 @@ func (t blobT) Convert(v interface{}) (interface{}, error) { // Compare implements Type interface. func (t blobT) Compare(a interface{}, b interface{}) (int, error) { + if hasNulls, res := compareNulls(a, b); hasNulls { + return res, nil + } return bytes.Compare(a.([]byte), b.([]byte)), nil } @@ -941,6 +974,9 @@ func (t jsonT) Convert(v interface{}) (interface{}, error) { // Compare implements Type interface. func (t jsonT) Compare(a interface{}, b interface{}) (int, error) { + if hasNulls, res := compareNulls(a, b); hasNulls { + return res, nil + } return bytes.Compare(a.([]byte), b.([]byte)), nil } @@ -1302,3 +1338,19 @@ func convertArrayForJSON(t arrayT, v interface{}) (interface{}, error) { return nil, ErrNotArray.New(v) } } + +// compareNulls compares two values, and returns true if either is null. +// The returned integer represents the ordering, with a rule that states nulls +// as being ordered before non-nulls. +func compareNulls(a interface{}, b interface{}) (bool, int) { + aIsNull := a == nil + bIsNull := b == nil + if aIsNull && bIsNull { + return true, 0 + } else if aIsNull && !bIsNull { + return true, -1 + } else if !aIsNull && bIsNull { + return true, 1 + } + return false, 0 +} diff --git a/sql/type_test.go b/sql/type_test.go index 1087fb423..a73bf08be 100644 --- a/sql/type_test.go +++ b/sql/type_test.go @@ -444,6 +444,46 @@ func TestJSONArraySQL(t *testing.T) { require.Equal(expected, string(val.Raw())) } +func TestComparesWithNulls(t *testing.T) { + timeParse := func(layout string, value string) time.Time { + t, err := time.Parse(layout, value) + if err != nil { + panic(err) + } + return t + } + + var typeVals = []struct { + typ Type + val interface{} + }{ + {Int8, int8(0)}, + {Uint8, uint8(0)}, + {Int16, int16(0)}, + {Uint16, uint16(0)}, + {Int32, int32(0)}, + {Uint32, uint32(0)}, + {Int64, int64(0)}, + {Uint64, uint64(0)}, + {Float32, float32(0)}, + {Float64, float64(0)}, + {Timestamp, timeParse(TimestampLayout, "2132-04-05 12:51:36")}, + {Date, timeParse(DateLayout, "2231-11-07")}, + {Text, ""}, + {Boolean, false}, + {JSON, `{}`}, + {Blob, ""}, + } + + for _, typeVal := range typeVals { + t.Run(typeVal.typ.String(), func(t *testing.T) { + lt(t, typeVal.typ, nil, typeVal.val) + gt(t, typeVal.typ, typeVal.val, nil) + eq(t, typeVal.typ, nil, nil) + }) + } +} + func eq(t *testing.T, typ Type, a, b interface{}) { t.Helper() cmp, err := typ.Compare(a, b)