diff --git a/.github/test/Dockerfile b/.github/test/Dockerfile index c6874db3..ec6c4769 100644 --- a/.github/test/Dockerfile +++ b/.github/test/Dockerfile @@ -5,7 +5,9 @@ LABEL "com.github.actions.description"="test" LABEL "com.github.actions.icon"="code" LABEL "com.github.actions.color"="purple" -RUN apt update && apt install -y shellcheck +RUN apt update && \ + apt install -y shellcheck python-pip && \ + pip install autobahntestsuite COPY entrypoint.sh /entrypoint.sh diff --git a/.gitignore b/.gitignore index 4383ca89..70d8e703 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ coverage.html +wstest_reports diff --git a/opcode.go b/opcode.go index de54fdb4..2469e781 100644 --- a/opcode.go +++ b/opcode.go @@ -10,7 +10,7 @@ const ( opText opBinary // 3 - 7 are reserved for further non-control frames. - opClose opcode = 8 + iota + opClose opcode = 8 + iota - 3 opPing opPong // 11-16 are reserved for further control frames. diff --git a/opcode_string.go b/opcode_string.go index e815cdea..740b5e70 100644 --- a/opcode_string.go +++ b/opcode_string.go @@ -11,9 +11,9 @@ func _() { _ = x[opContinuation-0] _ = x[opText-1] _ = x[opBinary-2] - _ = x[opClose-11] - _ = x[opPing-12] - _ = x[opPong-13] + _ = x[opClose-8] + _ = x[opPing-9] + _ = x[opPong-10] } const ( @@ -30,8 +30,8 @@ func (i opcode) String() string { switch { case 0 <= i && i <= 2: return _opcode_name_0[_opcode_index_0[i]:_opcode_index_0[i+1]] - case 11 <= i && i <= 13: - i -= 11 + case 8 <= i && i <= 10: + i -= 8 return _opcode_name_1[_opcode_index_1[i]:_opcode_index_1[i+1]] default: return "opcode(" + strconv.FormatInt(int64(i), 10) + ")" diff --git a/statuscode.go b/statuscode.go index 7efbfcd1..ed7e64b7 100644 --- a/statuscode.go +++ b/statuscode.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "math/bits" + "unicode/utf8" "golang.org/x/xerrors" ) @@ -21,7 +22,7 @@ const ( StatusProtocolError StatusUnsupportedData // 1004 is reserved. - StatusNoStatusRcvd StatusCode = 1005 + iota + StatusNoStatusRcvd StatusCode = 1005 + iota - 4 StatusAbnormalClosure StatusInvalidFramePayloadData StatusPolicyViolation @@ -53,9 +54,38 @@ func parseClosePayload(p []byte) (code StatusCode, reason string, err error) { code = StatusCode(binary.BigEndian.Uint16(p)) reason = string(p[2:]) + if !utf8.ValidString(reason) { + return 0, "", xerrors.Errorf("invalid utf-8: %q", reason) + } + if !isValidReceivedCloseCode(code) { + return 0, "", xerrors.Errorf("invalid code %v", code) + } + return code, reason, nil } +var validReceivedCloseCodes = map[StatusCode]bool{ + // see http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number + StatusNormalClosure: true, + StatusGoingAway: true, + StatusProtocolError: true, + StatusUnsupportedData: true, + StatusNoStatusRcvd: false, + StatusAbnormalClosure: false, + StatusInvalidFramePayloadData: true, + StatusPolicyViolation: true, + StatusMessageTooBig: true, + StatusMandatoryExtension: true, + StatusInternalError: true, + StatusServiceRestart: true, + StatusTryAgainLater: true, + StatusTLSHandshake: false, +} + +func isValidReceivedCloseCode(code StatusCode) bool { + return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) +} + const maxControlFramePayload = 125 func closePayload(code StatusCode, reason string) ([]byte, error) { diff --git a/statuscode_string.go b/statuscode_string.go index e1bfc462..fc8cea0d 100644 --- a/statuscode_string.go +++ b/statuscode_string.go @@ -12,17 +12,17 @@ func _() { _ = x[StatusGoingAway-1001] _ = x[StatusProtocolError-1002] _ = x[StatusUnsupportedData-1003] - _ = x[StatusNoStatusRcvd-1009] - _ = x[StatusAbnormalClosure-1010] - _ = x[StatusInvalidFramePayloadData-1011] - _ = x[StatusPolicyViolation-1012] - _ = x[StatusMessageTooBig-1013] - _ = x[StatusMandatoryExtension-1014] - _ = x[StatusInternalError-1015] - _ = x[StatusServiceRestart-1016] - _ = x[StatusTryAgainLater-1017] - _ = x[StatusBadGateway-1018] - _ = x[StatusTLSHandshake-1019] + _ = x[StatusNoStatusRcvd-1005] + _ = x[StatusAbnormalClosure-1006] + _ = x[StatusInvalidFramePayloadData-1007] + _ = x[StatusPolicyViolation-1008] + _ = x[StatusMessageTooBig-1009] + _ = x[StatusMandatoryExtension-1010] + _ = x[StatusInternalError-1011] + _ = x[StatusServiceRestart-1012] + _ = x[StatusTryAgainLater-1013] + _ = x[StatusBadGateway-1014] + _ = x[StatusTLSHandshake-1015] } const ( @@ -40,8 +40,8 @@ func (i StatusCode) String() string { case 1000 <= i && i <= 1003: i -= 1000 return _StatusCode_name_0[_StatusCode_index_0[i]:_StatusCode_index_0[i+1]] - case 1009 <= i && i <= 1019: - i -= 1009 + case 1005 <= i && i <= 1015: + i -= 1005 return _StatusCode_name_1[_StatusCode_index_1[i]:_StatusCode_index_1[i+1]] default: return "StatusCode(" + strconv.FormatInt(int64(i), 10) + ")" diff --git a/websocket.go b/websocket.go index 83ef0bb5..69781251 100644 --- a/websocket.go +++ b/websocket.go @@ -5,27 +5,28 @@ import ( "context" "fmt" "io" + "runtime" "sync" "time" "golang.org/x/xerrors" ) -type controlFrame struct { - header header - data []byte +type control struct { + opcode opcode + payload []byte } // Conn represents a WebSocket connection. // Pings will always be automatically responded to with pongs, you do not // have to do anything special. -// TODO set finalizer type Conn struct { subprotocol string br *bufio.Reader - bw *bufio.Writer - closer io.Closer - client bool + // TODO Cannot use bufio writer because for compression we need to know how much is buffered and compress it if large. + bw *bufio.Writer + closer io.Closer + client bool closeOnce sync.Once closeErr error @@ -34,16 +35,18 @@ type Conn struct { // Writers should send on write to begin sending // a message and then follow that up with some data // on writeBytes. - write chan opcode + write chan DataType + control chan control writeBytes chan []byte + writeDone chan struct{} // Readers should receive on read to begin reading a message. // Then send a byte slice to readBytes to read into it. - // A value on done will be sent once the read into a slice is complete. - // done will be closed when the message has been fully read. + // The n of bytes read will be sent on readDone once the read into a slice is complete. + // readDone will receive 0 when EOF is reached. read chan opcode readBytes chan []byte - readDone chan struct{} + readDone chan int } func (c *Conn) getCloseErr() error { @@ -59,6 +62,8 @@ func (c *Conn) close(err error) { } c.closeOnce.Do(func() { + runtime.SetFinalizer(c, nil) + c.closeErr = err cerr := c.closer.Close() @@ -78,23 +83,64 @@ func (c *Conn) Subprotocol() string { func (c *Conn) init() { c.closed = make(chan struct{}) - c.write = make(chan opcode) + c.write = make(chan DataType) + c.control = make(chan control) + c.writeDone = make(chan struct{}) c.read = make(chan opcode) + c.readDone = make(chan int) c.readBytes = make(chan []byte) + runtime.SetFinalizer(c, func(c *Conn) { + c.Close(StatusInternalError, "websocket: connection ended up being garbage collected") + }) + go c.writeLoop() go c.readLoop() } +func (c *Conn) writeFrame(h header, p []byte) { + b2 := marshalHeader(h) + _, err := c.bw.Write(b2) + if err != nil { + c.close(xerrors.Errorf("failed to write to connection: %v", err)) + return + } + + _, err = c.bw.Write(p) + if err != nil { + c.close(xerrors.Errorf("failed to write to connection: %v", err)) + return + } + + if h.opcode.controlOp() { + err := c.bw.Flush() + if err != nil { + c.close(xerrors.Errorf("failed to write to connection: %v", err)) + return + } + } +} + func (c *Conn) writeLoop() { messageLoop: for { c.writeBytes = make(chan []byte) - var opcode opcode + + var dataType DataType select { case <-c.closed: return - case opcode = <-c.write: + case dataType = <-c.write: + case control := <-c.control: + h := header{ + fin: true, + opcode: control.opcode, + payloadLength: int64(len(control.payload)), + masked: c.client, + } + c.writeFrame(h, control.payload) + c.writeDone <- struct{}{} + continue } var firstSent bool @@ -102,36 +148,20 @@ messageLoop: select { case <-c.closed: return - case b, ok := <-c.writeBytes: - if !ok { - if !opcode.controlOp() { - h := header{ - fin: true, - opcode: opContinuation, - masked: c.client, - } - b = marshalHeader(h) - _, err := c.bw.Write(b) - if err != nil { - c.close(xerrors.Errorf("failed to write to connection: %v", err)) - return - } - } - err := c.bw.Flush() - if err != nil { - c.close(xerrors.Errorf("failed to write to connection: %v", err)) - return - } - if opcode == opClose { - c.close(nil) - return - } - continue messageLoop + case control := <-c.control: + h := header{ + fin: true, + opcode: control.opcode, + payloadLength: int64(len(control.payload)), + masked: c.client, } - + c.writeFrame(h, control.payload) + c.writeDone <- struct{}{} + continue + case b, ok := <-c.writeBytes: h := header{ - fin: opcode.controlOp(), - opcode: opcode, + fin: !ok, + opcode: opcode(dataType), payloadLength: int64(len(b)), masked: c.client, } @@ -141,114 +171,168 @@ messageLoop: } firstSent = true - b2 := marshalHeader(h) - _, err := c.bw.Write(b2) - if err != nil { - c.close(xerrors.Errorf("failed to write to connection: %v", err)) - return + c.writeFrame(h, b) + + if !ok { + err := c.bw.Flush() + if err != nil { + c.close(xerrors.Errorf("failed to write to connection: %v", err)) + return + } } - _, err = c.bw.Write(b) - if err != nil { - c.close(xerrors.Errorf("failed to write to connection: %v", err)) + select { + case <-c.closed: return + case c.writeDone <- struct{}{}: + if ok { + continue + } else { + continue messageLoop + } } } } } } -func (c *Conn) readLoop() { - for { - h, err := readHeader(c.br) - if err != nil { - c.close(xerrors.Errorf("failed to read header: %v", err)) - return - } +func (c *Conn) handleControl(h header) { + if h.payloadLength > maxControlFramePayload { + c.Close(StatusProtocolError, "control frame too large") + return + } - switch h.opcode { - case opClose, opPing: - if h.payloadLength > maxControlFramePayload { - c.Close(StatusProtocolError, "control frame too large") - return - } - b := make([]byte, h.payloadLength) - _, err = io.ReadFull(c.br, b) - if err != nil { - c.close(xerrors.Errorf("failed to read control frame payload: %v", err)) - return - } + if !h.fin { + c.Close(StatusProtocolError, "control frame cannot be fragmented") + return + } - if h.opcode == opPing { - c.writePing(b) - continue - } + b := make([]byte, h.payloadLength) + _, err := io.ReadFull(c.br, b) + if err != nil { + c.close(xerrors.Errorf("failed to read control frame payload: %v", err)) + return + } + if h.masked { + mask(h.maskKey, 0, b) + } + + switch h.opcode { + case opPing: + c.writePong(b) + case opPong: + case opClose: + if len(b) > 0 { code, reason, err := parseClosePayload(b) if err != nil { - c.close(xerrors.Errorf("invalid close payload: %v", err)) + c.close(xerrors.Errorf("read invalid close payload: %v", err)) return } c.Close(code, reason) + } else { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + c.writeControl(ctx, opClose, nil) + c.close(nil) + } + default: + panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h)) + } +} + +func (c *Conn) readLoop() { + var indata bool + for { + h, err := readHeader(c.br) + if err != nil { + c.close(xerrors.Errorf("failed to read header: %v", err)) return } + if h.rsv1 || h.rsv2 || h.rsv3 { + c.Close(StatusProtocolError, fmt.Sprintf("read header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)) + return + } + + if h.opcode.controlOp() { + c.handleControl(h) + continue + } + switch h.opcode { case opBinary, opText: + if !indata { + select { + case <-c.closed: + return + case c.read <- h.opcode: + } + indata = true + } else { + c.Close(StatusProtocolError, "cannot send data frame when previous frame is not finished") + return + } + case opContinuation: + if !indata { + c.Close(StatusProtocolError, "continuation frame not after data or text frame") + return + } default: c.close(xerrors.Errorf("unexpected opcode in header: %#v", h)) return } - c.readDone = make(chan struct{}) - c.read <- h.opcode - for { - var maskPos int - left := h.payloadLength - for left > 0 { - select { - case <-c.closed: - return - case b := <-c.readBytes: - if int64(len(b)) > left { - b = b[:left] - } + maskPos := 0 + left := h.payloadLength + firstRead := false + for left > 0 || !firstRead { + select { + case <-c.closed: + return + case b := <-c.readBytes: + if int64(len(b)) > left { + b = b[:left] + } - _, err = io.ReadFull(c.br, b) - if err != nil { - c.close(xerrors.Errorf("failed to read from connection: %v", err)) - return - } - left -= int64(len(b)) + _, err = io.ReadFull(c.br, b) + if err != nil { + c.close(xerrors.Errorf("failed to read from connection: %v", err)) + return + } + left -= int64(len(b)) - if h.masked { - maskPos = mask(h.maskKey, maskPos, b) - } + if h.masked { + maskPos = mask(h.maskKey, maskPos, b) + } - select { - case <-c.closed: - return - case c.readDone <- struct{}{}: - } + select { + case <-c.closed: + return + case c.readDone <- len(b): + firstRead = true } } + } - if h.fin { - break - } - h, err = readHeader(c.br) - if err != nil { - c.close(xerrors.Errorf("failed to read header: %v", err)) + if h.fin { + indata = false + select { + case <-c.closed: return + case c.readDone <- 0: } - // TODO check opcode. } - close(c.readDone) } } -func (c *Conn) writePing(p []byte) { - panic("TODO") +func (c *Conn) writePong(p []byte) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + err := c.writeControl(ctx, opPong, p) + return err } // MessageWriter returns a writer bounded by the context that will write @@ -256,14 +340,10 @@ func (c *Conn) writePing(p []byte) { // Ensure you close the MessageWriter once you have written to entire message. // Concurrent calls to MessageWriter are ok. func (c *Conn) MessageWriter(dataType DataType) *MessageWriter { - return c.messageWriter(opcode(dataType)) -} - -func (c *Conn) messageWriter(opcode opcode) *MessageWriter { return &MessageWriter{ - c: c, - ctx: context.Background(), - opcode: opcode, + c: c, + ctx: context.Background(), + datatype: dataType, } } @@ -275,14 +355,14 @@ func (c *Conn) messageWriter(opcode opcode) *MessageWriter { func (c *Conn) ReadMessage(ctx context.Context) (DataType, *MessageReader, error) { select { case <-c.closed: - return 0, nil, c.getCloseErr() + return 0, nil, xerrors.Errorf("failed to read message: %v", c.getCloseErr()) case opcode := <-c.read: return DataType(opcode), &MessageReader{ ctx: context.Background(), c: c, }, nil case <-ctx.Done(): - return 0, nil, ctx.Err() + return 0, nil, xerrors.Errorf("failed to read message: %v", ctx.Err()) } } @@ -301,38 +381,46 @@ func (c *Conn) Close(code StatusCode, reason string) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - select { - case <-c.closed: - return c.getCloseErr() - case c.write <- opClose: - case <-ctx.Done(): - c.close(xerrors.New("force closed: close frame write timed out")) + err = c.writeControl(ctx, opClose, p) + if err != nil { + return err } + c.close(nil) + + if err != nil { + return err + } + return c.closeErr +} + +func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { select { case <-c.closed: return c.getCloseErr() - case c.writeBytes <- p: - close(c.writeBytes) + case c.control <- control{ + opcode: opcode, + payload: p, + }: case <-ctx.Done(): c.close(xerrors.New("force closed: close frame write timed out")) + return c.getCloseErr() } select { case <-c.closed: + return c.getCloseErr() + case <-c.writeDone: + return nil case <-ctx.Done(): - c.close(xerrors.New("force closed: close frame write timed out")) - } - if err != nil { - return err + return ctx.Err() } - return c.closeErr } // MessageWriter enables writing to a WebSocket connection. // Ensure you close the MessageWriter once you have written to entire message. type MessageWriter struct { - opcode opcode + datatype DataType ctx context.Context c *Conn acquiredLock bool @@ -350,7 +438,7 @@ func (w *MessageWriter) Write(p []byte) (int, error) { select { case <-w.c.closed: return 0, w.c.getCloseErr() - case w.c.write <- w.opcode: + case w.c.write <- w.datatype: w.acquiredLock = true case <-w.ctx.Done(): return 0, w.ctx.Err() @@ -361,7 +449,14 @@ func (w *MessageWriter) Write(p []byte) (int, error) { case <-w.c.closed: return 0, w.c.getCloseErr() case w.c.writeBytes <- p: - return len(p), nil + select { + case <-w.c.closed: + return 0, w.c.getCloseErr() + case <-w.c.writeDone: + return len(p), nil + case <-w.ctx.Done(): + return 0, w.ctx.Err() + } case <-w.ctx.Done(): return 0, w.ctx.Err() } @@ -376,9 +471,17 @@ func (w *MessageWriter) SetContext(ctx context.Context) { // This must be called for every MessageWriter. func (w *MessageWriter) Close() error { if !w.acquiredLock { - return xerrors.New("websocket: MessageWriter closed without writing any bytes") + select { + case <-w.c.closed: + return w.c.getCloseErr() + case w.c.write <- w.datatype: + w.acquiredLock = true + case <-w.ctx.Done(): + return w.ctx.Err() + } } close(w.c.writeBytes) + <-w.c.writeDone return nil } @@ -413,13 +516,13 @@ func (r *MessageReader) Read(p []byte) (n int, err error) { select { case <-r.c.closed: return 0, r.c.getCloseErr() - case <-r.c.readDone: - r.n += len(p) + case n := <-r.c.readDone: + r.n += n // TODO make this better later and inside readLoop to prevent the read from actually occuring if over limit. - if r.limit > 0 && n > r.limit { + if r.limit > 0 && r.n > r.limit { return 0, xerrors.New("message too big") } - return len(p), nil + return n, nil case <-r.ctx.Done(): return 0, r.ctx.Err() } diff --git a/websocket_test.go b/websocket_test.go index 06187075..b4b19c62 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -2,8 +2,14 @@ package websocket_test import ( "context" + "encoding/json" + "io" + "io/ioutil" "net/http" "net/http/httptest" + "os" + "os/exec" + "strings" "testing" "time" @@ -40,7 +46,6 @@ func TestConnection(t *testing.T) { return } - t.Log("success", v) obj <- v c.Close(websocket.StatusNormalClosure, "") @@ -87,3 +92,139 @@ func TestConnection(t *testing.T) { t.Fatalf("test timed out") } } + +func TestAutobahn(t *testing.T) { + t.Parallel() + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, + websocket.AcceptSubprotocols("echo"), + ) + if err != nil { + t.Logf("server handshake failed: %v", err) + return + } + defer c.Close(websocket.StatusInternalError, "") + + ctx := context.Background() + + echo := func() error { + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + + typ, r, err := c.ReadMessage(ctx) + if err != nil { + return err + } + + r.SetContext(ctx) + + w := c.MessageWriter(typ) + w.SetContext(ctx) + + _, err = io.Copy(w, r) + if err != nil { + return err + } + + err = w.Close() + if err != nil { + return err + } + + return nil + } + + for { + err := echo() + if err != nil { + t.Logf("%v: failed to echo message: %+v", time.Now(), err) + return + } + } + })) + defer s.Close() + + spec := map[string]interface{}{ + "outdir": "wstest_reports/server", + "servers": []interface{}{ + map[string]interface{}{ + "agent": "main", + "url": strings.Replace(s.URL, "http", "ws", 1), + "options": map[string]interface{}{ + "version": 18, + }, + }, + }, + "cases": []string{"*"}, + "exclude-cases": []string{"6.*", "12.*", "13.*"}, + } + specFile, err := ioutil.TempFile("", "websocket_fuzzingclient.json") + if err != nil { + t.Fatalf("failed to create temp file for fuzzingclient.json: %v", err) + } + defer specFile.Close() + + e := json.NewEncoder(specFile) + e.SetIndent("", "\t") + err = e.Encode(spec) + if err != nil { + t.Fatalf("failed to write spec: %v", err) + } + + err = specFile.Close() + if err != nil { + t.Fatalf("failed to close file: %v", err) + } + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Minute*10) + defer cancel() + + args := []string{"--mode", "fuzzingclient", "--spec", specFile.Name()} + if os.Getenv("CI") == "" { + args = append([]string{"--debug"}, args...) + } + wstest := exec.CommandContext(ctx, "wstest", args...) + out, err := wstest.CombinedOutput() + if err != nil { + t.Fatalf("failed to run wstest: %v\nout:\n%s", err, out) + } + + b, err := ioutil.ReadFile("./wstest_reports/server/index.json") + if err != nil { + t.Fatalf("failed to read index.json: %v", err) + } + + if testing.Verbose() { + t.Logf("output: %s", out) + } + + var indexJSON map[string]map[string]struct { + Behavior string `json:"behavior"` + } + err = json.Unmarshal(b, &indexJSON) + if err != nil { + t.Fatalf("failed to unmarshal index.json: %v", err) + } + + var failed bool + for _, tests := range indexJSON { + for test, result := range tests { + switch result.Behavior { + case "OK", "NON-STRICT", "INFORMATIONAL": + default: + failed = true + t.Errorf("test %v failed", test) + } + } + } + + if failed { + if os.Getenv("CI") == "" { + t.Errorf("wstest found failure, please see ./wstest_reports/server/index.html") + } else { + t.Errorf("wstest found failure, please run test.sh locally to see ./wstest_reports/server/index.html") + } + } +}