Skip to content

fix: nil pointer dereferencing in writeArg #3271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 20, 2025
7 changes: 7 additions & 0 deletions example/hset-struct/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Example for setting struct fields as hash fields

To run this example:

```shell
go run .
```
15 changes: 15 additions & 0 deletions example/hset-struct/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module github.com/redis/go-redis/example/scan-struct

go 1.18

replace github.com/redis/go-redis/v9 => ../..

require (
github.com/davecgh/go-spew v1.1.1
github.com/redis/go-redis/v9 v9.6.2
)

require (
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
)
10 changes: 10 additions & 0 deletions example/hset-struct/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
129 changes: 129 additions & 0 deletions example/hset-struct/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package main

import (
"context"
"time"

"github.com/davecgh/go-spew/spew"

"github.com/redis/go-redis/v9"
)

type Model struct {
Str1 string `redis:"str1"`
Str2 string `redis:"str2"`
Str3 *string `redis:"str3"`
Str4 *string `redis:"str4"`
Bytes []byte `redis:"bytes"`
Int int `redis:"int"`
Int2 *int `redis:"int2"`
Int3 *int `redis:"int3"`
Bool bool `redis:"bool"`
Bool2 *bool `redis:"bool2"`
Bool3 *bool `redis:"bool3"`
Bool4 *bool `redis:"bool4,omitempty"`
Time time.Time `redis:"time"`
Time2 *time.Time `redis:"time2"`
Time3 *time.Time `redis:"time3"`
Ignored struct{} `redis:"-"`
}

func main() {
ctx := context.Background()

rdb := redis.NewClient(&redis.Options{
Addr: ":6379",
})

_ = rdb.FlushDB(ctx).Err()

t := time.Date(2025, 02, 8, 0, 0, 0, 0, time.UTC)

data := Model{
Str1: "hello",
Str2: "world",
Str3: ToPtr("hello"),
Str4: nil,
Bytes: []byte("this is bytes !"),
Int: 123,
Int2: ToPtr(0),
Int3: nil,
Bool: true,
Bool2: ToPtr(false),
Bool3: nil,
Time: t,
Time2: ToPtr(t),
Time3: nil,
Ignored: struct{}{},
}

// Set some fields.
if _, err := rdb.Pipelined(ctx, func(rdb redis.Pipeliner) error {
rdb.HMSet(ctx, "key", data)
return nil
}); err != nil {
panic(err)
}

var model1, model2 Model

// Scan all fields into the model.
if err := rdb.HGetAll(ctx, "key").Scan(&model1); err != nil {
panic(err)
}

// Or scan a subset of the fields.
if err := rdb.HMGet(ctx, "key", "str1", "int").Scan(&model2); err != nil {
panic(err)
}

spew.Dump(model1)
// Output:
// (main.Model) {
// Str1: (string) (len=5) "hello",
// Str2: (string) (len=5) "world",
// Str3: (*string)(0xc000016970)((len=5) "hello"),
// Str4: (*string)(0xc000016980)(""),
// Bytes: ([]uint8) (len=15 cap=16) {
// 00000000 74 68 69 73 20 69 73 20 62 79 74 65 73 20 21 |this is bytes !|
// },
// Int: (int) 123,
// Int2: (*int)(0xc000014568)(0),
// Int3: (*int)(0xc000014560)(0),
// Bool: (bool) true,
// Bool2: (*bool)(0xc000014570)(false),
// Bool3: (*bool)(0xc000014548)(false),
// Bool4: (*bool)(<nil>),
// Time: (time.Time) 2025-02-08 00:00:00 +0000 UTC,
// Time2: (*time.Time)(0xc0000122a0)(2025-02-08 00:00:00 +0000 UTC),
// Time3: (*time.Time)(0xc000012288)(0001-01-01 00:00:00 +0000 UTC),
// Ignored: (struct {}) {
// }
// }

spew.Dump(model2)
// Output:
// (main.Model) {
// Str1: (string) (len=5) "hello",
// Str2: (string) "",
// Str3: (*string)(<nil>),
// Str4: (*string)(<nil>),
// Bytes: ([]uint8) <nil>,
// Int: (int) 123,
// Int2: (*int)(<nil>),
// Int3: (*int)(<nil>),
// Bool: (bool) false,
// Bool2: (*bool)(<nil>),
// Bool3: (*bool)(<nil>),
// Bool4: (*bool)(<nil>),
// Time: (time.Time) 0001-01-01 00:00:00 +0000 UTC,
// Time2: (*time.Time)(<nil>),
// Time3: (*time.Time)(<nil>),
// Ignored: (struct {}) {
// }
// }
}

func ToPtr[T any](v T) *T {
return &v
}
6 changes: 6 additions & 0 deletions example/scan-struct/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@ import (
type Model struct {
Str1 string `redis:"str1"`
Str2 string `redis:"str2"`
Str3 *string `redis:"str3"`
Bytes []byte `redis:"bytes"`
Int int `redis:"int"`
Int2 *int `redis:"int2"`
Bool bool `redis:"bool"`
Bool2 *bool `redis:"bool2"`
Ignored struct{} `redis:"-"`
}

Expand All @@ -29,8 +32,11 @@ func main() {
if _, err := rdb.Pipelined(ctx, func(rdb redis.Pipeliner) error {
rdb.HSet(ctx, "key", "str1", "hello")
rdb.HSet(ctx, "key", "str2", "world")
rdb.HSet(ctx, "key", "str3", "")
rdb.HSet(ctx, "key", "int", 123)
rdb.HSet(ctx, "key", "int2", 0)
rdb.HSet(ctx, "key", "bool", 1)
rdb.HSet(ctx, "key", "bool2", 0)
rdb.HSet(ctx, "key", "bytes", []byte("this is bytes !"))
return nil
}); err != nil {
Expand Down
53 changes: 53 additions & 0 deletions internal/proto/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,72 +66,125 @@ func (w *Writer) WriteArg(v interface{}) error {
case string:
return w.string(v)
case *string:
if v == nil {
return w.string("")
}
return w.string(*v)
case []byte:
return w.bytes(v)
case int:
return w.int(int64(v))
case *int:
if v == nil {
return w.int(0)
}
return w.int(int64(*v))
case int8:
return w.int(int64(v))
case *int8:
if v == nil {
return w.int(0)
}
return w.int(int64(*v))
case int16:
return w.int(int64(v))
case *int16:
if v == nil {
return w.int(0)
}
return w.int(int64(*v))
case int32:
return w.int(int64(v))
case *int32:
if v == nil {
return w.int(0)
}
return w.int(int64(*v))
case int64:
return w.int(v)
case *int64:
if v == nil {
return w.int(0)
}
return w.int(*v)
case uint:
return w.uint(uint64(v))
case *uint:
if v == nil {
return w.uint(0)
}
return w.uint(uint64(*v))
case uint8:
return w.uint(uint64(v))
case *uint8:
if v == nil {
return w.string("")
}
return w.uint(uint64(*v))
case uint16:
return w.uint(uint64(v))
case *uint16:
if v == nil {
return w.uint(0)
}
return w.uint(uint64(*v))
case uint32:
return w.uint(uint64(v))
case *uint32:
if v == nil {
return w.uint(0)
}
return w.uint(uint64(*v))
case uint64:
return w.uint(v)
case *uint64:
if v == nil {
return w.uint(0)
}
return w.uint(*v)
case float32:
return w.float(float64(v))
case *float32:
if v == nil {
return w.float(0)
}
return w.float(float64(*v))
case float64:
return w.float(v)
case *float64:
if v == nil {
return w.float(0)
}
return w.float(*v)
case bool:
if v {
return w.int(1)
}
return w.int(0)
case *bool:
if v == nil {
return w.int(0)
}
if *v {
return w.int(1)
}
return w.int(0)
case time.Time:
w.numBuf = v.AppendFormat(w.numBuf[:0], time.RFC3339Nano)
return w.bytes(w.numBuf)
case *time.Time:
if v == nil {
v = &time.Time{}
}
w.numBuf = v.AppendFormat(w.numBuf[:0], time.RFC3339Nano)
return w.bytes(w.numBuf)
case time.Duration:
return w.int(v.Nanoseconds())
case *time.Duration:
if v == nil {
return w.int(0)
}
return w.int(v.Nanoseconds())
case encoding.BinaryMarshaler:
b, err := v.MarshalBinary()
if err != nil {
Expand Down
Loading
Loading