Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit 1836395

Browse files
authoredOct 1, 2019
Fixed null errors during value comparisons (#831)
Fixed null errors during value comparisons
2 parents d161e2d + 6d0cfa6 commit 1836395

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed
 

Diff for: ‎sql/type.go

+52
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,10 @@ func (t numberT) Compare(a interface{}, b interface{}) (int, error) {
402402
func (t numberT) String() string { return t.t.String() }
403403

404404
func compareFloats(a interface{}, b interface{}) (int, error) {
405+
if hasNulls, res := compareNulls(a, b); hasNulls {
406+
return res, nil
407+
}
408+
405409
ca, err := cast.ToFloat64E(a)
406410
if err != nil {
407411
return 0, err
@@ -423,6 +427,10 @@ func compareFloats(a interface{}, b interface{}) (int, error) {
423427
}
424428

425429
func compareSignedInts(a interface{}, b interface{}) (int, error) {
430+
if hasNulls, res := compareNulls(a, b); hasNulls {
431+
return res, nil
432+
}
433+
426434
ca, err := cast.ToInt64E(a)
427435
if err != nil {
428436
return 0, err
@@ -444,6 +452,10 @@ func compareSignedInts(a interface{}, b interface{}) (int, error) {
444452
}
445453

446454
func compareUnsignedInts(a interface{}, b interface{}) (int, error) {
455+
if hasNulls, res := compareNulls(a, b); hasNulls {
456+
return res, nil
457+
}
458+
447459
ca, err := cast.ToUint64E(a)
448460
if err != nil {
449461
return 0, err
@@ -540,6 +552,10 @@ func (t timestampT) Convert(v interface{}) (interface{}, error) {
540552

541553
// Compare implements Type interface.
542554
func (t timestampT) Compare(a interface{}, b interface{}) (int, error) {
555+
if hasNulls, res := compareNulls(a, b); hasNulls {
556+
return res, nil
557+
}
558+
543559
av := a.(time.Time)
544560
bv := b.(time.Time)
545561
if av.Before(bv) {
@@ -603,6 +619,10 @@ func (t dateT) Convert(v interface{}) (interface{}, error) {
603619
}
604620

605621
func (t dateT) Compare(a, b interface{}) (int, error) {
622+
if hasNulls, res := compareNulls(a, b); hasNulls {
623+
return res, nil
624+
}
625+
606626
av := truncateDate(a.(time.Time))
607627
bv := truncateDate(b.(time.Time))
608628
if av.Before(bv) {
@@ -758,6 +778,9 @@ func (t varCharT) Convert(v interface{}) (interface{}, error) {
758778

759779
// Compare implements Type interface.
760780
func (t varCharT) Compare(a interface{}, b interface{}) (int, error) {
781+
if hasNulls, res := compareNulls(a, b); hasNulls {
782+
return res, nil
783+
}
761784
return strings.Compare(a.(string), b.(string)), nil
762785
}
763786

@@ -795,6 +818,9 @@ func (t textT) Convert(v interface{}) (interface{}, error) {
795818

796819
// Compare implements Type interface.
797820
func (t textT) Compare(a interface{}, b interface{}) (int, error) {
821+
if hasNulls, res := compareNulls(a, b); hasNulls {
822+
return res, nil
823+
}
798824
return strings.Compare(a.(string), b.(string)), nil
799825
}
800826

@@ -847,6 +873,10 @@ func (t booleanT) Convert(v interface{}) (interface{}, error) {
847873

848874
// Compare implements Type interface.
849875
func (t booleanT) Compare(a interface{}, b interface{}) (int, error) {
876+
if hasNulls, res := compareNulls(a, b); hasNulls {
877+
return res, nil
878+
}
879+
850880
if a == b {
851881
return 0, nil
852882
}
@@ -899,6 +929,9 @@ func (t blobT) Convert(v interface{}) (interface{}, error) {
899929

900930
// Compare implements Type interface.
901931
func (t blobT) Compare(a interface{}, b interface{}) (int, error) {
932+
if hasNulls, res := compareNulls(a, b); hasNulls {
933+
return res, nil
934+
}
902935
return bytes.Compare(a.([]byte), b.([]byte)), nil
903936
}
904937

@@ -941,6 +974,9 @@ func (t jsonT) Convert(v interface{}) (interface{}, error) {
941974

942975
// Compare implements Type interface.
943976
func (t jsonT) Compare(a interface{}, b interface{}) (int, error) {
977+
if hasNulls, res := compareNulls(a, b); hasNulls {
978+
return res, nil
979+
}
944980
return bytes.Compare(a.([]byte), b.([]byte)), nil
945981
}
946982

@@ -1302,3 +1338,19 @@ func convertArrayForJSON(t arrayT, v interface{}) (interface{}, error) {
13021338
return nil, ErrNotArray.New(v)
13031339
}
13041340
}
1341+
1342+
// compareNulls compares two values, and returns true if either is null.
1343+
// The returned integer represents the ordering, with a rule that states nulls
1344+
// as being ordered before non-nulls.
1345+
func compareNulls(a interface{}, b interface{}) (bool, int) {
1346+
aIsNull := a == nil
1347+
bIsNull := b == nil
1348+
if aIsNull && bIsNull {
1349+
return true, 0
1350+
} else if aIsNull && !bIsNull {
1351+
return true, -1
1352+
} else if !aIsNull && bIsNull {
1353+
return true, 1
1354+
}
1355+
return false, 0
1356+
}

Diff for: ‎sql/type_test.go

+40
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,46 @@ func TestJSONArraySQL(t *testing.T) {
444444
require.Equal(expected, string(val.Raw()))
445445
}
446446

447+
func TestComparesWithNulls(t *testing.T) {
448+
timeParse := func(layout string, value string) time.Time {
449+
t, err := time.Parse(layout, value)
450+
if err != nil {
451+
panic(err)
452+
}
453+
return t
454+
}
455+
456+
var typeVals = []struct {
457+
typ Type
458+
val interface{}
459+
}{
460+
{Int8, int8(0)},
461+
{Uint8, uint8(0)},
462+
{Int16, int16(0)},
463+
{Uint16, uint16(0)},
464+
{Int32, int32(0)},
465+
{Uint32, uint32(0)},
466+
{Int64, int64(0)},
467+
{Uint64, uint64(0)},
468+
{Float32, float32(0)},
469+
{Float64, float64(0)},
470+
{Timestamp, timeParse(TimestampLayout, "2132-04-05 12:51:36")},
471+
{Date, timeParse(DateLayout, "2231-11-07")},
472+
{Text, ""},
473+
{Boolean, false},
474+
{JSON, `{}`},
475+
{Blob, ""},
476+
}
477+
478+
for _, typeVal := range typeVals {
479+
t.Run(typeVal.typ.String(), func(t *testing.T) {
480+
lt(t, typeVal.typ, nil, typeVal.val)
481+
gt(t, typeVal.typ, typeVal.val, nil)
482+
eq(t, typeVal.typ, nil, nil)
483+
})
484+
}
485+
}
486+
447487
func eq(t *testing.T, typ Type, a, b interface{}) {
448488
t.Helper()
449489
cmp, err := typ.Compare(a, b)

0 commit comments

Comments
 (0)