Skip to content

Commit 37accb4

Browse files
aliforeverndyakov
andauthored
fix: nil pointer dereferencing in writeArg (#3271)
* fixed bug with nil dereferencing in writeArg, added hset struct example, added tests * removed password from example * added omitempty * reverted xxhash versioning * reverted xxhash versioning * removed password * removed password --------- Co-authored-by: Nedyalko Dyakov <[email protected]>
1 parent 747190e commit 37accb4

File tree

7 files changed

+274
-29
lines changed

7 files changed

+274
-29
lines changed

Diff for: example/hset-struct/README.md

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Example for setting struct fields as hash fields
2+
3+
To run this example:
4+
5+
```shell
6+
go run .
7+
```

Diff for: example/hset-struct/go.mod

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
module github.com/redis/go-redis/example/scan-struct
2+
3+
go 1.18
4+
5+
replace github.com/redis/go-redis/v9 => ../..
6+
7+
require (
8+
github.com/davecgh/go-spew v1.1.1
9+
github.com/redis/go-redis/v9 v9.6.2
10+
)
11+
12+
require (
13+
github.com/cespare/xxhash/v2 v2.2.0 // indirect
14+
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
15+
)

Diff for: example/hset-struct/go.sum

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
2+
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
3+
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
4+
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
5+
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
6+
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
7+
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
8+
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
9+
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
10+
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=

Diff for: example/hset-struct/main.go

+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"time"
6+
7+
"github.com/davecgh/go-spew/spew"
8+
9+
"github.com/redis/go-redis/v9"
10+
)
11+
12+
type Model struct {
13+
Str1 string `redis:"str1"`
14+
Str2 string `redis:"str2"`
15+
Str3 *string `redis:"str3"`
16+
Str4 *string `redis:"str4"`
17+
Bytes []byte `redis:"bytes"`
18+
Int int `redis:"int"`
19+
Int2 *int `redis:"int2"`
20+
Int3 *int `redis:"int3"`
21+
Bool bool `redis:"bool"`
22+
Bool2 *bool `redis:"bool2"`
23+
Bool3 *bool `redis:"bool3"`
24+
Bool4 *bool `redis:"bool4,omitempty"`
25+
Time time.Time `redis:"time"`
26+
Time2 *time.Time `redis:"time2"`
27+
Time3 *time.Time `redis:"time3"`
28+
Ignored struct{} `redis:"-"`
29+
}
30+
31+
func main() {
32+
ctx := context.Background()
33+
34+
rdb := redis.NewClient(&redis.Options{
35+
Addr: ":6379",
36+
})
37+
38+
_ = rdb.FlushDB(ctx).Err()
39+
40+
t := time.Date(2025, 02, 8, 0, 0, 0, 0, time.UTC)
41+
42+
data := Model{
43+
Str1: "hello",
44+
Str2: "world",
45+
Str3: ToPtr("hello"),
46+
Str4: nil,
47+
Bytes: []byte("this is bytes !"),
48+
Int: 123,
49+
Int2: ToPtr(0),
50+
Int3: nil,
51+
Bool: true,
52+
Bool2: ToPtr(false),
53+
Bool3: nil,
54+
Time: t,
55+
Time2: ToPtr(t),
56+
Time3: nil,
57+
Ignored: struct{}{},
58+
}
59+
60+
// Set some fields.
61+
if _, err := rdb.Pipelined(ctx, func(rdb redis.Pipeliner) error {
62+
rdb.HMSet(ctx, "key", data)
63+
return nil
64+
}); err != nil {
65+
panic(err)
66+
}
67+
68+
var model1, model2 Model
69+
70+
// Scan all fields into the model.
71+
if err := rdb.HGetAll(ctx, "key").Scan(&model1); err != nil {
72+
panic(err)
73+
}
74+
75+
// Or scan a subset of the fields.
76+
if err := rdb.HMGet(ctx, "key", "str1", "int").Scan(&model2); err != nil {
77+
panic(err)
78+
}
79+
80+
spew.Dump(model1)
81+
// Output:
82+
// (main.Model) {
83+
// Str1: (string) (len=5) "hello",
84+
// Str2: (string) (len=5) "world",
85+
// Str3: (*string)(0xc000016970)((len=5) "hello"),
86+
// Str4: (*string)(0xc000016980)(""),
87+
// Bytes: ([]uint8) (len=15 cap=16) {
88+
// 00000000 74 68 69 73 20 69 73 20 62 79 74 65 73 20 21 |this is bytes !|
89+
// },
90+
// Int: (int) 123,
91+
// Int2: (*int)(0xc000014568)(0),
92+
// Int3: (*int)(0xc000014560)(0),
93+
// Bool: (bool) true,
94+
// Bool2: (*bool)(0xc000014570)(false),
95+
// Bool3: (*bool)(0xc000014548)(false),
96+
// Bool4: (*bool)(<nil>),
97+
// Time: (time.Time) 2025-02-08 00:00:00 +0000 UTC,
98+
// Time2: (*time.Time)(0xc0000122a0)(2025-02-08 00:00:00 +0000 UTC),
99+
// Time3: (*time.Time)(0xc000012288)(0001-01-01 00:00:00 +0000 UTC),
100+
// Ignored: (struct {}) {
101+
// }
102+
// }
103+
104+
spew.Dump(model2)
105+
// Output:
106+
// (main.Model) {
107+
// Str1: (string) (len=5) "hello",
108+
// Str2: (string) "",
109+
// Str3: (*string)(<nil>),
110+
// Str4: (*string)(<nil>),
111+
// Bytes: ([]uint8) <nil>,
112+
// Int: (int) 123,
113+
// Int2: (*int)(<nil>),
114+
// Int3: (*int)(<nil>),
115+
// Bool: (bool) false,
116+
// Bool2: (*bool)(<nil>),
117+
// Bool3: (*bool)(<nil>),
118+
// Bool4: (*bool)(<nil>),
119+
// Time: (time.Time) 0001-01-01 00:00:00 +0000 UTC,
120+
// Time2: (*time.Time)(<nil>),
121+
// Time3: (*time.Time)(<nil>),
122+
// Ignored: (struct {}) {
123+
// }
124+
// }
125+
}
126+
127+
func ToPtr[T any](v T) *T {
128+
return &v
129+
}

Diff for: example/scan-struct/main.go

+6
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@ import (
1111
type Model struct {
1212
Str1 string `redis:"str1"`
1313
Str2 string `redis:"str2"`
14+
Str3 *string `redis:"str3"`
1415
Bytes []byte `redis:"bytes"`
1516
Int int `redis:"int"`
17+
Int2 *int `redis:"int2"`
1618
Bool bool `redis:"bool"`
19+
Bool2 *bool `redis:"bool2"`
1720
Ignored struct{} `redis:"-"`
1821
}
1922

@@ -29,8 +32,11 @@ func main() {
2932
if _, err := rdb.Pipelined(ctx, func(rdb redis.Pipeliner) error {
3033
rdb.HSet(ctx, "key", "str1", "hello")
3134
rdb.HSet(ctx, "key", "str2", "world")
35+
rdb.HSet(ctx, "key", "str3", "")
3236
rdb.HSet(ctx, "key", "int", 123)
37+
rdb.HSet(ctx, "key", "int2", 0)
3338
rdb.HSet(ctx, "key", "bool", 1)
39+
rdb.HSet(ctx, "key", "bool2", 0)
3440
rdb.HSet(ctx, "key", "bytes", []byte("this is bytes !"))
3541
return nil
3642
}); err != nil {

Diff for: internal/proto/writer.go

+53
Original file line numberDiff line numberDiff line change
@@ -66,72 +66,125 @@ func (w *Writer) WriteArg(v interface{}) error {
6666
case string:
6767
return w.string(v)
6868
case *string:
69+
if v == nil {
70+
return w.string("")
71+
}
6972
return w.string(*v)
7073
case []byte:
7174
return w.bytes(v)
7275
case int:
7376
return w.int(int64(v))
7477
case *int:
78+
if v == nil {
79+
return w.int(0)
80+
}
7581
return w.int(int64(*v))
7682
case int8:
7783
return w.int(int64(v))
7884
case *int8:
85+
if v == nil {
86+
return w.int(0)
87+
}
7988
return w.int(int64(*v))
8089
case int16:
8190
return w.int(int64(v))
8291
case *int16:
92+
if v == nil {
93+
return w.int(0)
94+
}
8395
return w.int(int64(*v))
8496
case int32:
8597
return w.int(int64(v))
8698
case *int32:
99+
if v == nil {
100+
return w.int(0)
101+
}
87102
return w.int(int64(*v))
88103
case int64:
89104
return w.int(v)
90105
case *int64:
106+
if v == nil {
107+
return w.int(0)
108+
}
91109
return w.int(*v)
92110
case uint:
93111
return w.uint(uint64(v))
94112
case *uint:
113+
if v == nil {
114+
return w.uint(0)
115+
}
95116
return w.uint(uint64(*v))
96117
case uint8:
97118
return w.uint(uint64(v))
98119
case *uint8:
120+
if v == nil {
121+
return w.string("")
122+
}
99123
return w.uint(uint64(*v))
100124
case uint16:
101125
return w.uint(uint64(v))
102126
case *uint16:
127+
if v == nil {
128+
return w.uint(0)
129+
}
103130
return w.uint(uint64(*v))
104131
case uint32:
105132
return w.uint(uint64(v))
106133
case *uint32:
134+
if v == nil {
135+
return w.uint(0)
136+
}
107137
return w.uint(uint64(*v))
108138
case uint64:
109139
return w.uint(v)
110140
case *uint64:
141+
if v == nil {
142+
return w.uint(0)
143+
}
111144
return w.uint(*v)
112145
case float32:
113146
return w.float(float64(v))
114147
case *float32:
148+
if v == nil {
149+
return w.float(0)
150+
}
115151
return w.float(float64(*v))
116152
case float64:
117153
return w.float(v)
118154
case *float64:
155+
if v == nil {
156+
return w.float(0)
157+
}
119158
return w.float(*v)
120159
case bool:
121160
if v {
122161
return w.int(1)
123162
}
124163
return w.int(0)
125164
case *bool:
165+
if v == nil {
166+
return w.int(0)
167+
}
126168
if *v {
127169
return w.int(1)
128170
}
129171
return w.int(0)
130172
case time.Time:
131173
w.numBuf = v.AppendFormat(w.numBuf[:0], time.RFC3339Nano)
132174
return w.bytes(w.numBuf)
175+
case *time.Time:
176+
if v == nil {
177+
v = &time.Time{}
178+
}
179+
w.numBuf = v.AppendFormat(w.numBuf[:0], time.RFC3339Nano)
180+
return w.bytes(w.numBuf)
133181
case time.Duration:
134182
return w.int(v.Nanoseconds())
183+
case *time.Duration:
184+
if v == nil {
185+
return w.int(0)
186+
}
187+
return w.int(v.Nanoseconds())
135188
case encoding.BinaryMarshaler:
136189
b, err := v.MarshalBinary()
137190
if err != nil {

0 commit comments

Comments
 (0)