Skip to content

Commit 391980c

Browse files
authored
Support compare functions with SortSlices and SortMaps (#367)
The SortSlices and SortMaps options predate generics and accept an interface{}, so it is possible with reflection to support other function signatures than "func(T, T) bool". In particular, the Go ecosystem is increasingly moving towards "func(T, T) int" as the signature for ordering as evidenced by the newer slices.SortFunc function in stdlib. Thus, modernize cmpopts by supporting "func(T, T) int". Also, bump the minimum version to Go 1.21 to match the minimum supported version of google.golang.org/protobuf. Fixes #365
1 parent c3ad843 commit 391980c

File tree

5 files changed

+87
-36
lines changed

5 files changed

+87
-36
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ jobs:
66
test:
77
strategy:
88
matrix:
9-
go-version: [1.18.x, 1.19.x, 1.20.x, 1.21.x]
9+
go-version: [1.21.x]
1010
os: [ubuntu-latest, macos-latest]
1111
runs-on: ${{ matrix.os }}
1212
steps:

cmp/cmpopts/sort.go

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,29 @@ import (
1414
)
1515

1616
// SortSlices returns a [cmp.Transformer] option that sorts all []V.
17-
// The less function must be of the form "func(T, T) bool" which is used to
18-
// sort any slice with element type V that is assignable to T.
17+
// The lessOrCompareFunc function must be either
18+
// a less function of the form "func(T, T) bool" or
19+
// a compare function of the format "func(T, T) int"
20+
// which is used to sort any slice with element type V that is assignable to T.
1921
//
20-
// The less function must be:
22+
// A less function must be:
2123
// - Deterministic: less(x, y) == less(x, y)
2224
// - Irreflexive: !less(x, x)
2325
// - Transitive: if !less(x, y) and !less(y, z), then !less(x, z)
2426
//
25-
// The less function does not have to be "total". That is, if !less(x, y) and
26-
// !less(y, x) for two elements x and y, their relative order is maintained.
27+
// A compare function must be:
28+
// - Deterministic: compare(x, y) == compare(x, y)
29+
// - Irreflexive: compare(x, x) == 0
30+
// - Transitive: if !less(x, y) and !less(y, z), then !less(x, z)
31+
//
32+
// The function does not have to be "total". That is, if x != y, but
33+
// less or compare report inequality, their relative order is maintained.
2734
//
2835
// SortSlices can be used in conjunction with [EquateEmpty].
29-
func SortSlices(lessFunc interface{}) cmp.Option {
30-
vf := reflect.ValueOf(lessFunc)
31-
if !function.IsType(vf.Type(), function.Less) || vf.IsNil() {
32-
panic(fmt.Sprintf("invalid less function: %T", lessFunc))
36+
func SortSlices(lessOrCompareFunc interface{}) cmp.Option {
37+
vf := reflect.ValueOf(lessOrCompareFunc)
38+
if (!function.IsType(vf.Type(), function.Less) && !function.IsType(vf.Type(), function.Compare)) || vf.IsNil() {
39+
panic(fmt.Sprintf("invalid less or compare function: %T", lessOrCompareFunc))
3340
}
3441
ss := sliceSorter{vf.Type().In(0), vf}
3542
return cmp.FilterValues(ss.filter, cmp.Transformer("cmpopts.SortSlices", ss.sort))
@@ -79,28 +86,40 @@ func (ss sliceSorter) checkSort(v reflect.Value) {
7986
}
8087
func (ss sliceSorter) less(v reflect.Value, i, j int) bool {
8188
vx, vy := v.Index(i), v.Index(j)
82-
return ss.fnc.Call([]reflect.Value{vx, vy})[0].Bool()
89+
vo := ss.fnc.Call([]reflect.Value{vx, vy})[0]
90+
if vo.Kind() == reflect.Bool {
91+
return vo.Bool()
92+
} else {
93+
return vo.Int() < 0
94+
}
8395
}
8496

85-
// SortMaps returns a [cmp.Transformer] option that flattens map[K]V types to be a
86-
// sorted []struct{K, V}. The less function must be of the form
87-
// "func(T, T) bool" which is used to sort any map with key K that is
88-
// assignable to T.
97+
// SortMaps returns a [cmp.Transformer] option that flattens map[K]V types to be
98+
// a sorted []struct{K, V}. The lessOrCompareFunc function must be either
99+
// a less function of the form "func(T, T) bool" or
100+
// a compare function of the format "func(T, T) int"
101+
// which is used to sort any map with key K that is assignable to T.
89102
//
90103
// Flattening the map into a slice has the property that [cmp.Equal] is able to
91104
// use [cmp.Comparer] options on K or the K.Equal method if it exists.
92105
//
93-
// The less function must be:
106+
// A less function must be:
94107
// - Deterministic: less(x, y) == less(x, y)
95108
// - Irreflexive: !less(x, x)
96109
// - Transitive: if !less(x, y) and !less(y, z), then !less(x, z)
97110
// - Total: if x != y, then either less(x, y) or less(y, x)
98111
//
112+
// A compare function must be:
113+
// - Deterministic: compare(x, y) == compare(x, y)
114+
// - Irreflexive: compare(x, x) == 0
115+
// - Transitive: if compare(x, y) < 0 and compare(y, z) < 0, then compare(x, z) < 0
116+
// - Total: if x != y, then compare(x, y) != 0
117+
//
99118
// SortMaps can be used in conjunction with [EquateEmpty].
100-
func SortMaps(lessFunc interface{}) cmp.Option {
101-
vf := reflect.ValueOf(lessFunc)
102-
if !function.IsType(vf.Type(), function.Less) || vf.IsNil() {
103-
panic(fmt.Sprintf("invalid less function: %T", lessFunc))
119+
func SortMaps(lessOrCompareFunc interface{}) cmp.Option {
120+
vf := reflect.ValueOf(lessOrCompareFunc)
121+
if (!function.IsType(vf.Type(), function.Less) && !function.IsType(vf.Type(), function.Compare)) || vf.IsNil() {
122+
panic(fmt.Sprintf("invalid less or compare function: %T", lessOrCompareFunc))
104123
}
105124
ms := mapSorter{vf.Type().In(0), vf}
106125
return cmp.FilterValues(ms.filter, cmp.Transformer("cmpopts.SortMaps", ms.sort))
@@ -143,5 +162,10 @@ func (ms mapSorter) checkSort(v reflect.Value) {
143162
}
144163
func (ms mapSorter) less(v reflect.Value, i, j int) bool {
145164
vx, vy := v.Index(i).Field(0), v.Index(j).Field(0)
146-
return ms.fnc.Call([]reflect.Value{vx, vy})[0].Bool()
165+
vo := ms.fnc.Call([]reflect.Value{vx, vy})[0]
166+
if vo.Kind() == reflect.Bool {
167+
return vo.Bool()
168+
} else {
169+
return vo.Int() < 0
170+
}
147171
}

cmp/cmpopts/util_test.go

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,23 @@ func TestOptions(t *testing.T) {
130130
opts: []cmp.Option{SortSlices(func(x, y int) bool { return x < y })},
131131
wantEqual: true,
132132
reason: "equal because SortSlices sorts the slices",
133+
}, {
134+
label: "SortSlices",
135+
x: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
136+
y: []int{1, 0, 5, 2, 8, 9, 4, 3, 6, 7},
137+
opts: []cmp.Option{SortSlices(func(x, y int) int {
138+
// TODO(Go1.22): Use cmp.Compare.
139+
switch {
140+
case x < y:
141+
return -1
142+
case y > x:
143+
return +1
144+
default:
145+
return 0
146+
}
147+
})},
148+
wantEqual: true,
149+
reason: "equal because SortSlices sorts the slices",
133150
}, {
134151
label: "SortSlices",
135152
x: []MyInt{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
@@ -201,6 +218,21 @@ func TestOptions(t *testing.T) {
201218
opts: []cmp.Option{SortMaps(func(x, y time.Time) bool { return x.Before(y) })},
202219
wantEqual: true,
203220
reason: "equal because SortMaps flattens to a slice where Time.Equal can be used",
221+
}, {
222+
label: "SortMaps",
223+
x: map[time.Time]string{
224+
time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC): "0th birthday",
225+
time.Date(2010, time.November, 10, 23, 0, 0, 0, time.UTC): "1st birthday",
226+
time.Date(2011, time.November, 10, 23, 0, 0, 0, time.UTC): "2nd birthday",
227+
},
228+
y: map[time.Time]string{
229+
time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC).In(time.Local): "0th birthday",
230+
time.Date(2010, time.November, 10, 23, 0, 0, 0, time.UTC).In(time.Local): "1st birthday",
231+
time.Date(2011, time.November, 10, 23, 0, 0, 0, time.UTC).In(time.Local): "2nd birthday",
232+
},
233+
opts: []cmp.Option{SortMaps(func(x, y time.Time) int { return time.Time.Compare(x, y) })},
234+
wantEqual: true,
235+
reason: "equal because SortMaps flattens to a slice where Time.Equal can be used",
204236
}, {
205237
label: "SortMaps",
206238
x: map[MyTime]string{
@@ -1184,29 +1216,17 @@ func TestPanic(t *testing.T) {
11841216
args: args(time.Duration(-1)),
11851217
wantPanic: "margin must be a non-negative number",
11861218
reason: "negative duration is invalid",
1187-
}, {
1188-
label: "SortSlices",
1189-
fnc: SortSlices,
1190-
args: args(strings.Compare),
1191-
wantPanic: "invalid less function",
1192-
reason: "func(x, y string) int is wrong signature for less",
11931219
}, {
11941220
label: "SortSlices",
11951221
fnc: SortSlices,
11961222
args: args((func(_, _ int) bool)(nil)),
1197-
wantPanic: "invalid less function",
1223+
wantPanic: "invalid less or compare function",
11981224
reason: "nil value is not valid",
1199-
}, {
1200-
label: "SortMaps",
1201-
fnc: SortMaps,
1202-
args: args(strings.Compare),
1203-
wantPanic: "invalid less function",
1204-
reason: "func(x, y string) int is wrong signature for less",
12051225
}, {
12061226
label: "SortMaps",
12071227
fnc: SortMaps,
12081228
args: args((func(_, _ int) bool)(nil)),
1209-
wantPanic: "invalid less function",
1229+
wantPanic: "invalid less or compare function",
12101230
reason: "nil value is not valid",
12111231
}, {
12121232
label: "IgnoreFields",

cmp/internal/function/func.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ const (
1919

2020
tbFunc // func(T) bool
2121
ttbFunc // func(T, T) bool
22+
ttiFunc // func(T, T) int
2223
trbFunc // func(T, R) bool
2324
tibFunc // func(T, I) bool
2425
trFunc // func(T) R
@@ -28,11 +29,13 @@ const (
2829
Transformer = trFunc // func(T) R
2930
ValueFilter = ttbFunc // func(T, T) bool
3031
Less = ttbFunc // func(T, T) bool
32+
Compare = ttiFunc // func(T, T) int
3133
ValuePredicate = tbFunc // func(T) bool
3234
KeyValuePredicate = trbFunc // func(T, R) bool
3335
)
3436

3537
var boolType = reflect.TypeOf(true)
38+
var intType = reflect.TypeOf(0)
3639

3740
// IsType reports whether the reflect.Type is of the specified function type.
3841
func IsType(t reflect.Type, ft funcType) bool {
@@ -49,6 +52,10 @@ func IsType(t reflect.Type, ft funcType) bool {
4952
if ni == 2 && no == 1 && t.In(0) == t.In(1) && t.Out(0) == boolType {
5053
return true
5154
}
55+
case ttiFunc: // func(T, T) int
56+
if ni == 2 && no == 1 && t.In(0) == t.In(1) && t.Out(0) == intType {
57+
return true
58+
}
5259
case trbFunc: // func(T, R) bool
5360
if ni == 2 && no == 1 && t.Out(0) == boolType {
5461
return true

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
module github.com/google/go-cmp
22

3-
go 1.13
3+
go 1.21

0 commit comments

Comments
 (0)