From 6ad9c3b33b2b14d3ceed26f154e6429ed64738a0 Mon Sep 17 00:00:00 2001 From: monkey92t Date: Thu, 11 May 2023 21:45:19 +0800 Subject: [PATCH] feat(plugin): add conn plugin api Signed-off-by: monkey92t --- conn_plugin.go | 52 ++++++++++++++++++++++++++++++++++++ redis.go | 72 +++++++++++++++++++++++++++++++++++++------------- 2 files changed, 105 insertions(+), 19 deletions(-) create mode 100644 conn_plugin.go diff --git a/conn_plugin.go b/conn_plugin.go new file mode 100644 index 000000000..28ba70d8f --- /dev/null +++ b/conn_plugin.go @@ -0,0 +1,52 @@ +package redis + +import "context" + +type ( + // PreInitConnPlugin plugin executed before connection initialization. At this point, + // the network connection has been established, but Redis authentication has not yet + // taken place. You can perform specific operations before the Redis authentication, + // such as third-party Redis proxy authentication or executing any necessary commands. + // Please note that the `HELLO` command has not been executed yet. If you invoke any Redis + // commands, the default RESP version of the Redis server will be used. + PreInitConnPlugin func(ctx context.Context, conn *Conn) error + + // InitConnPlugin redis connection authentication plugin. go-redis sets a default + // authentication plugin, but if you need to implement a special authentication + // mechanism for your Redis server, you can use this plugin instead of the default one. + // This plugin can only be set once, and if set multiple times, + // only the last set plugin will be executed. + InitConnPlugin func(ctx context.Context, conn *Conn) error + + // PostInitConnPlugin Plugin executed after connection initialization. At this point, + // Redis authentication has been completed, and you can execute commands related to + // the connection status, such as `SELECT DB`, `CLIENT SETNAME`. + PostInitConnPlugin func(ctx context.Context, conn *Conn) error +) + +// --------------------------------------------------------------------------------------- + +type plugin struct { + preInitConnPlugins []PreInitConnPlugin + initConnPlugin InitConnPlugin + postInitConnPlugin []PostInitConnPlugin +} + +// RegistryPreInitConnPlugin register a PreInitConnPlugin plugin, which can be registered +// multiple times. It will be executed in the order of registration. +func (p *plugin) RegistryPreInitConnPlugin(pre PreInitConnPlugin) { + p.preInitConnPlugins = append(p.preInitConnPlugins, pre) +} + +// RegistryInitConnPlugin register an InitConnPlugin plugin, which will override the default +// authentication mechanism of go-redis. If registered multiple times, only the plugin +// registered last will be executed. +func (p *plugin) RegistryInitConnPlugin(init InitConnPlugin) { + p.initConnPlugin = init +} + +// RegistryPostInitConnPlugin register a PostInitConnPlugin plugin, which can be registered +// multiple times. It will be executed in the order of registration. +func (p *plugin) RegistryPostInitConnPlugin(post PostInitConnPlugin) { + p.postInitConnPlugin = append(p.postInitConnPlugin, post) +} diff --git a/redis.go b/redis.go index cae12f8c9..69b410acd 100644 --- a/redis.go +++ b/redis.go @@ -183,6 +183,7 @@ func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) e //------------------------------------------------------------------------------ type baseClient struct { + plugin opt *Options connPool pool.Pooler @@ -264,22 +265,14 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { return cn, nil } -func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { - if cn.Inited { - return nil - } - cn.Inited = true +func (c *baseClient) authentication(ctx context.Context, conn *Conn) error { + var auth bool username, password := c.opt.Username, c.opt.Password if c.opt.CredentialsProvider != nil { username, password = c.opt.CredentialsProvider() } - connPool := pool.NewSingleConnPool(c.connPool, cn) - conn := newConn(c.opt, connPool) - - var auth bool - // for redis-server versions that do not support the HELLO command, // RESP2 will continue to be used. if err := conn.Hello(ctx, 3, username, password, "").Err(); err == nil { @@ -295,15 +288,49 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { return err } - _, err := conn.Pipelined(ctx, func(pipe Pipeliner) error { - if !auth && password != "" { - if username != "" { - pipe.AuthACL(ctx, username, password) - } else { - pipe.Auth(ctx, password) - } + if !auth && password != "" { + var authErr error + if username != "" { + authErr = conn.AuthACL(ctx, username, password).Err() + } else { + authErr = conn.Auth(ctx, password).Err() + } + + if authErr != nil { + return authErr + } + } + + return nil +} + +func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { + if cn.Inited { + return nil + } + + connPool := pool.NewSingleConnPool(c.connPool, cn) + conn := newConn(c.opt, connPool, c.plugin) + + for _, p := range c.plugin.preInitConnPlugins { + if err := p(ctx, conn); err != nil { + return err + } + } + + cn.Inited = true + + if c.plugin.initConnPlugin != nil { + if err := c.plugin.initConnPlugin(ctx, conn); err != nil { + return err + } + } else { + if err := c.authentication(ctx, conn); err != nil { + return err } + } + _, err := conn.Pipelined(ctx, func(pipe Pipeliner) error { if c.opt.DB > 0 { pipe.Select(ctx, c.opt.DB) } @@ -322,6 +349,12 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { return err } + for _, p := range c.plugin.postInitConnPlugin { + if err = p(ctx, conn); err != nil { + return err + } + } + if c.opt.OnConnect != nil { return c.opt.OnConnect(ctx, conn) } @@ -631,7 +664,7 @@ func (c *Client) WithTimeout(timeout time.Duration) *Client { } func (c *Client) Conn() *Conn { - return newConn(c.opt, pool.NewStickyConnPool(c.connPool)) + return newConn(c.opt, pool.NewStickyConnPool(c.connPool), c.baseClient.plugin) } // Do create a Cmd from the args and processes the cmd. @@ -767,11 +800,12 @@ type Conn struct { hooksMixin } -func newConn(opt *Options, connPool pool.Pooler) *Conn { +func newConn(opt *Options, connPool pool.Pooler, plugin plugin) *Conn { c := Conn{ baseClient: baseClient{ opt: opt, connPool: connPool, + plugin: plugin, }, }