Skip to content

Commit 231f9ae

Browse files
committed
close.go: Rewrite how the library handles closing
Far simpler now. Sorry this took a while. Closes coder#427 Closes coder#429 Closes coder#434 Closes coder#436 Closes coder#437
1 parent 1b5ffe9 commit 231f9ae

File tree

6 files changed

+150
-136
lines changed

6 files changed

+150
-136
lines changed

Diff for: close.go

+103-51
Original file line numberDiff line numberDiff line change
@@ -98,82 +98,106 @@ func CloseStatus(err error) StatusCode {
9898
//
9999
// Close will unblock all goroutines interacting with the connection once
100100
// complete.
101-
func (c *Conn) Close(code StatusCode, reason string) error {
102-
defer c.wg.Wait()
103-
return c.closeHandshake(code, reason)
101+
func (c *Conn) Close(code StatusCode, reason string) (err error) {
102+
defer errd.Wrap(&err, "failed to close WebSocket")
103+
104+
if !c.casClosing() {
105+
err = c.waitGoroutines()
106+
if err != nil {
107+
return err
108+
}
109+
return net.ErrClosed
110+
}
111+
defer func() {
112+
if errors.Is(err, net.ErrClosed) {
113+
err = nil
114+
}
115+
}()
116+
117+
err = c.closeHandshake(code, reason)
118+
119+
err2 := c.close()
120+
if err == nil && err2 != nil {
121+
err = err2
122+
}
123+
124+
err2 = c.waitGoroutines()
125+
if err == nil && err2 != nil {
126+
err = err2
127+
}
128+
129+
return err
104130
}
105131

106132
// CloseNow closes the WebSocket connection without attempting a close handshake.
107133
// Use when you do not want the overhead of the close handshake.
108134
func (c *Conn) CloseNow() (err error) {
109-
defer c.wg.Wait()
110135
defer errd.Wrap(&err, "failed to close WebSocket")
111136

112-
if c.isClosed() {
137+
if !c.casClosing() {
138+
err = c.waitGoroutines()
139+
if err != nil {
140+
return err
141+
}
113142
return net.ErrClosed
114143
}
144+
defer func() {
145+
if errors.Is(err, net.ErrClosed) {
146+
err = nil
147+
}
148+
}()
115149

116-
c.close(nil)
117-
c.closeMu.Lock()
118-
defer c.closeMu.Unlock()
119-
return c.closeErr
120-
}
121-
122-
func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) {
123-
defer errd.Wrap(&err, "failed to close WebSocket")
124-
125-
writeErr := c.writeClose(code, reason)
126-
closeHandshakeErr := c.waitCloseHandshake()
150+
err = c.close()
127151

128-
if writeErr != nil {
129-
return writeErr
152+
err2 := c.waitGoroutines()
153+
if err == nil && err2 != nil {
154+
err = err2
130155
}
156+
return err
157+
}
131158

132-
if CloseStatus(closeHandshakeErr) == -1 && !errors.Is(net.ErrClosed, closeHandshakeErr) {
133-
return closeHandshakeErr
159+
func (c *Conn) closeHandshake(code StatusCode, reason string) error {
160+
err := c.writeClose(code, reason)
161+
if err != nil {
162+
return err
134163
}
135164

165+
err = c.waitCloseHandshake()
166+
if CloseStatus(err) != code {
167+
return err
168+
}
136169
return nil
137170
}
138171

139172
func (c *Conn) writeClose(code StatusCode, reason string) error {
140-
c.closeMu.Lock()
141-
wroteClose := c.wroteClose
142-
c.wroteClose = true
143-
c.closeMu.Unlock()
144-
if wroteClose {
145-
return net.ErrClosed
146-
}
147-
148173
ce := CloseError{
149174
Code: code,
150175
Reason: reason,
151176
}
152177

153178
var p []byte
154-
var marshalErr error
179+
var err error
155180
if ce.Code != StatusNoStatusRcvd {
156-
p, marshalErr = ce.bytes()
157-
}
158-
159-
writeErr := c.writeControl(context.Background(), opClose, p)
160-
if CloseStatus(writeErr) != -1 {
161-
// Not a real error if it's due to a close frame being received.
162-
writeErr = nil
181+
p, err = ce.bytes()
182+
if err != nil {
183+
return err
184+
}
163185
}
164186

165-
// We do this after in case there was an error writing the close frame.
166-
c.setCloseErr(fmt.Errorf("sent close frame: %w", ce))
187+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
188+
defer cancel()
167189

168-
if marshalErr != nil {
169-
return marshalErr
190+
err = c.writeControl(ctx, opClose, p)
191+
// If the connection closed as we're writing we ignore the error as we might
192+
// have written the close frame, the peer responded and then someone else read it
193+
// and closed the connection.
194+
if err != nil && !errors.Is(err, net.ErrClosed) {
195+
return err
170196
}
171-
return writeErr
197+
return nil
172198
}
173199

174200
func (c *Conn) waitCloseHandshake() error {
175-
defer c.close(nil)
176-
177201
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
178202
defer cancel()
179203

@@ -209,6 +233,36 @@ func (c *Conn) waitCloseHandshake() error {
209233
}
210234
}
211235

236+
func (c *Conn) waitGoroutines() error {
237+
t := time.NewTimer(time.Second * 15)
238+
defer t.Stop()
239+
240+
select {
241+
case <-c.timeoutLoopDone:
242+
case <-t.C:
243+
return errors.New("failed to wait for timeoutLoop goroutine to exit")
244+
}
245+
246+
c.closeReadMu.Lock()
247+
ctx := c.closeReadCtx
248+
c.closeReadMu.Unlock()
249+
if ctx != nil {
250+
select {
251+
case <-ctx.Done():
252+
case <-t.C:
253+
return errors.New("failed to wait for close read goroutine to exit")
254+
}
255+
}
256+
257+
select {
258+
case <-c.closed:
259+
case <-t.C:
260+
return errors.New("failed to wait for connection to be closed")
261+
}
262+
263+
return nil
264+
}
265+
212266
func parseClosePayload(p []byte) (CloseError, error) {
213267
if len(p) == 0 {
214268
return CloseError{
@@ -279,16 +333,14 @@ func (ce CloseError) bytesErr() ([]byte, error) {
279333
return buf, nil
280334
}
281335

282-
func (c *Conn) setCloseErr(err error) {
336+
func (c *Conn) casClosing() bool {
283337
c.closeMu.Lock()
284-
c.setCloseErrLocked(err)
285-
c.closeMu.Unlock()
286-
}
287-
288-
func (c *Conn) setCloseErrLocked(err error) {
289-
if c.closeErr == nil && err != nil {
290-
c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
338+
defer c.closeMu.Unlock()
339+
if !c.closing {
340+
c.closing = true
341+
return true
291342
}
343+
return false
292344
}
293345

294346
func (c *Conn) isClosed() bool {

Diff for: conn.go

+29-45
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ package websocket
66
import (
77
"bufio"
88
"context"
9-
"errors"
109
"fmt"
1110
"io"
1211
"net"
@@ -53,8 +52,9 @@ type Conn struct {
5352
br *bufio.Reader
5453
bw *bufio.Writer
5554

56-
readTimeout chan context.Context
57-
writeTimeout chan context.Context
55+
readTimeout chan context.Context
56+
writeTimeout chan context.Context
57+
timeoutLoopDone chan struct{}
5858

5959
// Read state.
6060
readMu *mu
@@ -70,11 +70,12 @@ type Conn struct {
7070
writeHeaderBuf [8]byte
7171
writeHeader header
7272

73-
wg sync.WaitGroup
74-
closed chan struct{}
75-
closeMu sync.Mutex
76-
closeErr error
77-
wroteClose bool
73+
closeReadMu sync.Mutex
74+
closeReadCtx context.Context
75+
76+
closed chan struct{}
77+
closeMu sync.Mutex
78+
closing bool
7879

7980
pingCounter int32
8081
activePingsMu sync.Mutex
@@ -103,8 +104,9 @@ func newConn(cfg connConfig) *Conn {
103104
br: cfg.br,
104105
bw: cfg.bw,
105106

106-
readTimeout: make(chan context.Context),
107-
writeTimeout: make(chan context.Context),
107+
readTimeout: make(chan context.Context),
108+
writeTimeout: make(chan context.Context),
109+
timeoutLoopDone: make(chan struct{}),
108110

109111
closed: make(chan struct{}),
110112
activePings: make(map[string]chan<- struct{}),
@@ -128,14 +130,10 @@ func newConn(cfg connConfig) *Conn {
128130
}
129131

130132
runtime.SetFinalizer(c, func(c *Conn) {
131-
c.close(errors.New("connection garbage collected"))
133+
c.close()
132134
})
133135

134-
c.wg.Add(1)
135-
go func() {
136-
defer c.wg.Done()
137-
c.timeoutLoop()
138-
}()
136+
go c.timeoutLoop()
139137

140138
return c
141139
}
@@ -146,35 +144,29 @@ func (c *Conn) Subprotocol() string {
146144
return c.subprotocol
147145
}
148146

149-
func (c *Conn) close(err error) {
147+
func (c *Conn) close() error {
150148
c.closeMu.Lock()
151149
defer c.closeMu.Unlock()
152150

153151
if c.isClosed() {
154-
return
155-
}
156-
if err == nil {
157-
err = c.rwc.Close()
152+
return net.ErrClosed
158153
}
159-
c.setCloseErrLocked(err)
160-
161-
close(c.closed)
162154
runtime.SetFinalizer(c, nil)
155+
close(c.closed)
163156

164157
// Have to close after c.closed is closed to ensure any goroutine that wakes up
165158
// from the connection being closed also sees that c.closed is closed and returns
166159
// closeErr.
167-
c.rwc.Close()
168-
169-
c.wg.Add(1)
170-
go func() {
171-
defer c.wg.Done()
172-
c.msgWriter.close()
173-
c.msgReader.close()
174-
}()
160+
err := c.rwc.Close()
161+
// With the close of rwc, these become safe to close.
162+
c.msgWriter.close()
163+
c.msgReader.close()
164+
return err
175165
}
176166

177167
func (c *Conn) timeoutLoop() {
168+
defer close(c.timeoutLoopDone)
169+
178170
readCtx := context.Background()
179171
writeCtx := context.Background()
180172

@@ -187,14 +179,10 @@ func (c *Conn) timeoutLoop() {
187179
case readCtx = <-c.readTimeout:
188180

189181
case <-readCtx.Done():
190-
c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err()))
191-
c.wg.Add(1)
192-
go func() {
193-
defer c.wg.Done()
194-
c.writeError(StatusPolicyViolation, errors.New("read timed out"))
195-
}()
182+
c.close()
183+
return
196184
case <-writeCtx.Done():
197-
c.close(fmt.Errorf("write timed out: %w", writeCtx.Err()))
185+
c.close()
198186
return
199187
}
200188
}
@@ -243,9 +231,7 @@ func (c *Conn) ping(ctx context.Context, p string) error {
243231
case <-c.closed:
244232
return net.ErrClosed
245233
case <-ctx.Done():
246-
err := fmt.Errorf("failed to wait for pong: %w", ctx.Err())
247-
c.close(err)
248-
return err
234+
return fmt.Errorf("failed to wait for pong: %w", ctx.Err())
249235
case <-pong:
250236
return nil
251237
}
@@ -281,9 +267,7 @@ func (m *mu) lock(ctx context.Context) error {
281267
case <-m.c.closed:
282268
return net.ErrClosed
283269
case <-ctx.Done():
284-
err := fmt.Errorf("failed to acquire lock: %w", ctx.Err())
285-
m.c.close(err)
286-
return err
270+
return fmt.Errorf("failed to acquire lock: %w", ctx.Err())
287271
case m.ch <- struct{}{}:
288272
// To make sure the connection is certainly alive.
289273
// As it's possible the send on m.ch was selected

Diff for: conn_test.go

+3
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,9 @@ func TestConn(t *testing.T) {
345345

346346
func TestWasm(t *testing.T) {
347347
t.Parallel()
348+
if os.Getenv("CI") == "" {
349+
t.Skip()
350+
}
348351

349352
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
350353
err := echoServer(w, r, &websocket.AcceptOptions{

0 commit comments

Comments
 (0)