Skip to content

feat(plugin): add conn plugin api #2590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions conn_plugin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package redis

import "context"

type (
// PreInitConnPlugin plugin executed before connection initialization. At this point,
// the network connection has been established, but Redis authentication has not yet
// taken place. You can perform specific operations before the Redis authentication,
// such as third-party Redis proxy authentication or executing any necessary commands.
// Please note that the `HELLO` command has not been executed yet. If you invoke any Redis
// commands, the default RESP version of the Redis server will be used.
PreInitConnPlugin func(ctx context.Context, conn *Conn) error

// InitConnPlugin redis connection authentication plugin. go-redis sets a default
// authentication plugin, but if you need to implement a special authentication
// mechanism for your Redis server, you can use this plugin instead of the default one.
// This plugin can only be set once, and if set multiple times,
// only the last set plugin will be executed.
InitConnPlugin func(ctx context.Context, conn *Conn) error

// PostInitConnPlugin Plugin executed after connection initialization. At this point,
// Redis authentication has been completed, and you can execute commands related to
// the connection status, such as `SELECT DB`, `CLIENT SETNAME`.
PostInitConnPlugin func(ctx context.Context, conn *Conn) error
)

// ---------------------------------------------------------------------------------------

type plugin struct {
preInitConnPlugins []PreInitConnPlugin
initConnPlugin InitConnPlugin
postInitConnPlugin []PostInitConnPlugin
}

// RegistryPreInitConnPlugin register a PreInitConnPlugin plugin, which can be registered
// multiple times. It will be executed in the order of registration.
func (p *plugin) RegistryPreInitConnPlugin(pre PreInitConnPlugin) {
p.preInitConnPlugins = append(p.preInitConnPlugins, pre)
}

// RegistryInitConnPlugin register an InitConnPlugin plugin, which will override the default
// authentication mechanism of go-redis. If registered multiple times, only the plugin
// registered last will be executed.
func (p *plugin) RegistryInitConnPlugin(init InitConnPlugin) {
p.initConnPlugin = init
}

// RegistryPostInitConnPlugin register a PostInitConnPlugin plugin, which can be registered
// multiple times. It will be executed in the order of registration.
func (p *plugin) RegistryPostInitConnPlugin(post PostInitConnPlugin) {
p.postInitConnPlugin = append(p.postInitConnPlugin, post)
}
72 changes: 53 additions & 19 deletions redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) e
//------------------------------------------------------------------------------

type baseClient struct {
plugin
opt *Options
connPool pool.Pooler

Expand Down Expand Up @@ -264,22 +265,14 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
return cn, nil
}

func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
if cn.Inited {
return nil
}
cn.Inited = true
func (c *baseClient) authentication(ctx context.Context, conn *Conn) error {
var auth bool

username, password := c.opt.Username, c.opt.Password
if c.opt.CredentialsProvider != nil {
username, password = c.opt.CredentialsProvider()
}

connPool := pool.NewSingleConnPool(c.connPool, cn)
conn := newConn(c.opt, connPool)

var auth bool

// for redis-server versions that do not support the HELLO command,
// RESP2 will continue to be used.
if err := conn.Hello(ctx, 3, username, password, "").Err(); err == nil {
Expand All @@ -295,15 +288,49 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
return err
}

_, err := conn.Pipelined(ctx, func(pipe Pipeliner) error {
if !auth && password != "" {
if username != "" {
pipe.AuthACL(ctx, username, password)
} else {
pipe.Auth(ctx, password)
}
if !auth && password != "" {
var authErr error
if username != "" {
authErr = conn.AuthACL(ctx, username, password).Err()
} else {
authErr = conn.Auth(ctx, password).Err()
}

if authErr != nil {
return authErr
}
}

return nil
}

func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
if cn.Inited {
return nil
}

connPool := pool.NewSingleConnPool(c.connPool, cn)
conn := newConn(c.opt, connPool, c.plugin)

for _, p := range c.plugin.preInitConnPlugins {
if err := p(ctx, conn); err != nil {
return err
}
}

cn.Inited = true

if c.plugin.initConnPlugin != nil {
if err := c.plugin.initConnPlugin(ctx, conn); err != nil {
return err
}
} else {
if err := c.authentication(ctx, conn); err != nil {
return err
}
}

_, err := conn.Pipelined(ctx, func(pipe Pipeliner) error {
if c.opt.DB > 0 {
pipe.Select(ctx, c.opt.DB)
}
Expand All @@ -322,6 +349,12 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
return err
}

for _, p := range c.plugin.postInitConnPlugin {
if err = p(ctx, conn); err != nil {
return err
}
}

if c.opt.OnConnect != nil {
return c.opt.OnConnect(ctx, conn)
}
Expand Down Expand Up @@ -631,7 +664,7 @@ func (c *Client) WithTimeout(timeout time.Duration) *Client {
}

func (c *Client) Conn() *Conn {
return newConn(c.opt, pool.NewStickyConnPool(c.connPool))
return newConn(c.opt, pool.NewStickyConnPool(c.connPool), c.baseClient.plugin)
}

// Do create a Cmd from the args and processes the cmd.
Expand Down Expand Up @@ -767,11 +800,12 @@ type Conn struct {
hooksMixin
}

func newConn(opt *Options, connPool pool.Pooler) *Conn {
func newConn(opt *Options, connPool pool.Pooler, plugin plugin) *Conn {
c := Conn{
baseClient: baseClient{
opt: opt,
connPool: connPool,
plugin: plugin,
},
}

Expand Down