@@ -12,9 +12,9 @@ import (
12
12
"golang.org/x/xerrors"
13
13
)
14
14
15
- type controlFrame struct {
16
- header header
17
- data []byte
15
+ type control struct {
16
+ opcode opcode
17
+ payload []byte
18
18
}
19
19
20
20
// Conn represents a WebSocket connection.
@@ -35,8 +35,10 @@ type Conn struct {
35
35
// Writers should send on write to begin sending
36
36
// a message and then follow that up with some data
37
37
// on writeBytes.
38
- write chan opcode
38
+ write chan DataType
39
+ control chan control
39
40
writeBytes chan []byte
41
+ writeDone chan struct {}
40
42
41
43
// Readers should receive on read to begin reading a message.
42
44
// Then send a byte slice to readBytes to read into it.
@@ -81,7 +83,9 @@ func (c *Conn) Subprotocol() string {
81
83
82
84
func (c * Conn ) init () {
83
85
c .closed = make (chan struct {})
84
- c .write = make (chan opcode )
86
+ c .write = make (chan DataType )
87
+ c .control = make (chan control )
88
+ c .writeDone = make (chan struct {})
85
89
c .read = make (chan opcode )
86
90
c .readDone = make (chan int )
87
91
c .readBytes = make (chan []byte )
@@ -94,67 +98,98 @@ func (c *Conn) init() {
94
98
go c .readLoop ()
95
99
}
96
100
101
+ func (c * Conn ) writeFrame (h header , p []byte ) {
102
+ b2 := marshalHeader (h )
103
+ _ , err := c .bw .Write (b2 )
104
+ if err != nil {
105
+ c .close (xerrors .Errorf ("failed to write to connection: %v" , err ))
106
+ return
107
+ }
108
+
109
+ _ , err = c .bw .Write (p )
110
+ if err != nil {
111
+ c .close (xerrors .Errorf ("failed to write to connection: %v" , err ))
112
+ return
113
+ }
114
+
115
+ if h .opcode .controlOp () {
116
+ err := c .bw .Flush ()
117
+ if err != nil {
118
+ c .close (xerrors .Errorf ("failed to write to connection: %v" , err ))
119
+ return
120
+ }
121
+ }
122
+ }
123
+
97
124
func (c * Conn ) writeLoop () {
98
125
messageLoop:
99
126
for {
100
127
c .writeBytes = make (chan []byte )
101
- var opcode opcode
128
+
129
+ var dataType DataType
102
130
select {
103
131
case <- c .closed :
104
132
return
105
- case opcode = <- c .write :
133
+ case dataType = <- c .write :
134
+ case control := <- c .control :
135
+ h := header {
136
+ fin : true ,
137
+ opcode : control .opcode ,
138
+ payloadLength : int64 (len (control .payload )),
139
+ masked : c .client ,
140
+ }
141
+ c .writeFrame (h , control .payload )
142
+ c .writeDone <- struct {}{}
143
+ continue
106
144
}
107
145
108
146
var firstSent bool
109
147
for {
110
148
select {
111
149
case <- c .closed :
112
150
return
151
+ case control := <- c .control :
152
+ h := header {
153
+ fin : true ,
154
+ opcode : control .opcode ,
155
+ payloadLength : int64 (len (control .payload )),
156
+ masked : c .client ,
157
+ }
158
+ c .writeFrame (h , control .payload )
159
+ c .writeDone <- struct {}{}
160
+ continue
113
161
case b , ok := <- c .writeBytes :
114
- if ! firstSent || ! opcode .controlOp () {
115
- h := header {
116
- fin : opcode .controlOp () || ! ok ,
117
- opcode : opcode ,
118
- payloadLength : int64 (len (b )),
119
- masked : c .client ,
120
- }
162
+ h := header {
163
+ fin : ! ok ,
164
+ opcode : opcode (dataType ),
165
+ payloadLength : int64 (len (b )),
166
+ masked : c .client ,
167
+ }
121
168
122
- if firstSent {
123
- h .opcode = opContinuation
124
- }
125
- firstSent = true
169
+ if firstSent {
170
+ h .opcode = opContinuation
171
+ }
172
+ firstSent = true
126
173
127
- b2 := marshalHeader (h )
128
- _ , err := c .bw .Write (b2 )
129
- if err != nil {
130
- c .close (xerrors .Errorf ("failed to write to connection: %v" , err ))
131
- return
132
- }
174
+ c .writeFrame (h , b )
133
175
134
- _ , err = c .bw .Write (b )
176
+ if ! ok {
177
+ err := c .bw .Flush ()
135
178
if err != nil {
136
179
c .close (xerrors .Errorf ("failed to write to connection: %v" , err ))
137
180
return
138
181
}
139
182
}
140
183
141
- if ok {
142
- select {
143
- case <- c .closed :
144
- return
145
- case c .writeBytes <- nil :
146
- }
147
- } else {
148
- err := c .bw .Flush ()
149
- if err != nil {
150
- c .close (xerrors .Errorf ("failed to write to connection: %v" , err ))
151
- return
152
- }
153
- if opcode == opClose {
154
- c .close (nil )
155
- return
184
+ select {
185
+ case <- c .closed :
186
+ return
187
+ case c .writeDone <- struct {}{}:
188
+ if ok {
189
+ continue
190
+ } else {
191
+ continue messageLoop
156
192
}
157
- continue messageLoop
158
193
}
159
194
}
160
195
}
@@ -167,6 +202,11 @@ func (c *Conn) handleControl(h header) {
167
202
return
168
203
}
169
204
205
+ if ! h .fin {
206
+ c .Close (StatusProtocolError , "control frame cannot be fragmented" )
207
+ return
208
+ }
209
+
170
210
b := make ([]byte , h .payloadLength )
171
211
_ , err := io .ReadFull (c .br , b )
172
212
if err != nil {
@@ -183,12 +223,20 @@ func (c *Conn) handleControl(h header) {
183
223
c .writePong (b )
184
224
case opPong :
185
225
case opClose :
186
- code , reason , err := parseClosePayload (b )
187
- if err != nil {
188
- c .close (xerrors .Errorf ("read invalid close payload: %v" , err ))
189
- return
226
+ if len (b ) > 0 {
227
+ code , reason , err := parseClosePayload (b )
228
+ if err != nil {
229
+ c .close (xerrors .Errorf ("read invalid close payload: %v" , err ))
230
+ return
231
+ }
232
+ c .Close (code , reason )
233
+ } else {
234
+ ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
235
+ defer cancel ()
236
+
237
+ c .writeControl (ctx , opClose , nil )
238
+ c .close (nil )
190
239
}
191
- c .Close (code , reason )
192
240
default :
193
241
panic (fmt .Sprintf ("websocket: unexpected control opcode: %#v" , h ))
194
242
}
@@ -208,33 +256,38 @@ func (c *Conn) readLoop() {
208
256
return
209
257
}
210
258
211
- // TODO this is fucked, as if they are reading a frame as they are writing, then we can't send ping/close so we'll just get stuck for 5s.
212
- switch h .opcode {
213
- case opClose , opPing , opPong :
259
+ if h .opcode .controlOp () {
214
260
c .handleControl (h )
215
261
continue
216
262
}
217
263
218
264
switch h .opcode {
219
265
case opBinary , opText :
266
+ if ! indata {
267
+ select {
268
+ case <- c .closed :
269
+ return
270
+ case c .read <- h .opcode :
271
+ }
272
+ indata = true
273
+ } else {
274
+ c .Close (StatusProtocolError , "cannot send data frame when previous frame is not finished" )
275
+ return
276
+ }
277
+ case opContinuation :
278
+ if ! indata {
279
+ c .Close (StatusProtocolError , "continuation frame not after data or text frame" )
280
+ return
281
+ }
220
282
default :
221
283
c .close (xerrors .Errorf ("unexpected opcode in header: %#v" , h ))
222
284
return
223
285
}
224
286
225
- if ! indata {
226
- select {
227
- case <- c .closed :
228
- return
229
- case c .read <- h .opcode :
230
- }
231
- } else {
232
- indata = true
233
- }
234
-
235
- var maskPos int
287
+ maskPos := 0
236
288
left := h .payloadLength
237
- for left > 0 {
289
+ firstRead := false
290
+ for left > 0 || ! firstRead {
238
291
select {
239
292
case <- c .closed :
240
293
return
@@ -258,6 +311,7 @@ func (c *Conn) readLoop() {
258
311
case <- c .closed :
259
312
return
260
313
case c .readDone <- len (b ):
314
+ firstRead = true
261
315
}
262
316
}
263
317
}
@@ -277,13 +331,7 @@ func (c *Conn) writePong(p []byte) error {
277
331
ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
278
332
defer cancel ()
279
333
280
- w := c .messageWriter (opPong )
281
- w .SetContext (ctx )
282
- _ , err := w .Write (p )
283
- if err != nil {
284
- return err
285
- }
286
- err = w .Close ()
334
+ err := c .writeControl (ctx , opPong , p )
287
335
return err
288
336
}
289
337
@@ -292,14 +340,10 @@ func (c *Conn) writePong(p []byte) error {
292
340
// Ensure you close the MessageWriter once you have written to entire message.
293
341
// Concurrent calls to MessageWriter are ok.
294
342
func (c * Conn ) MessageWriter (dataType DataType ) * MessageWriter {
295
- return c .messageWriter (opcode (dataType ))
296
- }
297
-
298
- func (c * Conn ) messageWriter (opcode opcode ) * MessageWriter {
299
343
return & MessageWriter {
300
- c : c ,
301
- ctx : context .Background (),
302
- opcode : opcode ,
344
+ c : c ,
345
+ ctx : context .Background (),
346
+ datatype : dataType ,
303
347
}
304
348
}
305
349
@@ -337,48 +381,46 @@ func (c *Conn) Close(code StatusCode, reason string) error {
337
381
ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
338
382
defer cancel ()
339
383
340
- select {
341
- case <- c .closed :
342
- return c .getCloseErr ()
343
- case c .write <- opClose :
344
- case <- ctx .Done ():
345
- c .close (xerrors .New ("force closed: close frame write timed out" ))
346
- return c .getCloseErr ()
384
+ err = c .writeControl (ctx , opClose , p )
385
+ if err != nil {
386
+ return err
387
+ }
388
+
389
+ c .close (nil )
390
+
391
+ if err != nil {
392
+ return err
347
393
}
394
+ return c .closeErr
395
+ }
348
396
397
+ func (c * Conn ) writeControl (ctx context.Context , opcode opcode , p []byte ) error {
349
398
select {
350
399
case <- c .closed :
351
400
return c .getCloseErr ()
352
- case c .writeBytes <- p :
353
- select {
354
- case <- c .closed :
355
- return c .getCloseErr ()
356
- case <- c .writeBytes :
357
- close (c .writeBytes )
358
- case <- ctx .Done ():
359
- return ctx .Err ()
360
- }
401
+ case c .control <- control {
402
+ opcode : opcode ,
403
+ payload : p ,
404
+ }:
361
405
case <- ctx .Done ():
362
406
c .close (xerrors .New ("force closed: close frame write timed out" ))
363
407
return c .getCloseErr ()
364
408
}
365
409
366
410
select {
367
411
case <- c .closed :
368
- if err != nil {
369
- return err
370
- }
371
- return c .closeErr
372
- case <- ctx .Done ():
373
- c .close (xerrors .New ("force closed: close frame write timed out" ))
374
412
return c .getCloseErr ()
413
+ case <- c .writeDone :
414
+ return nil
415
+ case <- ctx .Done ():
416
+ return ctx .Err ()
375
417
}
376
418
}
377
419
378
420
// MessageWriter enables writing to a WebSocket connection.
379
421
// Ensure you close the MessageWriter once you have written to entire message.
380
422
type MessageWriter struct {
381
- opcode opcode
423
+ datatype DataType
382
424
ctx context.Context
383
425
c * Conn
384
426
acquiredLock bool
@@ -396,7 +438,7 @@ func (w *MessageWriter) Write(p []byte) (int, error) {
396
438
select {
397
439
case <- w .c .closed :
398
440
return 0 , w .c .getCloseErr ()
399
- case w .c .write <- w .opcode :
441
+ case w .c .write <- w .datatype :
400
442
w .acquiredLock = true
401
443
case <- w .ctx .Done ():
402
444
return 0 , w .ctx .Err ()
@@ -410,7 +452,7 @@ func (w *MessageWriter) Write(p []byte) (int, error) {
410
452
select {
411
453
case <- w .c .closed :
412
454
return 0 , w .c .getCloseErr ()
413
- case <- w .c .writeBytes :
455
+ case <- w .c .writeDone :
414
456
return len (p ), nil
415
457
case <- w .ctx .Done ():
416
458
return 0 , w .ctx .Err ()
@@ -432,13 +474,14 @@ func (w *MessageWriter) Close() error {
432
474
select {
433
475
case <- w .c .closed :
434
476
return w .c .getCloseErr ()
435
- case w .c .write <- w .opcode :
477
+ case w .c .write <- w .datatype :
436
478
w .acquiredLock = true
437
479
case <- w .ctx .Done ():
438
480
return w .ctx .Err ()
439
481
}
440
482
}
441
483
close (w .c .writeBytes )
484
+ <- w .c .writeDone
442
485
return nil
443
486
}
444
487
0 commit comments