Skip to content

Commit b9465b2

Browse files
committedApr 7, 2019
Get near full autobahn suite passing
1 parent 11e2521 commit b9465b2

File tree

3 files changed

+180
-110
lines changed

3 files changed

+180
-110
lines changed
 

‎statuscode.go

+30
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"errors"
66
"fmt"
77
"math/bits"
8+
"unicode/utf8"
89

910
"golang.org/x/xerrors"
1011
)
@@ -53,9 +54,38 @@ func parseClosePayload(p []byte) (code StatusCode, reason string, err error) {
5354
code = StatusCode(binary.BigEndian.Uint16(p))
5455
reason = string(p[2:])
5556

57+
if !utf8.ValidString(reason) {
58+
return 0, "", xerrors.Errorf("invalid utf-8: %q", reason)
59+
}
60+
if !isValidReceivedCloseCode(code) {
61+
return 0, "", xerrors.Errorf("invalid code %v", code)
62+
}
63+
5664
return code, reason, nil
5765
}
5866

67+
var validReceivedCloseCodes = map[StatusCode]bool{
68+
// see http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
69+
StatusNormalClosure: true,
70+
StatusGoingAway: true,
71+
StatusProtocolError: true,
72+
StatusUnsupportedData: true,
73+
StatusNoStatusRcvd: false,
74+
StatusAbnormalClosure: false,
75+
StatusInvalidFramePayloadData: true,
76+
StatusPolicyViolation: true,
77+
StatusMessageTooBig: true,
78+
StatusMandatoryExtension: true,
79+
StatusInternalError: true,
80+
StatusServiceRestart: true,
81+
StatusTryAgainLater: true,
82+
StatusTLSHandshake: false,
83+
}
84+
85+
func isValidReceivedCloseCode(code StatusCode) bool {
86+
return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
87+
}
88+
5989
const maxControlFramePayload = 125
6090

6191
func closePayload(code StatusCode, reason string) ([]byte, error) {

‎websocket.go

+144-101
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ import (
1212
"golang.org/x/xerrors"
1313
)
1414

15-
type controlFrame struct {
16-
header header
17-
data []byte
15+
type control struct {
16+
opcode opcode
17+
payload []byte
1818
}
1919

2020
// Conn represents a WebSocket connection.
@@ -35,8 +35,10 @@ type Conn struct {
3535
// Writers should send on write to begin sending
3636
// a message and then follow that up with some data
3737
// on writeBytes.
38-
write chan opcode
38+
write chan DataType
39+
control chan control
3940
writeBytes chan []byte
41+
writeDone chan struct{}
4042

4143
// Readers should receive on read to begin reading a message.
4244
// Then send a byte slice to readBytes to read into it.
@@ -81,7 +83,9 @@ func (c *Conn) Subprotocol() string {
8183

8284
func (c *Conn) init() {
8385
c.closed = make(chan struct{})
84-
c.write = make(chan opcode)
86+
c.write = make(chan DataType)
87+
c.control = make(chan control)
88+
c.writeDone = make(chan struct{})
8589
c.read = make(chan opcode)
8690
c.readDone = make(chan int)
8791
c.readBytes = make(chan []byte)
@@ -94,67 +98,98 @@ func (c *Conn) init() {
9498
go c.readLoop()
9599
}
96100

101+
func (c *Conn) writeFrame(h header, p []byte) {
102+
b2 := marshalHeader(h)
103+
_, err := c.bw.Write(b2)
104+
if err != nil {
105+
c.close(xerrors.Errorf("failed to write to connection: %v", err))
106+
return
107+
}
108+
109+
_, err = c.bw.Write(p)
110+
if err != nil {
111+
c.close(xerrors.Errorf("failed to write to connection: %v", err))
112+
return
113+
}
114+
115+
if h.opcode.controlOp() {
116+
err := c.bw.Flush()
117+
if err != nil {
118+
c.close(xerrors.Errorf("failed to write to connection: %v", err))
119+
return
120+
}
121+
}
122+
}
123+
97124
func (c *Conn) writeLoop() {
98125
messageLoop:
99126
for {
100127
c.writeBytes = make(chan []byte)
101-
var opcode opcode
128+
129+
var dataType DataType
102130
select {
103131
case <-c.closed:
104132
return
105-
case opcode = <-c.write:
133+
case dataType = <-c.write:
134+
case control := <-c.control:
135+
h := header{
136+
fin: true,
137+
opcode: control.opcode,
138+
payloadLength: int64(len(control.payload)),
139+
masked: c.client,
140+
}
141+
c.writeFrame(h, control.payload)
142+
c.writeDone <- struct{}{}
143+
continue
106144
}
107145

108146
var firstSent bool
109147
for {
110148
select {
111149
case <-c.closed:
112150
return
151+
case control := <-c.control:
152+
h := header{
153+
fin: true,
154+
opcode: control.opcode,
155+
payloadLength: int64(len(control.payload)),
156+
masked: c.client,
157+
}
158+
c.writeFrame(h, control.payload)
159+
c.writeDone <- struct{}{}
160+
continue
113161
case b, ok := <-c.writeBytes:
114-
if !firstSent || !opcode.controlOp() {
115-
h := header{
116-
fin: opcode.controlOp() || !ok,
117-
opcode: opcode,
118-
payloadLength: int64(len(b)),
119-
masked: c.client,
120-
}
162+
h := header{
163+
fin: !ok,
164+
opcode: opcode(dataType),
165+
payloadLength: int64(len(b)),
166+
masked: c.client,
167+
}
121168

122-
if firstSent {
123-
h.opcode = opContinuation
124-
}
125-
firstSent = true
169+
if firstSent {
170+
h.opcode = opContinuation
171+
}
172+
firstSent = true
126173

127-
b2 := marshalHeader(h)
128-
_, err := c.bw.Write(b2)
129-
if err != nil {
130-
c.close(xerrors.Errorf("failed to write to connection: %v", err))
131-
return
132-
}
174+
c.writeFrame(h, b)
133175

134-
_, err = c.bw.Write(b)
176+
if !ok {
177+
err := c.bw.Flush()
135178
if err != nil {
136179
c.close(xerrors.Errorf("failed to write to connection: %v", err))
137180
return
138181
}
139182
}
140183

141-
if ok {
142-
select {
143-
case <-c.closed:
144-
return
145-
case c.writeBytes <- nil:
146-
}
147-
} else {
148-
err := c.bw.Flush()
149-
if err != nil {
150-
c.close(xerrors.Errorf("failed to write to connection: %v", err))
151-
return
152-
}
153-
if opcode == opClose {
154-
c.close(nil)
155-
return
184+
select {
185+
case <-c.closed:
186+
return
187+
case c.writeDone <- struct{}{}:
188+
if ok {
189+
continue
190+
} else {
191+
continue messageLoop
156192
}
157-
continue messageLoop
158193
}
159194
}
160195
}
@@ -167,6 +202,11 @@ func (c *Conn) handleControl(h header) {
167202
return
168203
}
169204

205+
if !h.fin {
206+
c.Close(StatusProtocolError, "control frame cannot be fragmented")
207+
return
208+
}
209+
170210
b := make([]byte, h.payloadLength)
171211
_, err := io.ReadFull(c.br, b)
172212
if err != nil {
@@ -183,12 +223,20 @@ func (c *Conn) handleControl(h header) {
183223
c.writePong(b)
184224
case opPong:
185225
case opClose:
186-
code, reason, err := parseClosePayload(b)
187-
if err != nil {
188-
c.close(xerrors.Errorf("read invalid close payload: %v", err))
189-
return
226+
if len(b) > 0 {
227+
code, reason, err := parseClosePayload(b)
228+
if err != nil {
229+
c.close(xerrors.Errorf("read invalid close payload: %v", err))
230+
return
231+
}
232+
c.Close(code, reason)
233+
} else {
234+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
235+
defer cancel()
236+
237+
c.writeControl(ctx, opClose, nil)
238+
c.close(nil)
190239
}
191-
c.Close(code, reason)
192240
default:
193241
panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h))
194242
}
@@ -208,33 +256,38 @@ func (c *Conn) readLoop() {
208256
return
209257
}
210258

211-
// TODO this is fucked, as if they are reading a frame as they are writing, then we can't send ping/close so we'll just get stuck for 5s.
212-
switch h.opcode {
213-
case opClose, opPing, opPong:
259+
if h.opcode.controlOp() {
214260
c.handleControl(h)
215261
continue
216262
}
217263

218264
switch h.opcode {
219265
case opBinary, opText:
266+
if !indata {
267+
select {
268+
case <-c.closed:
269+
return
270+
case c.read <- h.opcode:
271+
}
272+
indata = true
273+
} else {
274+
c.Close(StatusProtocolError, "cannot send data frame when previous frame is not finished")
275+
return
276+
}
277+
case opContinuation:
278+
if !indata {
279+
c.Close(StatusProtocolError, "continuation frame not after data or text frame")
280+
return
281+
}
220282
default:
221283
c.close(xerrors.Errorf("unexpected opcode in header: %#v", h))
222284
return
223285
}
224286

225-
if !indata {
226-
select {
227-
case <-c.closed:
228-
return
229-
case c.read <- h.opcode:
230-
}
231-
} else {
232-
indata = true
233-
}
234-
235-
var maskPos int
287+
maskPos := 0
236288
left := h.payloadLength
237-
for left > 0 {
289+
firstRead := false
290+
for left > 0 || !firstRead {
238291
select {
239292
case <-c.closed:
240293
return
@@ -258,6 +311,7 @@ func (c *Conn) readLoop() {
258311
case <-c.closed:
259312
return
260313
case c.readDone <- len(b):
314+
firstRead = true
261315
}
262316
}
263317
}
@@ -277,13 +331,7 @@ func (c *Conn) writePong(p []byte) error {
277331
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
278332
defer cancel()
279333

280-
w := c.messageWriter(opPong)
281-
w.SetContext(ctx)
282-
_, err := w.Write(p)
283-
if err != nil {
284-
return err
285-
}
286-
err = w.Close()
334+
err := c.writeControl(ctx, opPong, p)
287335
return err
288336
}
289337

@@ -292,14 +340,10 @@ func (c *Conn) writePong(p []byte) error {
292340
// Ensure you close the MessageWriter once you have written to entire message.
293341
// Concurrent calls to MessageWriter are ok.
294342
func (c *Conn) MessageWriter(dataType DataType) *MessageWriter {
295-
return c.messageWriter(opcode(dataType))
296-
}
297-
298-
func (c *Conn) messageWriter(opcode opcode) *MessageWriter {
299343
return &MessageWriter{
300-
c: c,
301-
ctx: context.Background(),
302-
opcode: opcode,
344+
c: c,
345+
ctx: context.Background(),
346+
datatype: dataType,
303347
}
304348
}
305349

@@ -337,48 +381,46 @@ func (c *Conn) Close(code StatusCode, reason string) error {
337381
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
338382
defer cancel()
339383

340-
select {
341-
case <-c.closed:
342-
return c.getCloseErr()
343-
case c.write <- opClose:
344-
case <-ctx.Done():
345-
c.close(xerrors.New("force closed: close frame write timed out"))
346-
return c.getCloseErr()
384+
err = c.writeControl(ctx, opClose, p)
385+
if err != nil {
386+
return err
387+
}
388+
389+
c.close(nil)
390+
391+
if err != nil {
392+
return err
347393
}
394+
return c.closeErr
395+
}
348396

397+
func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
349398
select {
350399
case <-c.closed:
351400
return c.getCloseErr()
352-
case c.writeBytes <- p:
353-
select {
354-
case <-c.closed:
355-
return c.getCloseErr()
356-
case <-c.writeBytes:
357-
close(c.writeBytes)
358-
case <-ctx.Done():
359-
return ctx.Err()
360-
}
401+
case c.control <- control{
402+
opcode: opcode,
403+
payload: p,
404+
}:
361405
case <-ctx.Done():
362406
c.close(xerrors.New("force closed: close frame write timed out"))
363407
return c.getCloseErr()
364408
}
365409

366410
select {
367411
case <-c.closed:
368-
if err != nil {
369-
return err
370-
}
371-
return c.closeErr
372-
case <-ctx.Done():
373-
c.close(xerrors.New("force closed: close frame write timed out"))
374412
return c.getCloseErr()
413+
case <-c.writeDone:
414+
return nil
415+
case <-ctx.Done():
416+
return ctx.Err()
375417
}
376418
}
377419

378420
// MessageWriter enables writing to a WebSocket connection.
379421
// Ensure you close the MessageWriter once you have written to entire message.
380422
type MessageWriter struct {
381-
opcode opcode
423+
datatype DataType
382424
ctx context.Context
383425
c *Conn
384426
acquiredLock bool
@@ -396,7 +438,7 @@ func (w *MessageWriter) Write(p []byte) (int, error) {
396438
select {
397439
case <-w.c.closed:
398440
return 0, w.c.getCloseErr()
399-
case w.c.write <- w.opcode:
441+
case w.c.write <- w.datatype:
400442
w.acquiredLock = true
401443
case <-w.ctx.Done():
402444
return 0, w.ctx.Err()
@@ -410,7 +452,7 @@ func (w *MessageWriter) Write(p []byte) (int, error) {
410452
select {
411453
case <-w.c.closed:
412454
return 0, w.c.getCloseErr()
413-
case <-w.c.writeBytes:
455+
case <-w.c.writeDone:
414456
return len(p), nil
415457
case <-w.ctx.Done():
416458
return 0, w.ctx.Err()
@@ -432,13 +474,14 @@ func (w *MessageWriter) Close() error {
432474
select {
433475
case <-w.c.closed:
434476
return w.c.getCloseErr()
435-
case w.c.write <- w.opcode:
477+
case w.c.write <- w.datatype:
436478
w.acquiredLock = true
437479
case <-w.ctx.Done():
438480
return w.ctx.Err()
439481
}
440482
}
441483
close(w.c.writeBytes)
484+
<-w.c.writeDone
442485
return nil
443486
}
444487

‎websocket_test.go

+6-9
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package websocket_test
33
import (
44
"context"
55
"encoding/json"
6-
"golang.org/x/time/rate"
76
"io"
87
"io/ioutil"
98
"net/http"
@@ -118,14 +117,11 @@ func TestAutobahn(t *testing.T) {
118117
return err
119118
}
120119

121-
ctx, cancel = context.WithTimeout(ctx, time.Second*10)
122-
defer cancel()
123-
124120
r.SetContext(ctx)
125-
r.Limit(131072)
126121

127122
w := c.MessageWriter(typ)
128123
w.SetContext(ctx)
124+
129125
_, err = io.Copy(w, r)
130126
if err != nil {
131127
return err
@@ -139,8 +135,7 @@ func TestAutobahn(t *testing.T) {
139135
return nil
140136
}
141137

142-
l := rate.NewLimiter(rate.Every(time.Millisecond*100), 10)
143-
for l.Allow() {
138+
for {
144139
err := echo()
145140
if err != nil {
146141
t.Logf("%v: failed to echo message: %+v", time.Now(), err)
@@ -162,7 +157,7 @@ func TestAutobahn(t *testing.T) {
162157
},
163158
},
164159
"cases": []string{"*"},
165-
"exclude-cases": []interface{}{},
160+
"exclude-cases": []string{"6.*", "12.*", "13.*"},
166161
}
167162
specFile, err := ioutil.TempFile("", "websocket_fuzzingclient.json")
168163
if err != nil {
@@ -216,7 +211,9 @@ func TestAutobahn(t *testing.T) {
216211
var failed bool
217212
for _, tests := range indexJSON {
218213
for test, result := range tests {
219-
if result.Behavior != "OK" {
214+
switch result.Behavior {
215+
case "OK", "NON-STRICT", "INFORMATIONAL":
216+
default:
220217
failed = true
221218
t.Errorf("test %v failed", test)
222219
}

0 commit comments

Comments
 (0)
Please sign in to comment.