@@ -9,10 +9,11 @@ import (
9
9
"errors"
10
10
"fmt"
11
11
"io"
12
+ "log"
12
13
"net/http"
13
14
"net/textproto"
14
15
"net/url"
15
- "strconv "
16
+ "path/filepath "
16
17
"strings"
17
18
18
19
"nhooyr.io/websocket/internal/errd"
@@ -25,18 +26,27 @@ type AcceptOptions struct {
25
26
// reject it, close the connection when c.Subprotocol() == "".
26
27
Subprotocols []string
27
28
28
- // InsecureSkipVerify disables Accept's origin verification behaviour. By default,
29
- // the connection will only be accepted if the request origin is equal to the request
30
- // host.
29
+ // InsecureSkipVerify is used to disable Accept's origin verification behaviour.
31
30
//
32
- // This is only required if you want javascript served from a different domain
33
- // to access your WebSocket server.
31
+ // Deprecated: Use OriginPatterns with a match all pattern of * instead to control
32
+ // origin authorization yourself.
33
+ InsecureSkipVerify bool
34
+
35
+ // OriginPatterns lists the host patterns for authorized origins.
36
+ // The request host is always authorized.
37
+ // Use this to enable cross origin WebSockets.
38
+ //
39
+ // i.e javascript running on example.com wants to access a WebSocket server at chat.example.com.
40
+ // In such a case, example.com is the origin and chat.example.com is the request host.
41
+ // One would set this field to []string{"example.com"} to authorize example.com to connect.
34
42
//
35
- // See https://stackoverflow.com/a/37837709/4283659
43
+ // Each pattern is matched case insensitively against the request origin host
44
+ // with filepath.Match.
45
+ // See https://golang.org/pkg/path/filepath/#Match
36
46
//
37
47
// Please ensure you understand the ramifications of enabling this.
38
48
// If used incorrectly your WebSocket server will be open to CSRF attacks.
39
- InsecureSkipVerify bool
49
+ OriginPatterns [] string
40
50
41
51
// CompressionMode controls the compression mode.
42
52
// Defaults to CompressionNoContextTakeover.
@@ -77,8 +87,12 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
77
87
}
78
88
79
89
if ! opts .InsecureSkipVerify {
80
- err = authenticateOrigin (r )
90
+ err = authenticateOrigin (r , opts . OriginPatterns )
81
91
if err != nil {
92
+ if errors .Is (err , filepath .ErrBadPattern ) {
93
+ log .Printf ("websocket: %v" , err )
94
+ err = errors .New (http .StatusText (http .StatusForbidden ))
95
+ }
82
96
http .Error (w , err .Error (), http .StatusForbidden )
83
97
return nil , err
84
98
}
@@ -165,18 +179,35 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _
165
179
return 0 , nil
166
180
}
167
181
168
- func authenticateOrigin (r * http.Request ) error {
182
+ func authenticateOrigin (r * http.Request , originHosts [] string ) error {
169
183
origin := r .Header .Get ("Origin" )
170
- if origin != "" {
171
- u , err := url .Parse (origin )
184
+ if origin == "" {
185
+ return nil
186
+ }
187
+
188
+ u , err := url .Parse (origin )
189
+ if err != nil {
190
+ return fmt .Errorf ("failed to parse Origin header %q: %w" , origin , err )
191
+ }
192
+
193
+ if strings .EqualFold (r .Host , u .Host ) {
194
+ return nil
195
+ }
196
+
197
+ for _ , hostPattern := range originHosts {
198
+ matched , err := match (hostPattern , u .Host )
172
199
if err != nil {
173
- return fmt .Errorf ("failed to parse Origin header %q: %w" , origin , err )
200
+ return fmt .Errorf ("failed to parse filepath pattern %q: %w" , hostPattern , err )
174
201
}
175
- if ! strings . EqualFold ( u . Host , r . Host ) {
176
- return fmt . Errorf ( "request Origin %q is not authorized for Host %q" , origin , r . Host )
202
+ if matched {
203
+ return nil
177
204
}
178
205
}
179
- return nil
206
+ return fmt .Errorf ("request Origin %q is not authorized for Host %q" , origin , r .Host )
207
+ }
208
+
209
+ func match (pattern , s string ) (bool , error ) {
210
+ return filepath .Match (strings .ToLower (pattern ), strings .ToLower (s ))
180
211
}
181
212
182
213
func selectSubprotocol (r * http.Request , subprotocols []string ) string {
@@ -235,16 +266,6 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
235
266
return copts , nil
236
267
}
237
268
238
- // parseExtensionParameter parses the value in the extension parameter p.
239
- func parseExtensionParameter (p string ) (int , bool ) {
240
- ps := strings .Split (p , "=" )
241
- if len (ps ) == 1 {
242
- return 0 , false
243
- }
244
- i , e := strconv .Atoi (strings .Trim (ps [1 ], `"` ))
245
- return i , e == nil
246
- }
247
-
248
269
func acceptWebkitDeflate (w http.ResponseWriter , ext websocketExtension , mode CompressionMode ) (* compressionOptions , error ) {
249
270
copts := mode .opts ()
250
271
// The peer must explicitly request it.
0 commit comments