@@ -98,82 +98,106 @@ func CloseStatus(err error) StatusCode {
98
98
//
99
99
// Close will unblock all goroutines interacting with the connection once
100
100
// complete.
101
- func (c * Conn ) Close (code StatusCode , reason string ) error {
102
- defer c .wg .Wait ()
103
- return c .closeHandshake (code , reason )
101
+ func (c * Conn ) Close (code StatusCode , reason string ) (err error ) {
102
+ defer errd .Wrap (& err , "failed to close WebSocket" )
103
+
104
+ if ! c .casClosing () {
105
+ err = c .waitGoroutines ()
106
+ if err != nil {
107
+ return err
108
+ }
109
+ return net .ErrClosed
110
+ }
111
+ defer func () {
112
+ if errors .Is (err , net .ErrClosed ) {
113
+ err = nil
114
+ }
115
+ }()
116
+
117
+ err = c .closeHandshake (code , reason )
118
+
119
+ err2 := c .close ()
120
+ if err == nil && err2 != nil {
121
+ err = err2
122
+ }
123
+
124
+ err2 = c .waitGoroutines ()
125
+ if err == nil && err2 != nil {
126
+ err = err2
127
+ }
128
+
129
+ return err
104
130
}
105
131
106
132
// CloseNow closes the WebSocket connection without attempting a close handshake.
107
133
// Use when you do not want the overhead of the close handshake.
108
134
func (c * Conn ) CloseNow () (err error ) {
109
- defer c .wg .Wait ()
110
135
defer errd .Wrap (& err , "failed to close WebSocket" )
111
136
112
- if c .isClosed () {
137
+ if ! c .casClosing () {
138
+ err = c .waitGoroutines ()
139
+ if err != nil {
140
+ return err
141
+ }
113
142
return net .ErrClosed
114
143
}
144
+ defer func () {
145
+ if errors .Is (err , net .ErrClosed ) {
146
+ err = nil
147
+ }
148
+ }()
115
149
116
- c .close (nil )
117
- c .closeMu .Lock ()
118
- defer c .closeMu .Unlock ()
119
- return c .closeErr
120
- }
121
-
122
- func (c * Conn ) closeHandshake (code StatusCode , reason string ) (err error ) {
123
- defer errd .Wrap (& err , "failed to close WebSocket" )
124
-
125
- writeErr := c .writeClose (code , reason )
126
- closeHandshakeErr := c .waitCloseHandshake ()
150
+ err = c .close ()
127
151
128
- if writeErr != nil {
129
- return writeErr
152
+ err2 := c .waitGoroutines ()
153
+ if err == nil && err2 != nil {
154
+ err = err2
130
155
}
156
+ return err
157
+ }
131
158
132
- if CloseStatus (closeHandshakeErr ) == - 1 && ! errors .Is (net .ErrClosed , closeHandshakeErr ) {
133
- return closeHandshakeErr
159
+ func (c * Conn ) closeHandshake (code StatusCode , reason string ) error {
160
+ err := c .writeClose (code , reason )
161
+ if err != nil {
162
+ return err
134
163
}
135
164
165
+ err = c .waitCloseHandshake ()
166
+ if CloseStatus (err ) != code {
167
+ return err
168
+ }
136
169
return nil
137
170
}
138
171
139
172
func (c * Conn ) writeClose (code StatusCode , reason string ) error {
140
- c .closeMu .Lock ()
141
- wroteClose := c .wroteClose
142
- c .wroteClose = true
143
- c .closeMu .Unlock ()
144
- if wroteClose {
145
- return net .ErrClosed
146
- }
147
-
148
173
ce := CloseError {
149
174
Code : code ,
150
175
Reason : reason ,
151
176
}
152
177
153
178
var p []byte
154
- var marshalErr error
179
+ var err error
155
180
if ce .Code != StatusNoStatusRcvd {
156
- p , marshalErr = ce .bytes ()
157
- }
158
-
159
- writeErr := c .writeControl (context .Background (), opClose , p )
160
- if CloseStatus (writeErr ) != - 1 {
161
- // Not a real error if it's due to a close frame being received.
162
- writeErr = nil
181
+ p , err = ce .bytes ()
182
+ if err != nil {
183
+ return err
184
+ }
163
185
}
164
186
165
- // We do this after in case there was an error writing the close frame.
166
- c . setCloseErr ( fmt . Errorf ( "sent close frame: %w" , ce ) )
187
+ ctx , cancel := context . WithTimeout ( context . Background (), time . Second * 5 )
188
+ defer cancel ( )
167
189
168
- if marshalErr != nil {
169
- return marshalErr
190
+ err = c .writeControl (ctx , opClose , p )
191
+ // If the connection closed as we're writing we ignore the error as we might
192
+ // have written the close frame, the peer responded and then someone else read it
193
+ // and closed the connection.
194
+ if err != nil && ! errors .Is (err , net .ErrClosed ) {
195
+ return err
170
196
}
171
- return writeErr
197
+ return nil
172
198
}
173
199
174
200
func (c * Conn ) waitCloseHandshake () error {
175
- defer c .close (nil )
176
-
177
201
ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
178
202
defer cancel ()
179
203
@@ -209,6 +233,36 @@ func (c *Conn) waitCloseHandshake() error {
209
233
}
210
234
}
211
235
236
+ func (c * Conn ) waitGoroutines () error {
237
+ t := time .NewTimer (time .Second * 15 )
238
+ defer t .Stop ()
239
+
240
+ select {
241
+ case <- c .timeoutLoopDone :
242
+ case <- t .C :
243
+ return errors .New ("failed to wait for timeoutLoop goroutine to exit" )
244
+ }
245
+
246
+ c .closeReadMu .Lock ()
247
+ ctx := c .closeReadCtx
248
+ c .closeReadMu .Unlock ()
249
+ if ctx != nil {
250
+ select {
251
+ case <- ctx .Done ():
252
+ case <- t .C :
253
+ return errors .New ("failed to wait for close read goroutine to exit" )
254
+ }
255
+ }
256
+
257
+ select {
258
+ case <- c .closed :
259
+ case <- t .C :
260
+ return errors .New ("failed to wait for connection to be closed" )
261
+ }
262
+
263
+ return nil
264
+ }
265
+
212
266
func parseClosePayload (p []byte ) (CloseError , error ) {
213
267
if len (p ) == 0 {
214
268
return CloseError {
@@ -279,16 +333,14 @@ func (ce CloseError) bytesErr() ([]byte, error) {
279
333
return buf , nil
280
334
}
281
335
282
- func (c * Conn ) setCloseErr ( err error ) {
336
+ func (c * Conn ) casClosing () bool {
283
337
c .closeMu .Lock ()
284
- c .setCloseErrLocked (err )
285
- c .closeMu .Unlock ()
286
- }
287
-
288
- func (c * Conn ) setCloseErrLocked (err error ) {
289
- if c .closeErr == nil && err != nil {
290
- c .closeErr = fmt .Errorf ("WebSocket closed: %w" , err )
338
+ defer c .closeMu .Unlock ()
339
+ if ! c .closing {
340
+ c .closing = true
341
+ return true
291
342
}
343
+ return false
292
344
}
293
345
294
346
func (c * Conn ) isClosed () bool {
0 commit comments