From 5410adb9dc483e2544dc498e571c15d6df8840f5 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Tue, 18 Mar 2025 11:07:14 +0200 Subject: [PATCH 1/7] wip --- auth/auth.go | 39 ++++++++++++++++ options.go | 125 ++++++++++++++++++++++++++++++++++----------------- redis.go | 50 ++++++++++++--------- 3 files changed, 152 insertions(+), 62 deletions(-) create mode 100644 auth/auth.go diff --git a/auth/auth.go b/auth/auth.go new file mode 100644 index 0000000000..d5834fc516 --- /dev/null +++ b/auth/auth.go @@ -0,0 +1,39 @@ +package auth + +type StreamingCredentialsProvider interface { + // Subscribe subscribes to the credentials provider and returns a channel that will receive updates. + // The first response is blocking, then data will be pushed to the channel. + Subscribe(listener CredentialsListener) (Credentials, CancelProviderFunc, error) +} + +type CancelProviderFunc func() error + +type CredentialsListener interface { + OnNext(credentials Credentials) + OnError(err error) +} + +type Credentials interface { + BasicAuth() (username string, password string) + RawCredentials() string +} + +type basicAuth struct { + username string + password string +} + +func (b *basicAuth) RawCredentials() string { + return b.username + ":" + b.password +} + +func (b *basicAuth) BasicAuth() (username string, password string) { + return b.username, b.password +} + +func NewCredentials(username, password string) Credentials { + return &basicAuth{ + username: username, + password: password, + } +} diff --git a/options.go b/options.go index a350a02f9b..3f0661d031 100644 --- a/options.go +++ b/options.go @@ -29,10 +29,13 @@ type Limiter interface { // Options keeps the settings to set up redis connection. type Options struct { - // The network type, either tcp or unix. - // Default is tcp. + + // Network type, either tcp or unix. + // + // default: is tcp. Network string - // host:port address. + + // Addr is the address formated as host:port Addr string // ClientName will execute the `CLIENT SETNAME ClientName` command for each conn. @@ -42,21 +45,25 @@ type Options struct { // Network and Addr options. Dialer func(ctx context.Context, network, addr string) (net.Conn, error) - // Hook that is called when new connection is established. + // OnConnect Hook that is called when new connection is established. OnConnect func(ctx context.Context, cn *Conn) error // Protocol 2 or 3. Use the version to negotiate RESP version with redis-server. - // Default is 3. + // + // default: 3. Protocol int - // Use the specified Username to authenticate the current connection + + // Username is used to authenticate the current connection // with one of the connections defined in the ACL list when connecting // to a Redis 6.0 instance, or greater, that is using the Redis ACL system. Username string - // Optional password. Must match the password specified in the - // requirepass server configuration option (if connecting to a Redis 5.0 instance, or lower), + + // Password is an optional password. Must match the password specified in the + // `requirepass` server configuration option (if connecting to a Redis 5.0 instance, or lower), // or the User Password when connecting to a Redis 6.0 instance, or greater, // that is using the Redis ACL system. Password string + // CredentialsProvider allows the username and password to be updated // before reconnecting. It should return the current username and password. CredentialsProvider func() (username string, password string) @@ -67,94 +74,128 @@ type Options struct { // There will be a conflict between them; if CredentialsProviderContext exists, we will ignore CredentialsProvider. CredentialsProviderContext func(ctx context.Context) (username string, password string, err error) - // Database to be selected after connecting to the server. + // DB is the database to be selected after connecting to the server. DB int - // Maximum number of retries before giving up. - // Default is 3 retries; -1 (not 0) disables retries. + // MaxRetries is the maximum number of retries before giving up. + // -1 (not 0) disables retries. + // + // default: 3 retries MaxRetries int - // Minimum backoff between each retry. - // Default is 8 milliseconds; -1 disables backoff. + + // MinRetryBackoff is the minimum backoff between each retry. + // -1 disables backoff. + // + // default: 8 milliseconds MinRetryBackoff time.Duration - // Maximum backoff between each retry. - // Default is 512 milliseconds; -1 disables backoff. + + // MaxRetryBackoff is the maximum backoff between each retry. + // -1 disables backoff. + // default: 512 milliseconds; MaxRetryBackoff time.Duration - // Dial timeout for establishing new connections. - // Default is 5 seconds. + // DialTimeout for establishing new connections. + // + // default: 5 seconds DialTimeout time.Duration - // Timeout for socket reads. If reached, commands will fail + + // ReadTimeout for socket reads. If reached, commands will fail // with a timeout instead of blocking. Supported values: - // - `0` - default timeout (3 seconds). - // - `-1` - no timeout (block indefinitely). - // - `-2` - disables SetReadDeadline calls completely. + // + // - `-1` - no timeout (block indefinitely). + // - `-2` - disables SetReadDeadline calls completely. + // + // default: 3 seconds ReadTimeout time.Duration - // Timeout for socket writes. If reached, commands will fail + + // WriteTimeout for socket writes. If reached, commands will fail // with a timeout instead of blocking. Supported values: - // - `0` - default timeout (3 seconds). - // - `-1` - no timeout (block indefinitely). - // - `-2` - disables SetWriteDeadline calls completely. + // + // - `-1` - no timeout (block indefinitely). + // - `-2` - disables SetWriteDeadline calls completely. + // + // default: 3 seconds WriteTimeout time.Duration + // ContextTimeoutEnabled controls whether the client respects context timeouts and deadlines. // See https://redis.uptrace.dev/guide/go-redis-debugging.html#timeouts ContextTimeoutEnabled bool - // Type of connection pool. - // true for FIFO pool, false for LIFO pool. + // PoolFIFO type of connection pool. + // + // - true for FIFO pool + // - false for LIFO pool. + // // Note that FIFO has slightly higher overhead compared to LIFO, // but it helps closing idle connections faster reducing the pool size. PoolFIFO bool - // Base number of socket connections. + + // PoolSize is the base number of socket connections. // Default is 10 connections per every available CPU as reported by runtime.GOMAXPROCS. // If there is not enough connections in the pool, new connections will be allocated in excess of PoolSize, // you can limit it through MaxActiveConns + // + // default: 10 * runtime.GOMAXPROCS(0) PoolSize int - // Amount of time client waits for connection if all connections + + // PoolTimeout is the amount of time client waits for connection if all connections // are busy before returning an error. - // Default is ReadTimeout + 1 second. + // + // default: ReadTimeout + 1 second PoolTimeout time.Duration - // Minimum number of idle connections which is useful when establishing - // new connection is slow. - // Default is 0. the idle connections are not closed by default. + + // MinIdleConns is the minimum number of idle connections which is useful when establishing + // new connection is slow. The idle connections are not closed by default. + // + // default: 0 MinIdleConns int - // Maximum number of idle connections. - // Default is 0. the idle connections are not closed by default. + + // MaxIdleConns is the maximum number of idle connections. + // The idle connections are not closed by default. + // + // default: 0 MaxIdleConns int - // Maximum number of connections allocated by the pool at a given time. + + // MaxActiveConns is the maximum number of connections allocated by the pool at a given time. // When zero, there is no limit on the number of connections in the pool. + // If the pool is full, the next call to Get() will block until a connection is released. MaxActiveConns int + // ConnMaxIdleTime is the maximum amount of time a connection may be idle. // Should be less than server's timeout. // // Expired connections may be closed lazily before reuse. // If d <= 0, connections are not closed due to a connection's idle time. + // -1 disables idle timeout check. // - // Default is 30 minutes. -1 disables idle timeout check. + // default: 30 minutes ConnMaxIdleTime time.Duration + // ConnMaxLifetime is the maximum amount of time a connection may be reused. // // Expired connections may be closed lazily before reuse. // If <= 0, connections are not closed due to a connection's age. // - // Default is to not close idle connections. + // default: 0 ConnMaxLifetime time.Duration - // TLS Config to use. When set, TLS will be negotiated. + // TLSConfig to use. When set, TLS will be negotiated. TLSConfig *tls.Config // Limiter interface used to implement circuit breaker or rate limiter. Limiter Limiter - // Enables read only queries on slave/follower nodes. + // readOnly enables read only queries on slave/follower nodes. readOnly bool - // Disable set-lib on connect. Default is false. + // DisableIndentity set-lib on connect. Default is false. DisableIndentity bool - // Add suffix to client name. Default is empty. + // IdentitySuffix - add suffix to client name. IdentitySuffix string // UnstableResp3 enables Unstable mode for Redis Search module with RESP3. + // When unstable mode is enabled, the client will use RESP3 protocol and only be able to use RawResult UnstableResp3 bool } diff --git a/redis.go b/redis.go index ec3ff616ac..6116d72073 100644 --- a/redis.go +++ b/redis.go @@ -9,6 +9,7 @@ import ( "sync/atomic" "time" + "github.com/redis/go-redis/v9/auth" "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/hscan" "github.com/redis/go-redis/v9/internal/pool" @@ -282,36 +283,47 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { return cn, nil } +func (c *baseClient) reAuth(ctx context.Context, cn *Conn, credentials auth.Credentials) error { + var err error + username, password := credentials.BasicAuth() + if username != "" { + err = cn.AuthACL(ctx, username, password).Err() + } else { + err = cn.Auth(ctx, password).Err() + } + return err +} + func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { if cn.Inited { return nil } - cn.Inited = true var err error - username, password := c.opt.Username, c.opt.Password - if c.opt.CredentialsProviderContext != nil { - if username, password, err = c.opt.CredentialsProviderContext(ctx); err != nil { - return err - } - } else if c.opt.CredentialsProvider != nil { - username, password = c.opt.CredentialsProvider() - } - + cn.Inited = true connPool := pool.NewSingleConnPool(c.connPool, cn) conn := newConn(c.opt, connPool) - var auth bool protocol := c.opt.Protocol // By default, use RESP3 in current version. if protocol < 2 { protocol = 3 } + var authenticated bool + username, password := c.opt.Username, c.opt.Password + if c.opt.CredentialsProviderContext != nil { + if username, password, err = c.opt.CredentialsProviderContext(ctx); err != nil { + return err + } + } else if c.opt.CredentialsProvider != nil { + username, password = c.opt.CredentialsProvider() + } + // for redis-server versions that do not support the HELLO command, // RESP2 will continue to be used. if err = conn.Hello(ctx, protocol, username, password, "").Err(); err == nil { - auth = true + authenticated = true } else if !isRedisError(err) { // When the server responds with the RESP protocol and the result is not a normal // execution result of the HELLO command, we consider it to be an indication that @@ -323,15 +335,13 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { return err } - _, err = conn.Pipelined(ctx, func(pipe Pipeliner) error { - if !auth && password != "" { - if username != "" { - pipe.AuthACL(ctx, username, password) - } else { - pipe.Auth(ctx, password) - } + if !authenticated && password != "" { + err = c.reAuth(ctx, conn, auth.NewCredentials(username, password)) + if err != nil { + return err } - + } + _, err = conn.Pipelined(ctx, func(pipe Pipeliner) error { if c.opt.DB > 0 { pipe.Select(ctx, c.opt.DB) } From df9bfce95440f382296af8cbc4c52346851b4c9c Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Mon, 24 Mar 2025 16:29:15 +0200 Subject: [PATCH 2/7] update documentation --- auth/auth.go | 25 ++++++++++++++++++++++--- redis.go | 2 +- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index d5834fc516..7cb1d1fd3d 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -1,20 +1,36 @@ package auth +// StreamingCredentialsProvider is an interface that defines the methods for a streaming credentials provider. +// It is used to provide credentials for authentication. +// The CredentialsListener is used to receive updates when the credentials change. type StreamingCredentialsProvider interface { - // Subscribe subscribes to the credentials provider and returns a channel that will receive updates. - // The first response is blocking, then data will be pushed to the channel. + // Subscribe subscribes to the credentials provider for updates. + // It returns the current credentials, a cancel function to unsubscribe from the provider, + // and an error if any. Subscribe(listener CredentialsListener) (Credentials, CancelProviderFunc, error) } +// CancelProviderFunc is a function that is used to cancel the subscription to the credentials provider. +// It is used to unsubscribe from the provider when the credentials are no longer needed. type CancelProviderFunc func() error +// CredentialsListener is an interface that defines the methods for a credentials listener. +// It is used to receive updates when the credentials change. +// The OnNext method is called when the credentials change. +// The OnError method is called when an error occurs while requesting the credentials. type CredentialsListener interface { OnNext(credentials Credentials) OnError(err error) } +// Credentials is an interface that defines the methods for credentials. +// It is used to provide the credentials for authentication. type Credentials interface { + // BasicAuth returns the username and password for basic authentication. BasicAuth() (username string, password string) + // RawCredentials returns the raw credentials as a string. + // This can be used to extract the username and password from the raw credentials or + // additional information if present in the token. RawCredentials() string } @@ -23,15 +39,18 @@ type basicAuth struct { password string } +// RawCredentials returns the raw credentials as a string. func (b *basicAuth) RawCredentials() string { return b.username + ":" + b.password } +// BasicAuth returns the username and password for basic authentication. func (b *basicAuth) BasicAuth() (username string, password string) { return b.username, b.password } -func NewCredentials(username, password string) Credentials { +// NewBasicCredentials creates a new Credentials object from the given username and password. +func NewBasicCredentials(username, password string) Credentials { return &basicAuth{ username: username, password: password, diff --git a/redis.go b/redis.go index c2a1805146..2026ff9d6c 100644 --- a/redis.go +++ b/redis.go @@ -336,7 +336,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { } if !authenticated && password != "" { - err = c.reAuth(ctx, conn, auth.NewCredentials(username, password)) + err = c.reAuth(ctx, conn, auth.NewBasicCredentials(username, password)) if err != nil { return err } From 140a278bb892766ceac20ad46ffe3a9ec09a218d Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Mon, 24 Mar 2025 16:39:40 +0200 Subject: [PATCH 3/7] add streamingcredentialsprovider in options --- options.go | 9 +++++++++ osscluster.go | 23 +++++++++++++---------- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/options.go b/options.go index df361f5de6..e98b3247a7 100644 --- a/options.go +++ b/options.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/redis/go-redis/v9/auth" "github.com/redis/go-redis/v9/internal/pool" ) @@ -74,6 +75,14 @@ type Options struct { // There will be a conflict between them; if CredentialsProviderContext exists, we will ignore CredentialsProvider. CredentialsProviderContext func(ctx context.Context) (username string, password string, err error) + // StreamingCredentialsProvider is used to retrieve the credentials + // for the connection from an external source. Those credentials may change + // during the connection lifetime. This is useful for managed identity + // scenarios where the credentials are retrieved from an external source. + // + // Currently, this is a placeholder for the future implementation. + StreamingCredentialsProvider auth.StreamingCredentialsProvider + // DB is the database to be selected after connecting to the server. DB int diff --git a/osscluster.go b/osscluster.go index b018cc9e46..6dd95577ba 100644 --- a/osscluster.go +++ b/osscluster.go @@ -14,6 +14,7 @@ import ( "sync/atomic" "time" + "github.com/redis/go-redis/v9/auth" "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/hashtag" "github.com/redis/go-redis/v9/internal/pool" @@ -66,11 +67,12 @@ type ClusterOptions struct { OnConnect func(ctx context.Context, cn *Conn) error - Protocol int - Username string - Password string - CredentialsProvider func() (username string, password string) - CredentialsProviderContext func(ctx context.Context) (username string, password string, err error) + Protocol int + Username string + Password string + CredentialsProvider func() (username string, password string) + CredentialsProviderContext func(ctx context.Context) (username string, password string, err error) + StreamingCredentialsProvider auth.StreamingCredentialsProvider MaxRetries int MinRetryBackoff time.Duration @@ -291,11 +293,12 @@ func (opt *ClusterOptions) clientOptions() *Options { Dialer: opt.Dialer, OnConnect: opt.OnConnect, - Protocol: opt.Protocol, - Username: opt.Username, - Password: opt.Password, - CredentialsProvider: opt.CredentialsProvider, - CredentialsProviderContext: opt.CredentialsProviderContext, + Protocol: opt.Protocol, + Username: opt.Username, + Password: opt.Password, + CredentialsProvider: opt.CredentialsProvider, + CredentialsProviderContext: opt.CredentialsProviderContext, + StreamingCredentialsProvider: opt.StreamingCredentialsProvider, MaxRetries: opt.MaxRetries, MinRetryBackoff: opt.MinRetryBackoff, From 7f5d87b038bd5a1370eed037a49c879214f8679e Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Mon, 24 Mar 2025 18:10:47 +0200 Subject: [PATCH 4/7] fix: put back option in pool creation --- options.go | 1 + 1 file changed, 1 insertion(+) diff --git a/options.go b/options.go index e98b3247a7..eb35353da1 100644 --- a/options.go +++ b/options.go @@ -582,6 +582,7 @@ func newConnPool( PoolFIFO: opt.PoolFIFO, PoolSize: opt.PoolSize, PoolTimeout: opt.PoolTimeout, + DialTimeout: opt.DialTimeout, MinIdleConns: opt.MinIdleConns, MaxIdleConns: opt.MaxIdleConns, MaxActiveConns: opt.MaxActiveConns, From fa59ccef4f31aac1bcf3041b009d554df40d0173 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Mon, 24 Mar 2025 18:13:58 +0200 Subject: [PATCH 5/7] add package level comment --- auth/auth.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/auth/auth.go b/auth/auth.go index 7cb1d1fd3d..dcfd09eba3 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -1,3 +1,5 @@ +// Package auth package provides authentication-related interfaces and types. +// It also includes a basic implementation of credentials using username and password. package auth // StreamingCredentialsProvider is an interface that defines the methods for a streaming credentials provider. From 40a89c56cc5284e4d5be35dffe99f7ea2f1607ee Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Tue, 15 Apr 2025 13:32:33 +0300 Subject: [PATCH 6/7] Initial re authentication implementation Introduces the StreamingCredentialsProvider as the CredentialsProvider with the highest priority. TODO: needs to be tested --- auth/auth.go | 1 + auth/reauth_credentials_listener.go | 45 +++++++++++++++++++ internal_test.go | 8 ++-- redis.go | 70 ++++++++++++++++++++++++----- sentinel.go | 2 +- 5 files changed, 111 insertions(+), 15 deletions(-) create mode 100644 auth/reauth_credentials_listener.go diff --git a/auth/auth.go b/auth/auth.go index dcfd09eba3..ae9310e043 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -9,6 +9,7 @@ type StreamingCredentialsProvider interface { // Subscribe subscribes to the credentials provider for updates. // It returns the current credentials, a cancel function to unsubscribe from the provider, // and an error if any. + // TODO(ndyakov): Should we add context to the Subscribe method? Subscribe(listener CredentialsListener) (Credentials, CancelProviderFunc, error) } diff --git a/auth/reauth_credentials_listener.go b/auth/reauth_credentials_listener.go new file mode 100644 index 0000000000..12eb295669 --- /dev/null +++ b/auth/reauth_credentials_listener.go @@ -0,0 +1,45 @@ +package auth + +// ReAuthCredentialsListener is a struct that implements the CredentialsListener interface. +// It is used to re-authenticate the credentials when they are updated. +// It contains: +// - reAuth: a function that takes the new credentials and returns an error if any. +// - onErr: a function that takes an error and handles it. +type ReAuthCredentialsListener struct { + reAuth func(credentials Credentials) error + onErr func(err error) +} + +// OnNext is called when the credentials are updated. +// It calls the reAuth function with the new credentials. +// If the reAuth function returns an error, it calls the onErr function with the error. +func (c *ReAuthCredentialsListener) OnNext(credentials Credentials) { + if c.reAuth != nil { + err := c.reAuth(credentials) + if err != nil { + if c.onErr != nil { + c.onErr(err) + } + } + } +} + +// OnError is called when an error occurs. +// It can be called from both the credentials provider and the reAuth function. +func (c *ReAuthCredentialsListener) OnError(err error) { + if c.onErr != nil { + c.onErr(err) + } +} + +// NewReAuthCredentialsListener creates a new ReAuthCredentialsListener. +// Implements the auth.CredentialsListener interface. +func NewReAuthCredentialsListener(reAuth func(credentials Credentials) error, onErr func(err error)) *ReAuthCredentialsListener { + return &ReAuthCredentialsListener{ + reAuth: reAuth, + onErr: onErr, + } +} + +// Ensure ReAuthCredentialsListener implements the CredentialsListener interface. +var _ CredentialsListener = (*ReAuthCredentialsListener)(nil) diff --git a/internal_test.go b/internal_test.go index a6317196a6..c2cbff70a2 100644 --- a/internal_test.go +++ b/internal_test.go @@ -212,10 +212,10 @@ func TestRingShardsCleanup(t *testing.T) { }, NewClient: func(opt *Options) *Client { c := NewClient(opt) - c.baseClient.onClose = func() error { + c.baseClient.onClose = c.baseClient.wrappedOnClose(func() error { closeCounter.increment(opt.Addr) return nil - } + }) return c }, }) @@ -261,10 +261,10 @@ func TestRingShardsCleanup(t *testing.T) { } createCounter.increment(opt.Addr) c := NewClient(opt) - c.baseClient.onClose = func() error { + c.baseClient.onClose = c.baseClient.wrappedOnClose(func() error { closeCounter.increment(opt.Addr) return nil - } + }) return c }, }) diff --git a/redis.go b/redis.go index 2026ff9d6c..94de3fc74b 100644 --- a/redis.go +++ b/redis.go @@ -283,15 +283,57 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { return cn, nil } -func (c *baseClient) reAuth(ctx context.Context, cn *Conn, credentials auth.Credentials) error { - var err error - username, password := credentials.BasicAuth() - if username != "" { - err = cn.AuthACL(ctx, username, password).Err() - } else { - err = cn.Auth(ctx, password).Err() +func (c *baseClient) newReAuthCredentialsListener(ctx context.Context, conn *Conn) auth.CredentialsListener { + return auth.NewReAuthCredentialsListener( + c.reAuthConnection(c.context(ctx), conn), + c.onAuthenticationErr(c.context(ctx), conn), + ) +} + +func (c *baseClient) reAuthConnection(ctx context.Context, cn *Conn) func(credentials auth.Credentials) error { + return func(credentials auth.Credentials) error { + var err error + username, password := credentials.BasicAuth() + if username != "" { + err = cn.AuthACL(ctx, username, password).Err() + } else { + err = cn.Auth(ctx, password).Err() + } + return err + } +} +func (c *baseClient) onAuthenticationErr(ctx context.Context, cn *Conn) func(err error) { + return func(err error) { + // since the connection pool of the *Conn will actually return us the underlying pool.Conn, + // we can get it from the *Conn and remove it from the clients pool. + if err != nil { + if isBadConn(err, false, c.opt.Addr) { + poolCn, _ := cn.connPool.Get(ctx) + c.connPool.Remove(ctx, poolCn, err) + } + } + } +} + +func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error { + onClose := c.onClose + return func() error { + var firstErr error + err := newOnClose() + // Even if we have an error we would like to execute the onClose hook + // if it exists. We will return the first error that occurred. + // This is to keep error handling consistent with the rest of the code. + if err != nil { + firstErr = err + } + if onClose != nil { + err = onClose() + if err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr } - return err } func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { @@ -312,7 +354,15 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { var authenticated bool username, password := c.opt.Username, c.opt.Password - if c.opt.CredentialsProviderContext != nil { + if c.opt.StreamingCredentialsProvider != nil { + credentials, cancelCredentialsProvider, err := c.opt.StreamingCredentialsProvider. + Subscribe(c.newReAuthCredentialsListener(ctx, conn)) + if err != nil { + return err + } + c.onClose = c.wrappedOnClose(cancelCredentialsProvider) + username, password = credentials.BasicAuth() + } else if c.opt.CredentialsProviderContext != nil { if username, password, err = c.opt.CredentialsProviderContext(ctx); err != nil { return err } @@ -336,7 +386,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { } if !authenticated && password != "" { - err = c.reAuth(ctx, conn, auth.NewBasicCredentials(username, password)) + err = c.reAuthConnection(ctx, conn)(auth.NewBasicCredentials(username, password)) if err != nil { return err } diff --git a/sentinel.go b/sentinel.go index a4c9f53c40..5534673535 100644 --- a/sentinel.go +++ b/sentinel.go @@ -257,7 +257,7 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { connPool = newConnPool(opt, rdb.dialHook) rdb.connPool = connPool - rdb.onClose = failover.Close + rdb.onClose = rdb.wrappedOnClose(failover.Close) failover.mu.Lock() failover.onFailover = func(ctx context.Context, addr string) { From d0a8c76d8420c1266263f590cf5e4b6cef5c1c83 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Wed, 16 Apr 2025 12:12:53 +0300 Subject: [PATCH 7/7] Change function type name Change CancelProviderFunc to UnsubscribeFunc --- .gitignore | 3 ++- auth/auth.go | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index e9c8f52641..0d99709e34 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ testdata/* redis8tests.sh coverage.txt **/coverage.txt -.vscode \ No newline at end of file +.vscode +tmp/* diff --git a/auth/auth.go b/auth/auth.go index ae9310e043..1f5c802248 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -10,12 +10,12 @@ type StreamingCredentialsProvider interface { // It returns the current credentials, a cancel function to unsubscribe from the provider, // and an error if any. // TODO(ndyakov): Should we add context to the Subscribe method? - Subscribe(listener CredentialsListener) (Credentials, CancelProviderFunc, error) + Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error) } -// CancelProviderFunc is a function that is used to cancel the subscription to the credentials provider. +// UnsubscribeFunc is a function that is used to cancel the subscription to the credentials provider. // It is used to unsubscribe from the provider when the credentials are no longer needed. -type CancelProviderFunc func() error +type UnsubscribeFunc func() error // CredentialsListener is an interface that defines the methods for a credentials listener. // It is used to receive updates when the credentials change.