@@ -45,20 +45,47 @@ func TestAccept(t *testing.T) {
45
45
t .Run ("badCompression" , func (t * testing.T ) {
46
46
t .Parallel ()
47
47
48
- w := mockHijacker {
49
- ResponseWriter : httptest .NewRecorder (),
48
+ newRequest := func (extensions string ) * http.Request {
49
+ r := httptest .NewRequest ("GET" , "/" , nil )
50
+ r .Header .Set ("Connection" , "Upgrade" )
51
+ r .Header .Set ("Upgrade" , "websocket" )
52
+ r .Header .Set ("Sec-WebSocket-Version" , "13" )
53
+ r .Header .Set ("Sec-WebSocket-Key" , "meow123" )
54
+ r .Header .Set ("Sec-WebSocket-Extensions" , extensions )
55
+ return r
56
+ }
57
+ newResponseWriter := func () http.ResponseWriter {
58
+ return mockHijacker {
59
+ ResponseWriter : httptest .NewRecorder (),
60
+ hijack : func () (net.Conn , * bufio.ReadWriter , error ) {
61
+ return nil , nil , errors .New ("hijack error" )
62
+ },
63
+ }
50
64
}
51
- r := httptest .NewRequest ("GET" , "/" , nil )
52
- r .Header .Set ("Connection" , "Upgrade" )
53
- r .Header .Set ("Upgrade" , "websocket" )
54
- r .Header .Set ("Sec-WebSocket-Version" , "13" )
55
- r .Header .Set ("Sec-WebSocket-Key" , "meow123" )
56
- r .Header .Set ("Sec-WebSocket-Extensions" , "permessage-deflate; harharhar" )
57
65
58
- _ , err := Accept (w , r , & AcceptOptions {
59
- CompressionMode : CompressionContextTakeover ,
66
+ t .Run ("withoutFallback" , func (t * testing.T ) {
67
+ t .Parallel ()
68
+
69
+ w := newResponseWriter ()
70
+ r := newRequest ("permessage-deflate; harharhar" )
71
+ _ , _ = Accept (w , r , & AcceptOptions {
72
+ CompressionMode : CompressionNoContextTakeover ,
73
+ })
74
+ assert .Equal (t , "extension header" , w .Header ().Get ("Sec-WebSocket-Extensions" ), "" )
75
+ })
76
+ t .Run ("withFallback" , func (t * testing.T ) {
77
+ t .Parallel ()
78
+
79
+ w := newResponseWriter ()
80
+ r := newRequest ("permessage-deflate; harharhar, permessage-deflate" )
81
+ _ , _ = Accept (w , r , & AcceptOptions {
82
+ CompressionMode : CompressionNoContextTakeover ,
83
+ })
84
+ assert .Equal (t , "extension header" ,
85
+ w .Header ().Get ("Sec-WebSocket-Extensions" ),
86
+ CompressionNoContextTakeover .opts ().String (),
87
+ )
60
88
})
61
- assert .Contains (t , err , `unsupported permessage-deflate parameter` )
62
89
})
63
90
64
91
t .Run ("requireHttpHijacker" , func (t * testing.T ) {
@@ -321,79 +348,66 @@ func Test_authenticateOrigin(t *testing.T) {
321
348
}
322
349
}
323
350
324
- func Test_acceptCompression (t * testing.T ) {
351
+ func Test_selectDeflate (t * testing.T ) {
325
352
t .Parallel ()
326
353
327
354
testCases := []struct {
328
- name string
329
- mode CompressionMode
330
- reqSecWebSocketExtensions string
331
- respSecWebSocketExtensions string
332
- expCopts * compressionOptions
333
- error bool
355
+ name string
356
+ mode CompressionMode
357
+ header string
358
+ expCopts * compressionOptions
359
+ expOK bool
334
360
}{
335
361
{
336
362
name : "disabled" ,
337
363
mode : CompressionDisabled ,
338
364
expCopts : nil ,
365
+ expOK : false ,
339
366
},
340
367
{
341
368
name : "noClientSupport" ,
342
369
mode : CompressionNoContextTakeover ,
343
370
expCopts : nil ,
371
+ expOK : false ,
344
372
},
345
373
{
346
- name : "permessage-deflate" ,
347
- mode : CompressionNoContextTakeover ,
348
- reqSecWebSocketExtensions : "permessage-deflate; client_max_window_bits" ,
349
- respSecWebSocketExtensions : "permessage-deflate; client_no_context_takeover; server_no_context_takeover" ,
374
+ name : "permessage-deflate" ,
375
+ mode : CompressionNoContextTakeover ,
376
+ header : "permessage-deflate; client_max_window_bits" ,
350
377
expCopts : & compressionOptions {
351
378
clientNoContextTakeover : true ,
352
379
serverNoContextTakeover : true ,
353
380
},
381
+ expOK : true ,
382
+ },
383
+ {
384
+ name : "permessage-deflate/unknown-parameter" ,
385
+ mode : CompressionNoContextTakeover ,
386
+ header : "permessage-deflate; meow" ,
387
+ expOK : false ,
354
388
},
355
389
{
356
- name : "permessage-deflate/error" ,
357
- mode : CompressionNoContextTakeover ,
358
- reqSecWebSocketExtensions : "permessage-deflate; meow" ,
359
- error : true ,
390
+ name : "permessage-deflate/unknown-parameter" ,
391
+ mode : CompressionNoContextTakeover ,
392
+ header : "permessage-deflate; meow, permessage-deflate; client_max_window_bits" ,
393
+ expCopts : & compressionOptions {
394
+ clientNoContextTakeover : true ,
395
+ serverNoContextTakeover : true ,
396
+ },
397
+ expOK : true ,
360
398
},
361
- // {
362
- // name: "x-webkit-deflate-frame",
363
- // mode: CompressionNoContextTakeover,
364
- // reqSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
365
- // respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
366
- // expCopts: &compressionOptions{
367
- // clientNoContextTakeover: true,
368
- // serverNoContextTakeover: true,
369
- // },
370
- // },
371
- // {
372
- // name: "x-webkit-deflate/error",
373
- // mode: CompressionNoContextTakeover,
374
- // reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits",
375
- // error: true,
376
- // },
377
399
}
378
400
379
401
for _ , tc := range testCases {
380
402
tc := tc
381
403
t .Run (tc .name , func (t * testing.T ) {
382
404
t .Parallel ()
383
405
384
- r := httptest .NewRequest (http .MethodGet , "/" , nil )
385
- r .Header .Set ("Sec-WebSocket-Extensions" , tc .reqSecWebSocketExtensions )
386
-
387
- w := httptest .NewRecorder ()
388
- copts , err := acceptCompression (r , w , tc .mode )
389
- if tc .error {
390
- assert .Error (t , err )
391
- return
392
- }
393
-
394
- assert .Success (t , err )
406
+ h := http.Header {}
407
+ h .Set ("Sec-WebSocket-Extensions" , tc .header )
408
+ copts , ok := selectDeflate (websocketExtensions (h ), tc .mode )
409
+ assert .Equal (t , "selected options" , tc .expOK , ok )
395
410
assert .Equal (t , "compression options" , tc .expCopts , copts )
396
- assert .Equal (t , "Sec-WebSocket-Extensions" , tc .respSecWebSocketExtensions , w .Header ().Get ("Sec-WebSocket-Extensions" ))
397
411
})
398
412
}
399
413
}
0 commit comments