Skip to content

Commit b06417f

Browse files
committed
Extend DialOptions to allow Host header override
1 parent 8dee580 commit b06417f

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

Diff for: dial.go

+7
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ type DialOptions struct {
3030
// HTTPHeader specifies the HTTP headers included in the handshake request.
3131
HTTPHeader http.Header
3232

33+
// Host optionally overrides the Host HTTP header to send. If empty, the value
34+
// of URL.Host will be used.
35+
Host string
36+
3337
// Subprotocols lists the WebSocket subprotocols to negotiate with the server.
3438
Subprotocols []string
3539

@@ -158,6 +162,9 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts
158162
}
159163

160164
req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
165+
if len(opts.Host) > 0 {
166+
req.Host = opts.Host
167+
}
161168
req.Header = opts.HTTPHeader.Clone()
162169
req.Header.Set("Connection", "Upgrade")
163170
req.Header.Set("Upgrade", "websocket")

Diff for: dial_test.go

+50
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,56 @@ func TestBadDials(t *testing.T) {
108108
})
109109
}
110110

111+
func Test_verifyHostOverride(t *testing.T) {
112+
testCases := []struct {
113+
name string
114+
host string
115+
exp string
116+
}{
117+
{
118+
name: "noOverride",
119+
host: "",
120+
exp: "example.com",
121+
},
122+
{
123+
name: "hostOverride",
124+
host: "example.net",
125+
exp: "example.net",
126+
},
127+
}
128+
129+
for _, tc := range testCases {
130+
tc := tc
131+
t.Run(tc.name, func(t *testing.T) {
132+
t.Parallel()
133+
134+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
135+
defer cancel()
136+
137+
rt := func(r *http.Request) (*http.Response, error) {
138+
assert.Equal(t, "Host", tc.exp, r.Host)
139+
140+
h := http.Header{}
141+
h.Set("Connection", "Upgrade")
142+
h.Set("Upgrade", "websocket")
143+
h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
144+
145+
return &http.Response{
146+
StatusCode: http.StatusSwitchingProtocols,
147+
Header: h,
148+
Body: ioutil.NopCloser(strings.NewReader("hi")),
149+
}, nil
150+
}
151+
152+
_, _, _ = Dial(ctx, "ws://example.com", &DialOptions{
153+
HTTPClient: mockHTTPClient(rt),
154+
Host: tc.host,
155+
})
156+
})
157+
}
158+
159+
}
160+
111161
func Test_verifyServerHandshake(t *testing.T) {
112162
t.Parallel()
113163

0 commit comments

Comments
 (0)