Skip to content

Commit 97d7f90

Browse files
authored
Merge pull request #258 from abursavich/compression
Make compression negotiation more lenient
2 parents 81afa8a + d6b342b commit 97d7f90

File tree

6 files changed

+109
-91
lines changed

6 files changed

+109
-91
lines changed

Diff for: accept.go

+19-22
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,9 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
123123
w.Header().Set("Sec-WebSocket-Protocol", subproto)
124124
}
125125

126-
copts, err := acceptCompression(r, w, opts.CompressionMode)
127-
if err != nil {
128-
return nil, err
126+
copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode)
127+
if ok {
128+
w.Header().Set("Sec-WebSocket-Extensions", copts.String())
129129
}
130130

131131
w.WriteHeader(http.StatusSwitchingProtocols)
@@ -238,25 +238,26 @@ func selectSubprotocol(r *http.Request, subprotocols []string) string {
238238
return ""
239239
}
240240

241-
func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) {
241+
func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
242242
if mode == CompressionDisabled {
243-
return nil, nil
243+
return nil, false
244244
}
245-
246-
for _, ext := range websocketExtensions(r.Header) {
245+
for _, ext := range extensions {
247246
switch ext.name {
248247
// We used to implement x-webkit-deflate-fram too but Safari has bugs.
249248
// See https://github.com/nhooyr/websocket/issues/218
250249
case "permessage-deflate":
251-
return acceptDeflate(w, ext, mode)
250+
copts, ok := acceptDeflate(ext, mode)
251+
if ok {
252+
return copts, true
253+
}
252254
}
253255
}
254-
return nil, nil
256+
return nil, false
255257
}
256258

257-
func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
259+
func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
258260
copts := mode.opts()
259-
260261
for _, p := range ext.params {
261262
switch p {
262263
case "client_no_context_takeover":
@@ -265,22 +266,18 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
265266
case "server_no_context_takeover":
266267
copts.serverNoContextTakeover = true
267268
continue
269+
case "client_max_window_bits",
270+
"server_max_window_bits=15":
271+
continue
268272
}
269273

270-
if strings.HasPrefix(p, "client_max_window_bits") {
271-
// We cannot adjust the read sliding window so cannot make use of this.
272-
// By not responding to it, we tell the client we're ignoring it.
274+
if strings.HasPrefix(p, "client_max_window_bits=") {
275+
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
273276
continue
274277
}
275-
276-
err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
277-
http.Error(w, err.Error(), http.StatusBadRequest)
278-
return nil, err
278+
return nil, false
279279
}
280-
281-
copts.setHeader(w.Header())
282-
283-
return copts, nil
280+
return copts, true
284281
}
285282

286283
func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool {

Diff for: accept_test.go

+71-54
Original file line numberDiff line numberDiff line change
@@ -62,20 +62,50 @@ func TestAccept(t *testing.T) {
6262
t.Run("badCompression", func(t *testing.T) {
6363
t.Parallel()
6464

65-
w := mockHijacker{
66-
ResponseWriter: httptest.NewRecorder(),
65+
newRequest := func(extensions string) *http.Request {
66+
r := httptest.NewRequest("GET", "/", nil)
67+
r.Header.Set("Connection", "Upgrade")
68+
r.Header.Set("Upgrade", "websocket")
69+
r.Header.Set("Sec-WebSocket-Version", "13")
70+
r.Header.Set("Sec-WebSocket-Key", "meow123")
71+
r.Header.Set("Sec-WebSocket-Extensions", extensions)
72+
return r
73+
}
74+
errHijack := errors.New("hijack error")
75+
newResponseWriter := func() http.ResponseWriter {
76+
return mockHijacker{
77+
ResponseWriter: httptest.NewRecorder(),
78+
hijack: func() (net.Conn, *bufio.ReadWriter, error) {
79+
return nil, nil, errHijack
80+
},
81+
}
6782
}
68-
r := httptest.NewRequest("GET", "/", nil)
69-
r.Header.Set("Connection", "Upgrade")
70-
r.Header.Set("Upgrade", "websocket")
71-
r.Header.Set("Sec-WebSocket-Version", "13")
72-
r.Header.Set("Sec-WebSocket-Key", "meow123")
73-
r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar")
7483

75-
_, err := Accept(w, r, &AcceptOptions{
76-
CompressionMode: CompressionContextTakeover,
84+
t.Run("withoutFallback", func(t *testing.T) {
85+
t.Parallel()
86+
87+
w := newResponseWriter()
88+
r := newRequest("permessage-deflate; harharhar")
89+
_, err := Accept(w, r, &AcceptOptions{
90+
CompressionMode: CompressionNoContextTakeover,
91+
})
92+
assert.ErrorIs(t, errHijack, err)
93+
assert.Equal(t, "extension header", w.Header().Get("Sec-WebSocket-Extensions"), "")
94+
})
95+
t.Run("withFallback", func(t *testing.T) {
96+
t.Parallel()
97+
98+
w := newResponseWriter()
99+
r := newRequest("permessage-deflate; harharhar, permessage-deflate")
100+
_, err := Accept(w, r, &AcceptOptions{
101+
CompressionMode: CompressionNoContextTakeover,
102+
})
103+
assert.ErrorIs(t, errHijack, err)
104+
assert.Equal(t, "extension header",
105+
w.Header().Get("Sec-WebSocket-Extensions"),
106+
CompressionNoContextTakeover.opts().String(),
107+
)
77108
})
78-
assert.Contains(t, err, `unsupported permessage-deflate parameter`)
79109
})
80110

81111
t.Run("requireHttpHijacker", func(t *testing.T) {
@@ -344,79 +374,66 @@ func Test_authenticateOrigin(t *testing.T) {
344374
}
345375
}
346376

347-
func Test_acceptCompression(t *testing.T) {
377+
func Test_selectDeflate(t *testing.T) {
348378
t.Parallel()
349379

350380
testCases := []struct {
351-
name string
352-
mode CompressionMode
353-
reqSecWebSocketExtensions string
354-
respSecWebSocketExtensions string
355-
expCopts *compressionOptions
356-
error bool
381+
name string
382+
mode CompressionMode
383+
header string
384+
expCopts *compressionOptions
385+
expOK bool
357386
}{
358387
{
359388
name: "disabled",
360389
mode: CompressionDisabled,
361390
expCopts: nil,
391+
expOK: false,
362392
},
363393
{
364394
name: "noClientSupport",
365395
mode: CompressionNoContextTakeover,
366396
expCopts: nil,
397+
expOK: false,
367398
},
368399
{
369-
name: "permessage-deflate",
370-
mode: CompressionNoContextTakeover,
371-
reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits",
372-
respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover",
400+
name: "permessage-deflate",
401+
mode: CompressionNoContextTakeover,
402+
header: "permessage-deflate; client_max_window_bits",
373403
expCopts: &compressionOptions{
374404
clientNoContextTakeover: true,
375405
serverNoContextTakeover: true,
376406
},
407+
expOK: true,
408+
},
409+
{
410+
name: "permessage-deflate/unknown-parameter",
411+
mode: CompressionNoContextTakeover,
412+
header: "permessage-deflate; meow",
413+
expOK: false,
377414
},
378415
{
379-
name: "permessage-deflate/error",
380-
mode: CompressionNoContextTakeover,
381-
reqSecWebSocketExtensions: "permessage-deflate; meow",
382-
error: true,
416+
name: "permessage-deflate/unknown-parameter",
417+
mode: CompressionNoContextTakeover,
418+
header: "permessage-deflate; meow, permessage-deflate; client_max_window_bits",
419+
expCopts: &compressionOptions{
420+
clientNoContextTakeover: true,
421+
serverNoContextTakeover: true,
422+
},
423+
expOK: true,
383424
},
384-
// {
385-
// name: "x-webkit-deflate-frame",
386-
// mode: CompressionNoContextTakeover,
387-
// reqSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
388-
// respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
389-
// expCopts: &compressionOptions{
390-
// clientNoContextTakeover: true,
391-
// serverNoContextTakeover: true,
392-
// },
393-
// },
394-
// {
395-
// name: "x-webkit-deflate/error",
396-
// mode: CompressionNoContextTakeover,
397-
// reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits",
398-
// error: true,
399-
// },
400425
}
401426

402427
for _, tc := range testCases {
403428
tc := tc
404429
t.Run(tc.name, func(t *testing.T) {
405430
t.Parallel()
406431

407-
r := httptest.NewRequest(http.MethodGet, "/", nil)
408-
r.Header.Set("Sec-WebSocket-Extensions", tc.reqSecWebSocketExtensions)
409-
410-
w := httptest.NewRecorder()
411-
copts, err := acceptCompression(r, w, tc.mode)
412-
if tc.error {
413-
assert.Error(t, err)
414-
return
415-
}
416-
417-
assert.Success(t, err)
432+
h := http.Header{}
433+
h.Set("Sec-WebSocket-Extensions", tc.header)
434+
copts, ok := selectDeflate(websocketExtensions(h), tc.mode)
435+
assert.Equal(t, "selected options", tc.expOK, ok)
418436
assert.Equal(t, "compression options", tc.expCopts, copts)
419-
assert.Equal(t, "Sec-WebSocket-Extensions", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions"))
420437
})
421438
}
422439
}

Diff for: compress.go

+3-8
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,13 @@ package websocket
66
import (
77
"compress/flate"
88
"io"
9-
"net/http"
109
"sync"
1110
)
1211

1312
// CompressionMode represents the modes available to the deflate extension.
1413
// See https://tools.ietf.org/html/rfc7692
1514
//
16-
// A compatibility layer is implemented for the older deflate-frame extension used
17-
// by safari. See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06
18-
// It will work the same in every way except that we cannot signal to the peer we
19-
// want to use no context takeover on our side, we can only signal that they should.
20-
// But it is currently disabled due to Safari bugs. See https://github.com/nhooyr/websocket/issues/218
15+
// Works in all browsers except Safari which does not implement the deflate extension.
2116
type CompressionMode int
2217

2318
const (
@@ -65,15 +60,15 @@ type compressionOptions struct {
6560
serverNoContextTakeover bool
6661
}
6762

68-
func (copts *compressionOptions) setHeader(h http.Header) {
63+
func (copts *compressionOptions) String() string {
6964
s := "permessage-deflate"
7065
if copts.clientNoContextTakeover {
7166
s += "; client_no_context_takeover"
7267
}
7368
if copts.serverNoContextTakeover {
7469
s += "; server_no_context_takeover"
7570
}
76-
h.Set("Sec-WebSocket-Extensions", s)
71+
return s
7772
}
7873

7974
// These bytes are required to get flate.Reader to return.

Diff for: dial.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts
185185
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
186186
}
187187
if copts != nil {
188-
copts.setHeader(req.Header)
188+
req.Header.Set("Sec-WebSocket-Extensions", copts.String())
189189
}
190190

191191
resp, err := opts.HTTPClient.Do(req)
@@ -273,6 +273,10 @@ func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compress
273273
copts.serverNoContextTakeover = true
274274
continue
275275
}
276+
if strings.HasPrefix(p, "server_max_window_bits=") {
277+
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
278+
continue
279+
}
276280

277281
return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
278282
}

Diff for: internal/test/assert/assert.go

+10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package assert
22

33
import (
4+
"errors"
45
"fmt"
56
"reflect"
67
"strings"
@@ -43,3 +44,12 @@ func Contains(t testing.TB, v interface{}, sub string) {
4344
t.Fatalf("expected %q to contain %q", s, sub)
4445
}
4546
}
47+
48+
// ErrorIs asserts errors.Is(got, exp)
49+
func ErrorIs(t testing.TB, exp, got error) {
50+
t.Helper()
51+
52+
if !errors.Is(got, exp) {
53+
t.Fatalf("expected %v but got %v", exp, got)
54+
}
55+
}

Diff for: ws_js.go

+1-6
Original file line numberDiff line numberDiff line change
@@ -485,12 +485,7 @@ func CloseStatus(err error) StatusCode {
485485

486486
// CompressionMode represents the modes available to the deflate extension.
487487
// See https://tools.ietf.org/html/rfc7692
488-
//
489-
// A compatibility layer is implemented for the older deflate-frame extension used
490-
// by safari. See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06
491-
// It will work the same in every way except that we cannot signal to the peer we
492-
// want to use no context takeover on our side, we can only signal that they should.
493-
// It is however currently disabled due to Safari bugs. See https://github.com/nhooyr/websocket/issues/218
488+
// Works in all browsers except Safari which does not implement the deflate extension.
494489
type CompressionMode int
495490

496491
const (

0 commit comments

Comments
 (0)