Skip to content

Commit d522d62

Browse files
committed
Server selects first acceptable compression offer
Unacceptable offers are declined without rejecting the request.
1 parent cc2d7bd commit d522d62

File tree

4 files changed

+88
-118
lines changed

4 files changed

+88
-118
lines changed

Diff for: accept.go

+17-60
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
118118
w.Header().Set("Sec-WebSocket-Protocol", subproto)
119119
}
120120

121-
copts, err := acceptCompression(r, w, opts.CompressionMode)
122-
if err != nil {
123-
return nil, err
121+
copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode)
122+
if ok {
123+
w.Header().Set("Sec-WebSocket-Extensions", copts.String())
124124
}
125125

126126
w.WriteHeader(http.StatusSwitchingProtocols)
@@ -230,26 +230,23 @@ func selectSubprotocol(r *http.Request, subprotocols []string) string {
230230
return ""
231231
}
232232

233-
func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) {
233+
func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
234234
if mode == CompressionDisabled {
235-
return nil, nil
235+
return nil, false
236236
}
237-
238-
for _, ext := range websocketExtensions(r.Header) {
237+
for _, ext := range extensions {
239238
switch ext.name {
240239
case "permessage-deflate":
241-
return acceptDeflate(w, ext, mode)
242-
// Disabled for now, see https://github.com/nhooyr/websocket/issues/218
243-
// case "x-webkit-deflate-frame":
244-
// return acceptWebkitDeflate(w, ext, mode)
240+
if copts, ok := acceptDeflate(ext, mode); ok {
241+
return copts, true
242+
}
245243
}
246244
}
247-
return nil, nil
245+
return nil, false
248246
}
249247

250-
func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
248+
func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
251249
copts := mode.opts()
252-
253250
for _, p := range ext.params {
254251
switch p {
255252
case "client_no_context_takeover":
@@ -258,57 +255,17 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
258255
case "server_no_context_takeover":
259256
copts.serverNoContextTakeover = true
260257
continue
261-
case "server_max_window_bits=15":
262-
continue
263-
}
264-
265-
if strings.HasPrefix(p, "client_max_window_bits") {
266-
// We cannot adjust the read sliding window so cannot make use of this.
258+
case "client_max_window_bits",
259+
"server_max_window_bits=15":
267260
continue
268261
}
269-
270-
err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
271-
http.Error(w, err.Error(), http.StatusBadRequest)
272-
return nil, err
273-
}
274-
275-
copts.setHeader(w.Header())
276-
277-
return copts, nil
278-
}
279-
280-
func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
281-
copts := mode.opts()
282-
// The peer must explicitly request it.
283-
copts.serverNoContextTakeover = false
284-
285-
for _, p := range ext.params {
286-
if p == "no_context_takeover" {
287-
copts.serverNoContextTakeover = true
262+
if strings.HasPrefix(p, "client_max_window_bits=") {
263+
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
288264
continue
289265
}
290-
291-
// We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead
292-
// of ignoring it as the draft spec is unclear. It says the server can ignore it
293-
// but the server has no way of signalling to the client it was ignored as the parameters
294-
// are set one way.
295-
// Thus us ignoring it would make the client think we understood it which would cause issues.
296-
// See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1
297-
//
298-
// Either way, we're only implementing this for webkit which never sends the max_window_bits
299-
// parameter so we don't need to worry about it.
300-
err := fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p)
301-
http.Error(w, err.Error(), http.StatusBadRequest)
302-
return nil, err
266+
return nil, false
303267
}
304-
305-
s := "x-webkit-deflate-frame"
306-
if copts.clientNoContextTakeover {
307-
s += "; no_context_takeover"
308-
}
309-
w.Header().Set("Sec-WebSocket-Extensions", s)
310-
311-
return copts, nil
268+
return copts, true
312269
}
313270

314271
func headerContainsToken(h http.Header, key, token string) bool {

Diff for: accept_test.go

+68-54
Original file line numberDiff line numberDiff line change
@@ -45,20 +45,47 @@ func TestAccept(t *testing.T) {
4545
t.Run("badCompression", func(t *testing.T) {
4646
t.Parallel()
4747

48-
w := mockHijacker{
49-
ResponseWriter: httptest.NewRecorder(),
48+
newRequest := func(extensions string) *http.Request {
49+
r := httptest.NewRequest("GET", "/", nil)
50+
r.Header.Set("Connection", "Upgrade")
51+
r.Header.Set("Upgrade", "websocket")
52+
r.Header.Set("Sec-WebSocket-Version", "13")
53+
r.Header.Set("Sec-WebSocket-Key", "meow123")
54+
r.Header.Set("Sec-WebSocket-Extensions", extensions)
55+
return r
56+
}
57+
newResponseWriter := func() http.ResponseWriter {
58+
return mockHijacker{
59+
ResponseWriter: httptest.NewRecorder(),
60+
hijack: func() (net.Conn, *bufio.ReadWriter, error) {
61+
return nil, nil, errors.New("hijack error")
62+
},
63+
}
5064
}
51-
r := httptest.NewRequest("GET", "/", nil)
52-
r.Header.Set("Connection", "Upgrade")
53-
r.Header.Set("Upgrade", "websocket")
54-
r.Header.Set("Sec-WebSocket-Version", "13")
55-
r.Header.Set("Sec-WebSocket-Key", "meow123")
56-
r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar")
5765

58-
_, err := Accept(w, r, &AcceptOptions{
59-
CompressionMode: CompressionContextTakeover,
66+
t.Run("withoutFallback", func(t *testing.T) {
67+
t.Parallel()
68+
69+
w := newResponseWriter()
70+
r := newRequest("permessage-deflate; harharhar")
71+
_, _ = Accept(w, r, &AcceptOptions{
72+
CompressionMode: CompressionNoContextTakeover,
73+
})
74+
assert.Equal(t, "extension header", w.Header().Get("Sec-WebSocket-Extensions"), "")
75+
})
76+
t.Run("withFallback", func(t *testing.T) {
77+
t.Parallel()
78+
79+
w := newResponseWriter()
80+
r := newRequest("permessage-deflate; harharhar, permessage-deflate")
81+
_, _ = Accept(w, r, &AcceptOptions{
82+
CompressionMode: CompressionNoContextTakeover,
83+
})
84+
assert.Equal(t, "extension header",
85+
w.Header().Get("Sec-WebSocket-Extensions"),
86+
CompressionNoContextTakeover.opts().String(),
87+
)
6088
})
61-
assert.Contains(t, err, `unsupported permessage-deflate parameter`)
6289
})
6390

6491
t.Run("requireHttpHijacker", func(t *testing.T) {
@@ -321,79 +348,66 @@ func Test_authenticateOrigin(t *testing.T) {
321348
}
322349
}
323350

324-
func Test_acceptCompression(t *testing.T) {
351+
func Test_selectDeflate(t *testing.T) {
325352
t.Parallel()
326353

327354
testCases := []struct {
328-
name string
329-
mode CompressionMode
330-
reqSecWebSocketExtensions string
331-
respSecWebSocketExtensions string
332-
expCopts *compressionOptions
333-
error bool
355+
name string
356+
mode CompressionMode
357+
header string
358+
expCopts *compressionOptions
359+
expOK bool
334360
}{
335361
{
336362
name: "disabled",
337363
mode: CompressionDisabled,
338364
expCopts: nil,
365+
expOK: false,
339366
},
340367
{
341368
name: "noClientSupport",
342369
mode: CompressionNoContextTakeover,
343370
expCopts: nil,
371+
expOK: false,
344372
},
345373
{
346-
name: "permessage-deflate",
347-
mode: CompressionNoContextTakeover,
348-
reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits",
349-
respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover",
374+
name: "permessage-deflate",
375+
mode: CompressionNoContextTakeover,
376+
header: "permessage-deflate; client_max_window_bits",
350377
expCopts: &compressionOptions{
351378
clientNoContextTakeover: true,
352379
serverNoContextTakeover: true,
353380
},
381+
expOK: true,
382+
},
383+
{
384+
name: "permessage-deflate/unknown-parameter",
385+
mode: CompressionNoContextTakeover,
386+
header: "permessage-deflate; meow",
387+
expOK: false,
354388
},
355389
{
356-
name: "permessage-deflate/error",
357-
mode: CompressionNoContextTakeover,
358-
reqSecWebSocketExtensions: "permessage-deflate; meow",
359-
error: true,
390+
name: "permessage-deflate/unknown-parameter",
391+
mode: CompressionNoContextTakeover,
392+
header: "permessage-deflate; meow, permessage-deflate; client_max_window_bits",
393+
expCopts: &compressionOptions{
394+
clientNoContextTakeover: true,
395+
serverNoContextTakeover: true,
396+
},
397+
expOK: true,
360398
},
361-
// {
362-
// name: "x-webkit-deflate-frame",
363-
// mode: CompressionNoContextTakeover,
364-
// reqSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
365-
// respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
366-
// expCopts: &compressionOptions{
367-
// clientNoContextTakeover: true,
368-
// serverNoContextTakeover: true,
369-
// },
370-
// },
371-
// {
372-
// name: "x-webkit-deflate/error",
373-
// mode: CompressionNoContextTakeover,
374-
// reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits",
375-
// error: true,
376-
// },
377399
}
378400

379401
for _, tc := range testCases {
380402
tc := tc
381403
t.Run(tc.name, func(t *testing.T) {
382404
t.Parallel()
383405

384-
r := httptest.NewRequest(http.MethodGet, "/", nil)
385-
r.Header.Set("Sec-WebSocket-Extensions", tc.reqSecWebSocketExtensions)
386-
387-
w := httptest.NewRecorder()
388-
copts, err := acceptCompression(r, w, tc.mode)
389-
if tc.error {
390-
assert.Error(t, err)
391-
return
392-
}
393-
394-
assert.Success(t, err)
406+
h := http.Header{}
407+
h.Set("Sec-WebSocket-Extensions", tc.header)
408+
copts, ok := selectDeflate(websocketExtensions(h), tc.mode)
409+
assert.Equal(t, "selected options", tc.expOK, ok)
395410
assert.Equal(t, "compression options", tc.expCopts, copts)
396-
assert.Equal(t, "Sec-WebSocket-Extensions", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions"))
397411
})
398412
}
399413
}

Diff for: compress.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ package websocket
55
import (
66
"compress/flate"
77
"io"
8-
"net/http"
98
"sync"
109
)
1110

@@ -58,15 +57,15 @@ type compressionOptions struct {
5857
serverNoContextTakeover bool
5958
}
6059

61-
func (copts *compressionOptions) setHeader(h http.Header) {
60+
func (copts *compressionOptions) String() string {
6261
s := "permessage-deflate"
6362
if copts.clientNoContextTakeover {
6463
s += "; client_no_context_takeover"
6564
}
6665
if copts.serverNoContextTakeover {
6766
s += "; server_no_context_takeover"
6867
}
69-
h.Set("Sec-WebSocket-Extensions", s)
68+
return s
7069
}
7170

7271
// These bytes are required to get flate.Reader to return.

Diff for: dial.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts
162162
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
163163
}
164164
if copts != nil {
165-
copts.setHeader(req.Header)
165+
req.Header.Set("Sec-WebSocket-Extensions", copts.String())
166166
}
167167

168168
resp, err := opts.HTTPClient.Do(req)

0 commit comments

Comments
 (0)