From 8e3039accaa05c319ebd0ab0b2878b6b5baf2303 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Thu, 4 Apr 2019 19:34:38 -0500 Subject: [PATCH 1/4] Set a finalizer on Conn --- websocket.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/websocket.go b/websocket.go index 83ef0bb5..d25f55ed 100644 --- a/websocket.go +++ b/websocket.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "io" + "runtime" "sync" "time" @@ -19,7 +20,6 @@ type controlFrame struct { // Conn represents a WebSocket connection. // Pings will always be automatically responded to with pongs, you do not // have to do anything special. -// TODO set finalizer type Conn struct { subprotocol string br *bufio.Reader @@ -59,6 +59,8 @@ func (c *Conn) close(err error) { } c.closeOnce.Do(func() { + runtime.SetFinalizer(c, nil) + c.closeErr = err cerr := c.closer.Close() @@ -82,6 +84,10 @@ func (c *Conn) init() { c.read = make(chan opcode) c.readBytes = make(chan []byte) + runtime.SetFinalizer(c, func(c *Conn) { + c.Close(StatusInternalError, "websocket: connection ended up being garbage collected") + }) + go c.writeLoop() go c.readLoop() } From d8499350eeff7244143a2d0bf7b902a658227757 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sat, 30 Mar 2019 23:04:10 -0500 Subject: [PATCH 2/4] Integrate autobahn tests --- .gitignore | 1 + opcode.go | 2 +- opcode_string.go | 10 +- statuscode.go | 2 +- statuscode_string.go | 26 ++-- websocket.go | 284 +++++++++++++++++++++++++------------------ websocket_test.go | 146 +++++++++++++++++++++- 7 files changed, 335 insertions(+), 136 deletions(-) diff --git a/.gitignore b/.gitignore index 4383ca89..70d8e703 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ coverage.html +wstest_reports diff --git a/opcode.go b/opcode.go index de54fdb4..2469e781 100644 --- a/opcode.go +++ b/opcode.go @@ -10,7 +10,7 @@ const ( opText opBinary // 3 - 7 are reserved for further non-control frames. - opClose opcode = 8 + iota + opClose opcode = 8 + iota - 3 opPing opPong // 11-16 are reserved for further control frames. diff --git a/opcode_string.go b/opcode_string.go index e815cdea..740b5e70 100644 --- a/opcode_string.go +++ b/opcode_string.go @@ -11,9 +11,9 @@ func _() { _ = x[opContinuation-0] _ = x[opText-1] _ = x[opBinary-2] - _ = x[opClose-11] - _ = x[opPing-12] - _ = x[opPong-13] + _ = x[opClose-8] + _ = x[opPing-9] + _ = x[opPong-10] } const ( @@ -30,8 +30,8 @@ func (i opcode) String() string { switch { case 0 <= i && i <= 2: return _opcode_name_0[_opcode_index_0[i]:_opcode_index_0[i+1]] - case 11 <= i && i <= 13: - i -= 11 + case 8 <= i && i <= 10: + i -= 8 return _opcode_name_1[_opcode_index_1[i]:_opcode_index_1[i+1]] default: return "opcode(" + strconv.FormatInt(int64(i), 10) + ")" diff --git a/statuscode.go b/statuscode.go index 7efbfcd1..790de712 100644 --- a/statuscode.go +++ b/statuscode.go @@ -21,7 +21,7 @@ const ( StatusProtocolError StatusUnsupportedData // 1004 is reserved. - StatusNoStatusRcvd StatusCode = 1005 + iota + StatusNoStatusRcvd StatusCode = 1005 + iota - 4 StatusAbnormalClosure StatusInvalidFramePayloadData StatusPolicyViolation diff --git a/statuscode_string.go b/statuscode_string.go index e1bfc462..fc8cea0d 100644 --- a/statuscode_string.go +++ b/statuscode_string.go @@ -12,17 +12,17 @@ func _() { _ = x[StatusGoingAway-1001] _ = x[StatusProtocolError-1002] _ = x[StatusUnsupportedData-1003] - _ = x[StatusNoStatusRcvd-1009] - _ = x[StatusAbnormalClosure-1010] - _ = x[StatusInvalidFramePayloadData-1011] - _ = x[StatusPolicyViolation-1012] - _ = x[StatusMessageTooBig-1013] - _ = x[StatusMandatoryExtension-1014] - _ = x[StatusInternalError-1015] - _ = x[StatusServiceRestart-1016] - _ = x[StatusTryAgainLater-1017] - _ = x[StatusBadGateway-1018] - _ = x[StatusTLSHandshake-1019] + _ = x[StatusNoStatusRcvd-1005] + _ = x[StatusAbnormalClosure-1006] + _ = x[StatusInvalidFramePayloadData-1007] + _ = x[StatusPolicyViolation-1008] + _ = x[StatusMessageTooBig-1009] + _ = x[StatusMandatoryExtension-1010] + _ = x[StatusInternalError-1011] + _ = x[StatusServiceRestart-1012] + _ = x[StatusTryAgainLater-1013] + _ = x[StatusBadGateway-1014] + _ = x[StatusTLSHandshake-1015] } const ( @@ -40,8 +40,8 @@ func (i StatusCode) String() string { case 1000 <= i && i <= 1003: i -= 1000 return _StatusCode_name_0[_StatusCode_index_0[i]:_StatusCode_index_0[i+1]] - case 1009 <= i && i <= 1019: - i -= 1009 + case 1005 <= i && i <= 1015: + i -= 1005 return _StatusCode_name_1[_StatusCode_index_1[i]:_StatusCode_index_1[i+1]] default: return "StatusCode(" + strconv.FormatInt(int64(i), 10) + ")" diff --git a/websocket.go b/websocket.go index d25f55ed..b36b2a74 100644 --- a/websocket.go +++ b/websocket.go @@ -23,9 +23,10 @@ type controlFrame struct { type Conn struct { subprotocol string br *bufio.Reader - bw *bufio.Writer - closer io.Closer - client bool + // TODO Cannot use bufio writer because for compression we need to know how much is buffered and compress it if large. + bw *bufio.Writer + closer io.Closer + client bool closeOnce sync.Once closeErr error @@ -39,11 +40,11 @@ type Conn struct { // Readers should receive on read to begin reading a message. // Then send a byte slice to readBytes to read into it. - // A value on done will be sent once the read into a slice is complete. - // done will be closed when the message has been fully read. + // The n of bytes read will be sent on readDone once the read into a slice is complete. + // readDone will receive 0 when EOF is reached. read chan opcode readBytes chan []byte - readDone chan struct{} + readDone chan int } func (c *Conn) getCloseErr() error { @@ -82,6 +83,7 @@ func (c *Conn) init() { c.closed = make(chan struct{}) c.write = make(chan opcode) c.read = make(chan opcode) + c.readDone = make(chan int) c.readBytes = make(chan []byte) runtime.SetFinalizer(c, func(c *Conn) { @@ -109,20 +111,40 @@ messageLoop: case <-c.closed: return case b, ok := <-c.writeBytes: - if !ok { - if !opcode.controlOp() { - h := header{ - fin: true, - opcode: opContinuation, - masked: c.client, - } - b = marshalHeader(h) - _, err := c.bw.Write(b) - if err != nil { - c.close(xerrors.Errorf("failed to write to connection: %v", err)) - return - } + if !firstSent || !opcode.controlOp() { + h := header{ + fin: opcode.controlOp() || !ok, + opcode: opcode, + payloadLength: int64(len(b)), + masked: c.client, } + + if firstSent { + h.opcode = opContinuation + } + firstSent = true + + b2 := marshalHeader(h) + _, err := c.bw.Write(b2) + if err != nil { + c.close(xerrors.Errorf("failed to write to connection: %v", err)) + return + } + + _, err = c.bw.Write(b) + if err != nil { + c.close(xerrors.Errorf("failed to write to connection: %v", err)) + return + } + } + + if ok { + select { + case <-c.closed: + return + case c.writeBytes <- nil: + } + } else { err := c.bw.Flush() if err != nil { c.close(xerrors.Errorf("failed to write to connection: %v", err)) @@ -134,37 +156,46 @@ messageLoop: } continue messageLoop } + } + } + } +} - h := header{ - fin: opcode.controlOp(), - opcode: opcode, - payloadLength: int64(len(b)), - masked: c.client, - } +func (c *Conn) handleControl(h header) { + if h.payloadLength > maxControlFramePayload { + c.Close(StatusProtocolError, "control frame too large") + return + } - if firstSent { - h.opcode = opContinuation - } - firstSent = true + b := make([]byte, h.payloadLength) + _, err := io.ReadFull(c.br, b) + if err != nil { + c.close(xerrors.Errorf("failed to read control frame payload: %v", err)) + return + } - b2 := marshalHeader(h) - _, err := c.bw.Write(b2) - if err != nil { - c.close(xerrors.Errorf("failed to write to connection: %v", err)) - return - } + if h.masked { + mask(h.maskKey, 0, b) + } - _, err = c.bw.Write(b) - if err != nil { - c.close(xerrors.Errorf("failed to write to connection: %v", err)) - return - } - } + switch h.opcode { + case opPing: + c.writePong(b) + case opPong: + case opClose: + code, reason, err := parseClosePayload(b) + if err != nil { + c.close(xerrors.Errorf("read invalid close payload: %v", err)) + return } + c.Close(code, reason) + default: + panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h)) } } func (c *Conn) readLoop() { + var indata bool for { h, err := readHeader(c.br) if err != nil { @@ -172,33 +203,18 @@ func (c *Conn) readLoop() { return } - switch h.opcode { - case opClose, opPing: - if h.payloadLength > maxControlFramePayload { - c.Close(StatusProtocolError, "control frame too large") - return - } - b := make([]byte, h.payloadLength) - _, err = io.ReadFull(c.br, b) - if err != nil { - c.close(xerrors.Errorf("failed to read control frame payload: %v", err)) - return - } - - if h.opcode == opPing { - c.writePing(b) - continue - } - - code, reason, err := parseClosePayload(b) - if err != nil { - c.close(xerrors.Errorf("invalid close payload: %v", err)) - return - } - c.Close(code, reason) + if h.rsv1 || h.rsv2 || h.rsv3 { + c.Close(StatusProtocolError, fmt.Sprintf("read header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)) return } + // TODO this is fucked, as if they are reading a frame as they are writing, then we can't send ping/close so we'll just get stuck for 5s. + switch h.opcode { + case opClose, opPing, opPong: + c.handleControl(h) + continue + } + switch h.opcode { case opBinary, opText: default: @@ -206,55 +222,69 @@ func (c *Conn) readLoop() { return } - c.readDone = make(chan struct{}) - c.read <- h.opcode - for { - var maskPos int - left := h.payloadLength - for left > 0 { - select { - case <-c.closed: - return - case b := <-c.readBytes: - if int64(len(b)) > left { - b = b[:left] - } + if !indata { + select { + case <-c.closed: + return + case c.read <- h.opcode: + } + } else { + indata = true + } - _, err = io.ReadFull(c.br, b) - if err != nil { - c.close(xerrors.Errorf("failed to read from connection: %v", err)) - return - } - left -= int64(len(b)) + var maskPos int + left := h.payloadLength + for left > 0 { + select { + case <-c.closed: + return + case b := <-c.readBytes: + if int64(len(b)) > left { + b = b[:left] + } - if h.masked { - maskPos = mask(h.maskKey, maskPos, b) - } + _, err = io.ReadFull(c.br, b) + if err != nil { + c.close(xerrors.Errorf("failed to read from connection: %v", err)) + return + } + left -= int64(len(b)) - select { - case <-c.closed: - return - case c.readDone <- struct{}{}: - } + if h.masked { + maskPos = mask(h.maskKey, maskPos, b) } - } - if h.fin { - break + select { + case <-c.closed: + return + case c.readDone <- len(b): + } } - h, err = readHeader(c.br) - if err != nil { - c.close(xerrors.Errorf("failed to read header: %v", err)) + } + + if h.fin { + indata = false + select { + case <-c.closed: return + case c.readDone <- 0: } - // TODO check opcode. } - close(c.readDone) } } -func (c *Conn) writePing(p []byte) { - panic("TODO") +func (c *Conn) writePong(p []byte) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + w := c.messageWriter(opPong) + w.SetContext(ctx) + _, err := w.Write(p) + if err != nil { + return err + } + err = w.Close() + return err } // MessageWriter returns a writer bounded by the context that will write @@ -281,14 +311,14 @@ func (c *Conn) messageWriter(opcode opcode) *MessageWriter { func (c *Conn) ReadMessage(ctx context.Context) (DataType, *MessageReader, error) { select { case <-c.closed: - return 0, nil, c.getCloseErr() + return 0, nil, xerrors.Errorf("failed to read message: %v", c.getCloseErr()) case opcode := <-c.read: return DataType(opcode), &MessageReader{ ctx: context.Background(), c: c, }, nil case <-ctx.Done(): - return 0, nil, ctx.Err() + return 0, nil, xerrors.Errorf("failed to read message: %v", ctx.Err()) } } @@ -313,26 +343,36 @@ func (c *Conn) Close(code StatusCode, reason string) error { case c.write <- opClose: case <-ctx.Done(): c.close(xerrors.New("force closed: close frame write timed out")) + return c.getCloseErr() } select { case <-c.closed: return c.getCloseErr() case c.writeBytes <- p: - close(c.writeBytes) + select { + case <-c.closed: + return c.getCloseErr() + case <-c.writeBytes: + close(c.writeBytes) + case <-ctx.Done(): + return ctx.Err() + } case <-ctx.Done(): c.close(xerrors.New("force closed: close frame write timed out")) + return c.getCloseErr() } select { case <-c.closed: + if err != nil { + return err + } + return c.closeErr case <-ctx.Done(): c.close(xerrors.New("force closed: close frame write timed out")) + return c.getCloseErr() } - if err != nil { - return err - } - return c.closeErr } // MessageWriter enables writing to a WebSocket connection. @@ -367,7 +407,14 @@ func (w *MessageWriter) Write(p []byte) (int, error) { case <-w.c.closed: return 0, w.c.getCloseErr() case w.c.writeBytes <- p: - return len(p), nil + select { + case <-w.c.closed: + return 0, w.c.getCloseErr() + case <-w.c.writeBytes: + return len(p), nil + case <-w.ctx.Done(): + return 0, w.ctx.Err() + } case <-w.ctx.Done(): return 0, w.ctx.Err() } @@ -382,7 +429,14 @@ func (w *MessageWriter) SetContext(ctx context.Context) { // This must be called for every MessageWriter. func (w *MessageWriter) Close() error { if !w.acquiredLock { - return xerrors.New("websocket: MessageWriter closed without writing any bytes") + select { + case <-w.c.closed: + return w.c.getCloseErr() + case w.c.write <- w.opcode: + w.acquiredLock = true + case <-w.ctx.Done(): + return w.ctx.Err() + } } close(w.c.writeBytes) return nil @@ -419,13 +473,13 @@ func (r *MessageReader) Read(p []byte) (n int, err error) { select { case <-r.c.closed: return 0, r.c.getCloseErr() - case <-r.c.readDone: - r.n += len(p) + case n := <-r.c.readDone: + r.n += n // TODO make this better later and inside readLoop to prevent the read from actually occuring if over limit. - if r.limit > 0 && n > r.limit { + if r.limit > 0 && r.n > r.limit { return 0, xerrors.New("message too big") } - return len(p), nil + return n, nil case <-r.ctx.Done(): return 0, r.ctx.Err() } diff --git a/websocket_test.go b/websocket_test.go index 06187075..554ba7b1 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -2,8 +2,15 @@ package websocket_test import ( "context" + "encoding/json" + "golang.org/x/time/rate" + "io" + "io/ioutil" "net/http" "net/http/httptest" + "os" + "os/exec" + "strings" "testing" "time" @@ -40,7 +47,6 @@ func TestConnection(t *testing.T) { return } - t.Log("success", v) obj <- v c.Close(websocket.StatusNormalClosure, "") @@ -87,3 +93,141 @@ func TestConnection(t *testing.T) { t.Fatalf("test timed out") } } + +func TestAutobahn(t *testing.T) { + t.Parallel() + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, + websocket.AcceptSubprotocols("echo"), + ) + if err != nil { + t.Logf("server handshake failed: %v", err) + return + } + defer c.Close(websocket.StatusInternalError, "") + + ctx := context.Background() + + echo := func() error { + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + + typ, r, err := c.ReadMessage(ctx) + if err != nil { + return err + } + + ctx, cancel = context.WithTimeout(ctx, time.Second*10) + defer cancel() + + r.SetContext(ctx) + r.Limit(131072) + + w := c.MessageWriter(typ) + w.SetContext(ctx) + _, err = io.Copy(w, r) + if err != nil { + return err + } + + err = w.Close() + if err != nil { + return err + } + + return nil + } + + l := rate.NewLimiter(rate.Every(time.Millisecond*100), 10) + for l.Allow() { + err := echo() + if err != nil { + t.Logf("%v: failed to echo message: %+v", time.Now(), err) + return + } + } + })) + defer s.Close() + + spec := map[string]interface{}{ + "outdir": "wstest_reports/server", + "servers": []interface{}{ + map[string]interface{}{ + "agent": "main", + "url": strings.Replace(s.URL, "http", "ws", 1), + "options": map[string]interface{}{ + "version": 18, + }, + }, + }, + "cases": []string{"*"}, + "exclude-cases": []interface{}{}, + } + specFile, err := ioutil.TempFile("", "websocket_fuzzingclient.json") + if err != nil { + t.Fatalf("failed to create temp file for fuzzingclient.json: %v", err) + } + defer specFile.Close() + + e := json.NewEncoder(specFile) + e.SetIndent("", "\t") + err = e.Encode(spec) + if err != nil { + t.Fatalf("failed to write spec: %v", err) + } + + err = specFile.Close() + if err != nil { + t.Fatalf("failed to close file: %v", err) + } + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Minute*10) + defer cancel() + + args := []string{"--mode", "fuzzingclient", "--spec", specFile.Name()} + if os.Getenv("CI") == "" { + args = append([]string{"--debug"}, args...) + } + wstest := exec.CommandContext(ctx, "wstest", args...) + out, err := wstest.CombinedOutput() + if err != nil { + t.Fatalf("failed to run wstest: %v\nout:\n%s", err, out) + } + + b, err := ioutil.ReadFile("./wstest_reports/server/index.json") + if err != nil { + t.Fatalf("failed to read index.json: %v", err) + } + + if testing.Verbose() { + t.Logf("output: %s", out) + } + + var indexJSON map[string]map[string]struct { + Behavior string `json:"behavior"` + } + err = json.Unmarshal(b, &indexJSON) + if err != nil { + t.Fatalf("failed to unmarshal index.json: %v", err) + } + + var failed bool + for _, tests := range indexJSON { + for test, result := range tests { + if result.Behavior != "OK" { + failed = true + t.Errorf("test %v failed", test) + } + } + } + + if failed { + if os.Getenv("CI") == "" { + t.Errorf("wstest found failure, please see ./wstest_reports/server/index.html") + } else { + t.Errorf("wstest found failure, please run test.sh locally to see ./wstest_reports/server/index.html") + } + } +} From 11e252151ec1b23467f5943d2efe911faab15658 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sat, 30 Mar 2019 23:04:10 -0500 Subject: [PATCH 3/4] Add autobahn testsuite to CI --- .github/test/Dockerfile | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/test/Dockerfile b/.github/test/Dockerfile index c6874db3..ec6c4769 100644 --- a/.github/test/Dockerfile +++ b/.github/test/Dockerfile @@ -5,7 +5,9 @@ LABEL "com.github.actions.description"="test" LABEL "com.github.actions.icon"="code" LABEL "com.github.actions.color"="purple" -RUN apt update && apt install -y shellcheck +RUN apt update && \ + apt install -y shellcheck python-pip && \ + pip install autobahntestsuite COPY entrypoint.sh /entrypoint.sh From b9465b2e32c33fb50e7c60e8a129df49aefd22ab Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sat, 30 Mar 2019 23:04:10 -0500 Subject: [PATCH 4/4] Get near full autobahn suite passing --- statuscode.go | 30 ++++++ websocket.go | 245 +++++++++++++++++++++++++++------------------- websocket_test.go | 15 ++- 3 files changed, 180 insertions(+), 110 deletions(-) diff --git a/statuscode.go b/statuscode.go index 790de712..ed7e64b7 100644 --- a/statuscode.go +++ b/statuscode.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "math/bits" + "unicode/utf8" "golang.org/x/xerrors" ) @@ -53,9 +54,38 @@ func parseClosePayload(p []byte) (code StatusCode, reason string, err error) { code = StatusCode(binary.BigEndian.Uint16(p)) reason = string(p[2:]) + if !utf8.ValidString(reason) { + return 0, "", xerrors.Errorf("invalid utf-8: %q", reason) + } + if !isValidReceivedCloseCode(code) { + return 0, "", xerrors.Errorf("invalid code %v", code) + } + return code, reason, nil } +var validReceivedCloseCodes = map[StatusCode]bool{ + // see http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number + StatusNormalClosure: true, + StatusGoingAway: true, + StatusProtocolError: true, + StatusUnsupportedData: true, + StatusNoStatusRcvd: false, + StatusAbnormalClosure: false, + StatusInvalidFramePayloadData: true, + StatusPolicyViolation: true, + StatusMessageTooBig: true, + StatusMandatoryExtension: true, + StatusInternalError: true, + StatusServiceRestart: true, + StatusTryAgainLater: true, + StatusTLSHandshake: false, +} + +func isValidReceivedCloseCode(code StatusCode) bool { + return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) +} + const maxControlFramePayload = 125 func closePayload(code StatusCode, reason string) ([]byte, error) { diff --git a/websocket.go b/websocket.go index b36b2a74..69781251 100644 --- a/websocket.go +++ b/websocket.go @@ -12,9 +12,9 @@ import ( "golang.org/x/xerrors" ) -type controlFrame struct { - header header - data []byte +type control struct { + opcode opcode + payload []byte } // Conn represents a WebSocket connection. @@ -35,8 +35,10 @@ type Conn struct { // Writers should send on write to begin sending // a message and then follow that up with some data // on writeBytes. - write chan opcode + write chan DataType + control chan control writeBytes chan []byte + writeDone chan struct{} // Readers should receive on read to begin reading a message. // Then send a byte slice to readBytes to read into it. @@ -81,7 +83,9 @@ func (c *Conn) Subprotocol() string { func (c *Conn) init() { c.closed = make(chan struct{}) - c.write = make(chan opcode) + c.write = make(chan DataType) + c.control = make(chan control) + c.writeDone = make(chan struct{}) c.read = make(chan opcode) c.readDone = make(chan int) c.readBytes = make(chan []byte) @@ -94,15 +98,49 @@ func (c *Conn) init() { go c.readLoop() } +func (c *Conn) writeFrame(h header, p []byte) { + b2 := marshalHeader(h) + _, err := c.bw.Write(b2) + if err != nil { + c.close(xerrors.Errorf("failed to write to connection: %v", err)) + return + } + + _, err = c.bw.Write(p) + if err != nil { + c.close(xerrors.Errorf("failed to write to connection: %v", err)) + return + } + + if h.opcode.controlOp() { + err := c.bw.Flush() + if err != nil { + c.close(xerrors.Errorf("failed to write to connection: %v", err)) + return + } + } +} + func (c *Conn) writeLoop() { messageLoop: for { c.writeBytes = make(chan []byte) - var opcode opcode + + var dataType DataType select { case <-c.closed: return - case opcode = <-c.write: + case dataType = <-c.write: + case control := <-c.control: + h := header{ + fin: true, + opcode: control.opcode, + payloadLength: int64(len(control.payload)), + masked: c.client, + } + c.writeFrame(h, control.payload) + c.writeDone <- struct{}{} + continue } var firstSent bool @@ -110,51 +148,48 @@ messageLoop: select { case <-c.closed: return + case control := <-c.control: + h := header{ + fin: true, + opcode: control.opcode, + payloadLength: int64(len(control.payload)), + masked: c.client, + } + c.writeFrame(h, control.payload) + c.writeDone <- struct{}{} + continue case b, ok := <-c.writeBytes: - if !firstSent || !opcode.controlOp() { - h := header{ - fin: opcode.controlOp() || !ok, - opcode: opcode, - payloadLength: int64(len(b)), - masked: c.client, - } + h := header{ + fin: !ok, + opcode: opcode(dataType), + payloadLength: int64(len(b)), + masked: c.client, + } - if firstSent { - h.opcode = opContinuation - } - firstSent = true + if firstSent { + h.opcode = opContinuation + } + firstSent = true - b2 := marshalHeader(h) - _, err := c.bw.Write(b2) - if err != nil { - c.close(xerrors.Errorf("failed to write to connection: %v", err)) - return - } + c.writeFrame(h, b) - _, err = c.bw.Write(b) + if !ok { + err := c.bw.Flush() if err != nil { c.close(xerrors.Errorf("failed to write to connection: %v", err)) return } } - if ok { - select { - case <-c.closed: - return - case c.writeBytes <- nil: - } - } else { - err := c.bw.Flush() - if err != nil { - c.close(xerrors.Errorf("failed to write to connection: %v", err)) - return - } - if opcode == opClose { - c.close(nil) - return + select { + case <-c.closed: + return + case c.writeDone <- struct{}{}: + if ok { + continue + } else { + continue messageLoop } - continue messageLoop } } } @@ -167,6 +202,11 @@ func (c *Conn) handleControl(h header) { return } + if !h.fin { + c.Close(StatusProtocolError, "control frame cannot be fragmented") + return + } + b := make([]byte, h.payloadLength) _, err := io.ReadFull(c.br, b) if err != nil { @@ -183,12 +223,20 @@ func (c *Conn) handleControl(h header) { c.writePong(b) case opPong: case opClose: - code, reason, err := parseClosePayload(b) - if err != nil { - c.close(xerrors.Errorf("read invalid close payload: %v", err)) - return + if len(b) > 0 { + code, reason, err := parseClosePayload(b) + if err != nil { + c.close(xerrors.Errorf("read invalid close payload: %v", err)) + return + } + c.Close(code, reason) + } else { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + c.writeControl(ctx, opClose, nil) + c.close(nil) } - c.Close(code, reason) default: panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h)) } @@ -208,33 +256,38 @@ func (c *Conn) readLoop() { return } - // TODO this is fucked, as if they are reading a frame as they are writing, then we can't send ping/close so we'll just get stuck for 5s. - switch h.opcode { - case opClose, opPing, opPong: + if h.opcode.controlOp() { c.handleControl(h) continue } switch h.opcode { case opBinary, opText: + if !indata { + select { + case <-c.closed: + return + case c.read <- h.opcode: + } + indata = true + } else { + c.Close(StatusProtocolError, "cannot send data frame when previous frame is not finished") + return + } + case opContinuation: + if !indata { + c.Close(StatusProtocolError, "continuation frame not after data or text frame") + return + } default: c.close(xerrors.Errorf("unexpected opcode in header: %#v", h)) return } - if !indata { - select { - case <-c.closed: - return - case c.read <- h.opcode: - } - } else { - indata = true - } - - var maskPos int + maskPos := 0 left := h.payloadLength - for left > 0 { + firstRead := false + for left > 0 || !firstRead { select { case <-c.closed: return @@ -258,6 +311,7 @@ func (c *Conn) readLoop() { case <-c.closed: return case c.readDone <- len(b): + firstRead = true } } } @@ -277,13 +331,7 @@ func (c *Conn) writePong(p []byte) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - w := c.messageWriter(opPong) - w.SetContext(ctx) - _, err := w.Write(p) - if err != nil { - return err - } - err = w.Close() + err := c.writeControl(ctx, opPong, p) return err } @@ -292,14 +340,10 @@ func (c *Conn) writePong(p []byte) error { // Ensure you close the MessageWriter once you have written to entire message. // Concurrent calls to MessageWriter are ok. func (c *Conn) MessageWriter(dataType DataType) *MessageWriter { - return c.messageWriter(opcode(dataType)) -} - -func (c *Conn) messageWriter(opcode opcode) *MessageWriter { return &MessageWriter{ - c: c, - ctx: context.Background(), - opcode: opcode, + c: c, + ctx: context.Background(), + datatype: dataType, } } @@ -337,27 +381,27 @@ func (c *Conn) Close(code StatusCode, reason string) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - select { - case <-c.closed: - return c.getCloseErr() - case c.write <- opClose: - case <-ctx.Done(): - c.close(xerrors.New("force closed: close frame write timed out")) - return c.getCloseErr() + err = c.writeControl(ctx, opClose, p) + if err != nil { + return err + } + + c.close(nil) + + if err != nil { + return err } + return c.closeErr +} +func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { select { case <-c.closed: return c.getCloseErr() - case c.writeBytes <- p: - select { - case <-c.closed: - return c.getCloseErr() - case <-c.writeBytes: - close(c.writeBytes) - case <-ctx.Done(): - return ctx.Err() - } + case c.control <- control{ + opcode: opcode, + payload: p, + }: case <-ctx.Done(): c.close(xerrors.New("force closed: close frame write timed out")) return c.getCloseErr() @@ -365,20 +409,18 @@ func (c *Conn) Close(code StatusCode, reason string) error { select { case <-c.closed: - if err != nil { - return err - } - return c.closeErr - case <-ctx.Done(): - c.close(xerrors.New("force closed: close frame write timed out")) return c.getCloseErr() + case <-c.writeDone: + return nil + case <-ctx.Done(): + return ctx.Err() } } // MessageWriter enables writing to a WebSocket connection. // Ensure you close the MessageWriter once you have written to entire message. type MessageWriter struct { - opcode opcode + datatype DataType ctx context.Context c *Conn acquiredLock bool @@ -396,7 +438,7 @@ func (w *MessageWriter) Write(p []byte) (int, error) { select { case <-w.c.closed: return 0, w.c.getCloseErr() - case w.c.write <- w.opcode: + case w.c.write <- w.datatype: w.acquiredLock = true case <-w.ctx.Done(): return 0, w.ctx.Err() @@ -410,7 +452,7 @@ func (w *MessageWriter) Write(p []byte) (int, error) { select { case <-w.c.closed: return 0, w.c.getCloseErr() - case <-w.c.writeBytes: + case <-w.c.writeDone: return len(p), nil case <-w.ctx.Done(): return 0, w.ctx.Err() @@ -432,13 +474,14 @@ func (w *MessageWriter) Close() error { select { case <-w.c.closed: return w.c.getCloseErr() - case w.c.write <- w.opcode: + case w.c.write <- w.datatype: w.acquiredLock = true case <-w.ctx.Done(): return w.ctx.Err() } } close(w.c.writeBytes) + <-w.c.writeDone return nil } diff --git a/websocket_test.go b/websocket_test.go index 554ba7b1..b4b19c62 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -3,7 +3,6 @@ package websocket_test import ( "context" "encoding/json" - "golang.org/x/time/rate" "io" "io/ioutil" "net/http" @@ -118,14 +117,11 @@ func TestAutobahn(t *testing.T) { return err } - ctx, cancel = context.WithTimeout(ctx, time.Second*10) - defer cancel() - r.SetContext(ctx) - r.Limit(131072) w := c.MessageWriter(typ) w.SetContext(ctx) + _, err = io.Copy(w, r) if err != nil { return err @@ -139,8 +135,7 @@ func TestAutobahn(t *testing.T) { return nil } - l := rate.NewLimiter(rate.Every(time.Millisecond*100), 10) - for l.Allow() { + for { err := echo() if err != nil { t.Logf("%v: failed to echo message: %+v", time.Now(), err) @@ -162,7 +157,7 @@ func TestAutobahn(t *testing.T) { }, }, "cases": []string{"*"}, - "exclude-cases": []interface{}{}, + "exclude-cases": []string{"6.*", "12.*", "13.*"}, } specFile, err := ioutil.TempFile("", "websocket_fuzzingclient.json") if err != nil { @@ -216,7 +211,9 @@ func TestAutobahn(t *testing.T) { var failed bool for _, tests := range indexJSON { for test, result := range tests { - if result.Behavior != "OK" { + switch result.Behavior { + case "OK", "NON-STRICT", "INFORMATIONAL": + default: failed = true t.Errorf("test %v failed", test) }