Skip to content

Commit 679b569

Browse files
committed
Add ping and pong received callbacks
This change adds two optional callbacks to both `DialOptions` and `AcceptOptions`. These callbacks are invoked synchronously when a ping or pong frame is received, allowing advanced users to log or inspect payloads for metrics or debugging. If the callback needs to perform more complex work or reuse the payload outside the callback, it is recommended to clone the byte slice and/or perform processing in a separate goroutine. Tests confirm that the ping/pong callbacks are invoked as expected. References #246
1 parent 11bda98 commit 679b569

File tree

5 files changed

+81
-5
lines changed

5 files changed

+81
-5
lines changed

Diff for: accept.go

+15
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package websocket
55

66
import (
77
"bytes"
8+
"context"
89
"crypto/sha1"
910
"encoding/base64"
1011
"errors"
@@ -62,6 +63,18 @@ type AcceptOptions struct {
6263
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
6364
// for CompressionContextTakeover.
6465
CompressionThreshold int
66+
67+
// OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
68+
//
69+
// To avoid blocking, process the callback asynchronously using a goroutine.
70+
// If you need to reuse the payload outside the callback, clone the byte slice.
71+
// Any modifications to the payload within the callback will be sent in the subsequent pong frame.
72+
OnPingReceived func(context.Context, []byte)
73+
74+
// OnPongReceived is an optional callback invoked synchronously when a pong frame is received.
75+
//
76+
// To avoid blocking, process the callback asynchronously using a goroutine.
77+
OnPongReceived func(context.Context, []byte)
6578
}
6679

6780
func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions {
@@ -156,6 +169,8 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
156169
client: false,
157170
copts: copts,
158171
flateThreshold: opts.CompressionThreshold,
172+
onPingReceived: opts.OnPingReceived,
173+
onPongReceived: opts.OnPongReceived,
159174

160175
br: brw.Reader,
161176
bw: brw.Writer,

Diff for: conn.go

+11-5
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,11 @@ type Conn struct {
8383
closeMu sync.Mutex // Protects following.
8484
closed chan struct{}
8585

86-
pingCounter atomic.Int64
87-
activePingsMu sync.Mutex
88-
activePings map[string]chan<- struct{}
86+
pingCounter atomic.Int64
87+
activePingsMu sync.Mutex
88+
activePings map[string]chan<- struct{}
89+
onPingReceived func(context.Context, []byte)
90+
onPongReceived func(context.Context, []byte)
8991
}
9092

9193
type connConfig struct {
@@ -94,6 +96,8 @@ type connConfig struct {
9496
client bool
9597
copts *compressionOptions
9698
flateThreshold int
99+
onPingReceived func(context.Context, []byte)
100+
onPongReceived func(context.Context, []byte)
97101

98102
br *bufio.Reader
99103
bw *bufio.Writer
@@ -114,8 +118,10 @@ func newConn(cfg connConfig) *Conn {
114118
writeTimeout: make(chan context.Context),
115119
timeoutLoopDone: make(chan struct{}),
116120

117-
closed: make(chan struct{}),
118-
activePings: make(map[string]chan<- struct{}),
121+
closed: make(chan struct{}),
122+
activePings: make(map[string]chan<- struct{}),
123+
onPingReceived: cfg.onPingReceived,
124+
onPongReceived: cfg.onPongReceived,
119125
}
120126

121127
c.readMu = newMu(c)

Diff for: conn_test.go

+35
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,41 @@ func TestConn(t *testing.T) {
9797
assert.Contains(t, err, "failed to wait for pong")
9898
})
9999

100+
t.Run("pingPongReceived", func(t *testing.T) {
101+
var pingReceived1, pongReceived1 bool
102+
var pingReceived2, pongReceived2 bool
103+
tt, c1, c2 := newConnTest(t,
104+
&websocket.DialOptions{
105+
OnPingReceived: func(ctx context.Context, payload []byte) {
106+
pingReceived1 = true
107+
},
108+
OnPongReceived: func(ctx context.Context, payload []byte) {
109+
pongReceived1 = true
110+
},
111+
}, &websocket.AcceptOptions{
112+
OnPingReceived: func(ctx context.Context, payload []byte) {
113+
pingReceived2 = true
114+
},
115+
OnPongReceived: func(ctx context.Context, payload []byte) {
116+
pongReceived2 = true
117+
},
118+
},
119+
)
120+
121+
c1.CloseRead(tt.ctx)
122+
c2.CloseRead(tt.ctx)
123+
124+
ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100)
125+
defer cancel()
126+
127+
err := c1.Ping(ctx)
128+
assert.Success(t, err)
129+
130+
assert.Equal(t, "only one side receives the ping", false, pingReceived1 && pingReceived2)
131+
assert.Equal(t, "only one side receives the pong", false, pongReceived1 && pongReceived2)
132+
assert.Equal(t, "ping and pong received", true, pingReceived1 && pongReceived2 || pingReceived2 && pongReceived1)
133+
})
134+
100135
t.Run("concurrentWrite", func(t *testing.T) {
101136
tt, c1, c2 := newConnTest(t, nil, nil)
102137

Diff for: dial.go

+14
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ type DialOptions struct {
4848
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
4949
// for CompressionContextTakeover.
5050
CompressionThreshold int
51+
52+
// OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
53+
//
54+
// To avoid blocking, process the callback asynchronously using a goroutine.
55+
// If you need to reuse the payload outside the callback, clone the byte slice.
56+
// Any modifications to the payload within the callback will be sent in the subsequent pong frame.
57+
OnPingReceived func(context.Context, []byte)
58+
59+
// OnPongReceived is an optional callback invoked synchronously when a pong frame is received.
60+
//
61+
// To avoid blocking, process the callback asynchronously using a goroutine.
62+
OnPongReceived func(context.Context, []byte)
5163
}
5264

5365
func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) {
@@ -163,6 +175,8 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
163175
client: true,
164176
copts: copts,
165177
flateThreshold: opts.CompressionThreshold,
178+
onPingReceived: opts.OnPingReceived,
179+
onPongReceived: opts.OnPongReceived,
166180
br: getBufioReader(rwc),
167181
bw: getBufioWriter(rwc),
168182
}), resp, nil

Diff for: read.go

+6
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,14 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
312312

313313
switch h.opcode {
314314
case opPing:
315+
if c.onPingReceived != nil {
316+
c.onPingReceived(ctx, b)
317+
}
315318
return c.writeControl(ctx, opPong, b)
316319
case opPong:
320+
if c.onPongReceived != nil {
321+
c.onPongReceived(ctx, b)
322+
}
317323
c.activePingsMu.Lock()
318324
pong, ok := c.activePings[string(b)]
319325
c.activePingsMu.Unlock()

0 commit comments

Comments
 (0)