@@ -62,20 +62,50 @@ func TestAccept(t *testing.T) {
62
62
t .Run ("badCompression" , func (t * testing.T ) {
63
63
t .Parallel ()
64
64
65
- w := mockHijacker {
66
- ResponseWriter : httptest .NewRecorder (),
65
+ newRequest := func (extensions string ) * http.Request {
66
+ r := httptest .NewRequest ("GET" , "/" , nil )
67
+ r .Header .Set ("Connection" , "Upgrade" )
68
+ r .Header .Set ("Upgrade" , "websocket" )
69
+ r .Header .Set ("Sec-WebSocket-Version" , "13" )
70
+ r .Header .Set ("Sec-WebSocket-Key" , "meow123" )
71
+ r .Header .Set ("Sec-WebSocket-Extensions" , extensions )
72
+ return r
73
+ }
74
+ errHijack := errors .New ("hijack error" )
75
+ newResponseWriter := func () http.ResponseWriter {
76
+ return mockHijacker {
77
+ ResponseWriter : httptest .NewRecorder (),
78
+ hijack : func () (net.Conn , * bufio.ReadWriter , error ) {
79
+ return nil , nil , errHijack
80
+ },
81
+ }
67
82
}
68
- r := httptest .NewRequest ("GET" , "/" , nil )
69
- r .Header .Set ("Connection" , "Upgrade" )
70
- r .Header .Set ("Upgrade" , "websocket" )
71
- r .Header .Set ("Sec-WebSocket-Version" , "13" )
72
- r .Header .Set ("Sec-WebSocket-Key" , "meow123" )
73
- r .Header .Set ("Sec-WebSocket-Extensions" , "permessage-deflate; harharhar" )
74
83
75
- _ , err := Accept (w , r , & AcceptOptions {
76
- CompressionMode : CompressionContextTakeover ,
84
+ t .Run ("withoutFallback" , func (t * testing.T ) {
85
+ t .Parallel ()
86
+
87
+ w := newResponseWriter ()
88
+ r := newRequest ("permessage-deflate; harharhar" )
89
+ _ , err := Accept (w , r , & AcceptOptions {
90
+ CompressionMode : CompressionNoContextTakeover ,
91
+ })
92
+ assert .ErrorIs (t , errHijack , err )
93
+ assert .Equal (t , "extension header" , w .Header ().Get ("Sec-WebSocket-Extensions" ), "" )
94
+ })
95
+ t .Run ("withFallback" , func (t * testing.T ) {
96
+ t .Parallel ()
97
+
98
+ w := newResponseWriter ()
99
+ r := newRequest ("permessage-deflate; harharhar, permessage-deflate" )
100
+ _ , err := Accept (w , r , & AcceptOptions {
101
+ CompressionMode : CompressionNoContextTakeover ,
102
+ })
103
+ assert .ErrorIs (t , errHijack , err )
104
+ assert .Equal (t , "extension header" ,
105
+ w .Header ().Get ("Sec-WebSocket-Extensions" ),
106
+ CompressionNoContextTakeover .opts ().String (),
107
+ )
77
108
})
78
- assert .Contains (t , err , `unsupported permessage-deflate parameter` )
79
109
})
80
110
81
111
t .Run ("requireHttpHijacker" , func (t * testing.T ) {
@@ -344,79 +374,66 @@ func Test_authenticateOrigin(t *testing.T) {
344
374
}
345
375
}
346
376
347
- func Test_acceptCompression (t * testing.T ) {
377
+ func Test_selectDeflate (t * testing.T ) {
348
378
t .Parallel ()
349
379
350
380
testCases := []struct {
351
- name string
352
- mode CompressionMode
353
- reqSecWebSocketExtensions string
354
- respSecWebSocketExtensions string
355
- expCopts * compressionOptions
356
- error bool
381
+ name string
382
+ mode CompressionMode
383
+ header string
384
+ expCopts * compressionOptions
385
+ expOK bool
357
386
}{
358
387
{
359
388
name : "disabled" ,
360
389
mode : CompressionDisabled ,
361
390
expCopts : nil ,
391
+ expOK : false ,
362
392
},
363
393
{
364
394
name : "noClientSupport" ,
365
395
mode : CompressionNoContextTakeover ,
366
396
expCopts : nil ,
397
+ expOK : false ,
367
398
},
368
399
{
369
- name : "permessage-deflate" ,
370
- mode : CompressionNoContextTakeover ,
371
- reqSecWebSocketExtensions : "permessage-deflate; client_max_window_bits" ,
372
- respSecWebSocketExtensions : "permessage-deflate; client_no_context_takeover; server_no_context_takeover" ,
400
+ name : "permessage-deflate" ,
401
+ mode : CompressionNoContextTakeover ,
402
+ header : "permessage-deflate; client_max_window_bits" ,
373
403
expCopts : & compressionOptions {
374
404
clientNoContextTakeover : true ,
375
405
serverNoContextTakeover : true ,
376
406
},
407
+ expOK : true ,
408
+ },
409
+ {
410
+ name : "permessage-deflate/unknown-parameter" ,
411
+ mode : CompressionNoContextTakeover ,
412
+ header : "permessage-deflate; meow" ,
413
+ expOK : false ,
377
414
},
378
415
{
379
- name : "permessage-deflate/error" ,
380
- mode : CompressionNoContextTakeover ,
381
- reqSecWebSocketExtensions : "permessage-deflate; meow" ,
382
- error : true ,
416
+ name : "permessage-deflate/unknown-parameter" ,
417
+ mode : CompressionNoContextTakeover ,
418
+ header : "permessage-deflate; meow, permessage-deflate; client_max_window_bits" ,
419
+ expCopts : & compressionOptions {
420
+ clientNoContextTakeover : true ,
421
+ serverNoContextTakeover : true ,
422
+ },
423
+ expOK : true ,
383
424
},
384
- // {
385
- // name: "x-webkit-deflate-frame",
386
- // mode: CompressionNoContextTakeover,
387
- // reqSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
388
- // respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
389
- // expCopts: &compressionOptions{
390
- // clientNoContextTakeover: true,
391
- // serverNoContextTakeover: true,
392
- // },
393
- // },
394
- // {
395
- // name: "x-webkit-deflate/error",
396
- // mode: CompressionNoContextTakeover,
397
- // reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits",
398
- // error: true,
399
- // },
400
425
}
401
426
402
427
for _ , tc := range testCases {
403
428
tc := tc
404
429
t .Run (tc .name , func (t * testing.T ) {
405
430
t .Parallel ()
406
431
407
- r := httptest .NewRequest (http .MethodGet , "/" , nil )
408
- r .Header .Set ("Sec-WebSocket-Extensions" , tc .reqSecWebSocketExtensions )
409
-
410
- w := httptest .NewRecorder ()
411
- copts , err := acceptCompression (r , w , tc .mode )
412
- if tc .error {
413
- assert .Error (t , err )
414
- return
415
- }
416
-
417
- assert .Success (t , err )
432
+ h := http.Header {}
433
+ h .Set ("Sec-WebSocket-Extensions" , tc .header )
434
+ copts , ok := selectDeflate (websocketExtensions (h ), tc .mode )
435
+ assert .Equal (t , "selected options" , tc .expOK , ok )
418
436
assert .Equal (t , "compression options" , tc .expCopts , copts )
419
- assert .Equal (t , "Sec-WebSocket-Extensions" , tc .respSecWebSocketExtensions , w .Header ().Get ("Sec-WebSocket-Extensions" ))
420
437
})
421
438
}
422
439
}
0 commit comments