From 85e86704f9a7716ccdeb08ec6161b6663e27a7d6 Mon Sep 17 00:00:00 2001 From: igolaizola <11333576+igolaizola@users.noreply.github.com> Date: Tue, 28 Jan 2025 09:26:19 +0100 Subject: [PATCH] Add ping and pong received callbacks This change adds two optional callbacks to both `DialOptions` and `AcceptOptions`. These callbacks are invoked synchronously when a ping or pong frame is received, allowing advanced users to log or inspect payloads for metrics or debugging. If the callback needs to perform more complex work or reuse the payload outside the callback, it is recommended to perform processing in a separate goroutine. The boolean return value of `OnPingReceived` is used to determine if the subsequent pong frame should be sent. If `false` is returned, the pong frame is not sent. Tests confirm that the ping/pong callbacks are invoked as expected. References #246 --- accept.go | 19 +++++++++++++ conn.go | 16 +++++++---- conn_test.go | 79 ++++++++++++++++++++++++++++++++++++++++++++++++++++ dial.go | 18 ++++++++++++ read.go | 8 ++++++ 5 files changed, 135 insertions(+), 5 deletions(-) diff --git a/accept.go b/accept.go index 774ea285..f45fdd0b 100644 --- a/accept.go +++ b/accept.go @@ -5,6 +5,7 @@ package websocket import ( "bytes" + "context" "crypto/sha1" "encoding/base64" "errors" @@ -62,6 +63,22 @@ type AcceptOptions struct { // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes // for CompressionContextTakeover. CompressionThreshold int + + // OnPingReceived is an optional callback invoked synchronously when a ping frame is received. + // + // The payload contains the application data of the ping frame. + // If the callback returns false, the subsequent pong frame will not be sent. + // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. + OnPingReceived func(ctx context.Context, payload []byte) bool + + // OnPongReceived is an optional callback invoked synchronously when a pong frame is received. + // + // The payload contains the application data of the pong frame. + // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. + // + // Unlike OnPingReceived, this callback does not return a value because a pong frame + // is a response to a ping and does not trigger any further frame transmission. + OnPongReceived func(ctx context.Context, payload []byte) } func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions { @@ -156,6 +173,8 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con client: false, copts: copts, flateThreshold: opts.CompressionThreshold, + onPingReceived: opts.OnPingReceived, + onPongReceived: opts.OnPongReceived, br: brw.Reader, bw: brw.Writer, diff --git a/conn.go b/conn.go index 76b057dd..42fe89fe 100644 --- a/conn.go +++ b/conn.go @@ -83,9 +83,11 @@ type Conn struct { closeMu sync.Mutex // Protects following. closed chan struct{} - pingCounter atomic.Int64 - activePingsMu sync.Mutex - activePings map[string]chan<- struct{} + pingCounter atomic.Int64 + activePingsMu sync.Mutex + activePings map[string]chan<- struct{} + onPingReceived func(context.Context, []byte) bool + onPongReceived func(context.Context, []byte) } type connConfig struct { @@ -94,6 +96,8 @@ type connConfig struct { client bool copts *compressionOptions flateThreshold int + onPingReceived func(context.Context, []byte) bool + onPongReceived func(context.Context, []byte) br *bufio.Reader bw *bufio.Writer @@ -114,8 +118,10 @@ func newConn(cfg connConfig) *Conn { writeTimeout: make(chan context.Context), timeoutLoopDone: make(chan struct{}), - closed: make(chan struct{}), - activePings: make(map[string]chan<- struct{}), + closed: make(chan struct{}), + activePings: make(map[string]chan<- struct{}), + onPingReceived: cfg.onPingReceived, + onPongReceived: cfg.onPongReceived, } c.readMu = newMu(c) diff --git a/conn_test.go b/conn_test.go index 9ed8c7ea..45bb75be 100644 --- a/conn_test.go +++ b/conn_test.go @@ -97,6 +97,85 @@ func TestConn(t *testing.T) { assert.Contains(t, err, "failed to wait for pong") }) + t.Run("pingReceivedPongReceived", func(t *testing.T) { + var pingReceived1, pongReceived1 bool + var pingReceived2, pongReceived2 bool + tt, c1, c2 := newConnTest(t, + &websocket.DialOptions{ + OnPingReceived: func(ctx context.Context, payload []byte) bool { + pingReceived1 = true + return true + }, + OnPongReceived: func(ctx context.Context, payload []byte) { + pongReceived1 = true + }, + }, &websocket.AcceptOptions{ + OnPingReceived: func(ctx context.Context, payload []byte) bool { + pingReceived2 = true + return true + }, + OnPongReceived: func(ctx context.Context, payload []byte) { + pongReceived2 = true + }, + }, + ) + + c1.CloseRead(tt.ctx) + c2.CloseRead(tt.ctx) + + ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100) + defer cancel() + + err := c1.Ping(ctx) + assert.Success(t, err) + + c1.CloseNow() + c2.CloseNow() + + assert.Equal(t, "only one side receives the ping", false, pingReceived1 && pingReceived2) + assert.Equal(t, "only one side receives the pong", false, pongReceived1 && pongReceived2) + assert.Equal(t, "ping and pong received", true, (pingReceived1 && pongReceived2) || (pingReceived2 && pongReceived1)) + }) + + t.Run("pingReceivedPongNotReceived", func(t *testing.T) { + var pingReceived1, pongReceived1 bool + var pingReceived2, pongReceived2 bool + tt, c1, c2 := newConnTest(t, + &websocket.DialOptions{ + OnPingReceived: func(ctx context.Context, payload []byte) bool { + pingReceived1 = true + return false + }, + OnPongReceived: func(ctx context.Context, payload []byte) { + pongReceived1 = true + }, + }, &websocket.AcceptOptions{ + OnPingReceived: func(ctx context.Context, payload []byte) bool { + pingReceived2 = true + return false + }, + OnPongReceived: func(ctx context.Context, payload []byte) { + pongReceived2 = true + }, + }, + ) + + c1.CloseRead(tt.ctx) + c2.CloseRead(tt.ctx) + + ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100) + defer cancel() + + err := c1.Ping(ctx) + assert.Contains(t, err, "failed to wait for pong") + + c1.CloseNow() + c2.CloseNow() + + assert.Equal(t, "only one side receives the ping", false, pingReceived1 && pingReceived2) + assert.Equal(t, "ping received and pong not received", true, (pingReceived1 && !pongReceived2) || (pingReceived2 && !pongReceived1)) + }) + t.Run("concurrentWrite", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) diff --git a/dial.go b/dial.go index ad61a35d..0b11ecbb 100644 --- a/dial.go +++ b/dial.go @@ -48,6 +48,22 @@ type DialOptions struct { // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes // for CompressionContextTakeover. CompressionThreshold int + + // OnPingReceived is an optional callback invoked synchronously when a ping frame is received. + // + // The payload contains the application data of the ping frame. + // If the callback returns false, the subsequent pong frame will not be sent. + // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. + OnPingReceived func(ctx context.Context, payload []byte) bool + + // OnPongReceived is an optional callback invoked synchronously when a pong frame is received. + // + // The payload contains the application data of the pong frame. + // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. + // + // Unlike OnPingReceived, this callback does not return a value because a pong frame + // is a response to a ping and does not trigger any further frame transmission. + OnPongReceived func(ctx context.Context, payload []byte) } func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) { @@ -163,6 +179,8 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( client: true, copts: copts, flateThreshold: opts.CompressionThreshold, + onPingReceived: opts.OnPingReceived, + onPongReceived: opts.OnPongReceived, br: getBufioReader(rwc), bw: getBufioWriter(rwc), }), resp, nil diff --git a/read.go b/read.go index 1267b5b9..2db22435 100644 --- a/read.go +++ b/read.go @@ -312,8 +312,16 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { switch h.opcode { case opPing: + if c.onPingReceived != nil { + if !c.onPingReceived(ctx, b) { + return nil + } + } return c.writeControl(ctx, opPong, b) case opPong: + if c.onPongReceived != nil { + c.onPongReceived(ctx, b) + } c.activePingsMu.Lock() pong, ok := c.activePings[string(b)] c.activePingsMu.Unlock()