Skip to content

Commit 10137fa

Browse files
authored
Merge pull request #360 from Emyrk/emyrk/Sec-WebSocket-Key
Reject invalid "Sec-WebSocket-Key" headers from clients
2 parents 64ce009 + 305eab9 commit 10137fa

File tree

3 files changed

+76
-10
lines changed

3 files changed

+76
-10
lines changed

Diff for: accept.go

+12-1
Original file line numberDiff line numberDiff line change
@@ -185,10 +185,21 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _
185185
return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version"))
186186
}
187187

188-
if r.Header.Get("Sec-WebSocket-Key") == "" {
188+
websocketSecKeys := r.Header.Values("Sec-WebSocket-Key")
189+
if len(websocketSecKeys) == 0 {
189190
return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key")
190191
}
191192

193+
if len(websocketSecKeys) > 1 {
194+
return http.StatusBadRequest, errors.New("WebSocket protocol violation: multiple Sec-WebSocket-Key headers")
195+
}
196+
197+
// The RFC states to remove any leading or trailing whitespace.
198+
websocketSecKey := strings.TrimSpace(websocketSecKeys[0])
199+
if v, err := base64.StdEncoding.DecodeString(websocketSecKey); err != nil || len(v) != 16 {
200+
return http.StatusBadRequest, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Key %q, must be a 16 byte base64 encoded string", websocketSecKey)
201+
}
202+
192203
return 0, nil
193204
}
194205

Diff for: accept_test.go

+58-9
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"testing"
1414

1515
"nhooyr.io/websocket/internal/test/assert"
16+
"nhooyr.io/websocket/internal/test/xrand"
1617
)
1718

1819
func TestAccept(t *testing.T) {
@@ -36,7 +37,7 @@ func TestAccept(t *testing.T) {
3637
r.Header.Set("Connection", "Upgrade")
3738
r.Header.Set("Upgrade", "websocket")
3839
r.Header.Set("Sec-WebSocket-Version", "13")
39-
r.Header.Set("Sec-WebSocket-Key", "meow123")
40+
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
4041
r.Header.Set("Origin", "harhar.com")
4142

4243
_, err := Accept(w, r, nil)
@@ -52,7 +53,7 @@ func TestAccept(t *testing.T) {
5253
r.Header.Set("Connection", "Upgrade")
5354
r.Header.Set("Upgrade", "websocket")
5455
r.Header.Set("Sec-WebSocket-Version", "13")
55-
r.Header.Set("Sec-WebSocket-Key", "meow123")
56+
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
5657
r.Header.Set("Origin", "https://harhar.com")
5758

5859
_, err := Accept(w, r, nil)
@@ -67,7 +68,7 @@ func TestAccept(t *testing.T) {
6768
r.Header.Set("Connection", "Upgrade")
6869
r.Header.Set("Upgrade", "websocket")
6970
r.Header.Set("Sec-WebSocket-Version", "13")
70-
r.Header.Set("Sec-WebSocket-Key", "meow123")
71+
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
7172
r.Header.Set("Sec-WebSocket-Extensions", extensions)
7273
return r
7374
}
@@ -116,7 +117,7 @@ func TestAccept(t *testing.T) {
116117
r.Header.Set("Connection", "Upgrade")
117118
r.Header.Set("Upgrade", "websocket")
118119
r.Header.Set("Sec-WebSocket-Version", "13")
119-
r.Header.Set("Sec-WebSocket-Key", "meow123")
120+
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
120121

121122
_, err := Accept(w, r, nil)
122123
assert.Contains(t, err, `http.ResponseWriter does not implement http.Hijacker`)
@@ -136,7 +137,7 @@ func TestAccept(t *testing.T) {
136137
r.Header.Set("Connection", "Upgrade")
137138
r.Header.Set("Upgrade", "websocket")
138139
r.Header.Set("Sec-WebSocket-Version", "13")
139-
r.Header.Set("Sec-WebSocket-Key", "meow123")
140+
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
140141

141142
_, err := Accept(w, r, nil)
142143
assert.Contains(t, err, `failed to hijack connection`)
@@ -183,21 +184,59 @@ func Test_verifyClientHandshake(t *testing.T) {
183184
},
184185
},
185186
{
186-
name: "badWebSocketKey",
187+
name: "missingWebSocketKey",
188+
h: map[string]string{
189+
"Connection": "Upgrade",
190+
"Upgrade": "websocket",
191+
"Sec-WebSocket-Version": "13",
192+
},
193+
},
194+
{
195+
name: "emptyWebSocketKey",
187196
h: map[string]string{
188197
"Connection": "Upgrade",
189198
"Upgrade": "websocket",
190199
"Sec-WebSocket-Version": "13",
191200
"Sec-WebSocket-Key": "",
192201
},
193202
},
203+
{
204+
name: "shortWebSocketKey",
205+
h: map[string]string{
206+
"Connection": "Upgrade",
207+
"Upgrade": "websocket",
208+
"Sec-WebSocket-Version": "13",
209+
"Sec-WebSocket-Key": xrand.Base64(15),
210+
},
211+
},
212+
{
213+
name: "invalidWebSocketKey",
214+
h: map[string]string{
215+
"Connection": "Upgrade",
216+
"Upgrade": "websocket",
217+
"Sec-WebSocket-Version": "13",
218+
"Sec-WebSocket-Key": "notbase64",
219+
},
220+
},
221+
{
222+
name: "extraWebSocketKey",
223+
h: map[string]string{
224+
"Connection": "Upgrade",
225+
"Upgrade": "websocket",
226+
"Sec-WebSocket-Version": "13",
227+
// Kinda cheeky, but http headers are case-insensitive.
228+
// If 2 sec keys are present, this is a failure condition.
229+
"Sec-WebSocket-Key": xrand.Base64(16),
230+
"sec-webSocket-key": xrand.Base64(16),
231+
},
232+
},
194233
{
195234
name: "badHTTPVersion",
196235
h: map[string]string{
197236
"Connection": "Upgrade",
198237
"Upgrade": "websocket",
199238
"Sec-WebSocket-Version": "13",
200-
"Sec-WebSocket-Key": "meow123",
239+
"Sec-WebSocket-Key": xrand.Base64(16),
201240
},
202241
http1: true,
203242
},
@@ -207,7 +246,17 @@ func Test_verifyClientHandshake(t *testing.T) {
207246
"Connection": "keep-alive, Upgrade",
208247
"Upgrade": "websocket",
209248
"Sec-WebSocket-Version": "13",
210-
"Sec-WebSocket-Key": "meow123",
249+
"Sec-WebSocket-Key": xrand.Base64(16),
250+
},
251+
success: true,
252+
},
253+
{
254+
name: "successSecKeyExtraSpace",
255+
h: map[string]string{
256+
"Connection": "keep-alive, Upgrade",
257+
"Upgrade": "websocket",
258+
"Sec-WebSocket-Version": "13",
259+
"Sec-WebSocket-Key": " " + xrand.Base64(16) + " ",
211260
},
212261
success: true,
213262
},
@@ -227,7 +276,7 @@ func Test_verifyClientHandshake(t *testing.T) {
227276
}
228277

229278
for k, v := range tc.h {
230-
r.Header.Set(k, v)
279+
r.Header.Add(k, v)
231280
}
232281

233282
_, err := verifyClientRequest(httptest.NewRecorder(), r)

Diff for: internal/test/xrand/xrand.go

+6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package xrand
22

33
import (
44
"crypto/rand"
5+
"encoding/base64"
56
"fmt"
67
"math/big"
78
"strings"
@@ -45,3 +46,8 @@ func Int(max int) int {
4546
}
4647
return int(x.Int64())
4748
}
49+
50+
// Base64 returns a randomly generated base64 string of length n.
51+
func Base64(n int) string {
52+
return base64.StdEncoding.EncodeToString(Bytes(n))
53+
}

0 commit comments

Comments
 (0)