Skip to content

Commit bc7ce89

Browse files
ChannyClausChan Kang
and
Chan Kang
authored
Check for and report bad protocol in TLSClientConfig.NextProtos (#788)
* return an error when Dialer.TLSClientConfig.NextProtos contains a protocol that is not http/1.1 * include the likely cause of the error in the error message * check for nil-ness of Dialer.TLSClientConfig before attempting to run the check * addressing the review * move the NextProtos test into a separate file so that it can be run conditionally on go versions >= 1.14 * moving the new error check into existing http response error block to reduce the possibility of false positives * wrapping the error in %w * using %v instead of %w for compatibility with older versions of go * Revert "using %v instead of %w for compatibility with older versions of go" This reverts commit d34dd94. * move the unit test back into the existing test code since golang build constraint is no longer necessary Co-authored-by: Chan Kang <chankang@[email protected]>
1 parent 27d91a9 commit bc7ce89

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

client.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"context"
1010
"crypto/tls"
1111
"errors"
12+
"fmt"
1213
"io"
1314
"io/ioutil"
1415
"net"
@@ -370,6 +371,17 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
370371

371372
resp, err := http.ReadResponse(conn.br, req)
372373
if err != nil {
374+
if d.TLSClientConfig != nil {
375+
for _, proto := range d.TLSClientConfig.NextProtos {
376+
if proto != "http/1.1" {
377+
return nil, nil, fmt.Errorf(
378+
"websocket: protocol %q was given but is not supported;"+
379+
"sharing tls.Config with net/http Transport can cause this error: %w",
380+
proto, err,
381+
)
382+
}
383+
}
384+
}
373385
return nil, nil, err
374386
}
375387

client_server_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,3 +1098,38 @@ func TestNetDialConnect(t *testing.T) {
10981098
}
10991099
}
11001100
}
1101+
func TestNextProtos(t *testing.T) {
1102+
ts := httptest.NewUnstartedServer(
1103+
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
1104+
)
1105+
ts.EnableHTTP2 = true
1106+
ts.StartTLS()
1107+
defer ts.Close()
1108+
1109+
d := Dialer{
1110+
TLSClientConfig: ts.Client().Transport.(*http.Transport).TLSClientConfig,
1111+
}
1112+
1113+
r, err := ts.Client().Get(ts.URL)
1114+
if err != nil {
1115+
t.Fatalf("Get: %v", err)
1116+
}
1117+
r.Body.Close()
1118+
1119+
// Asserts that Dialer.TLSClientConfig.NextProtos contains "h2"
1120+
// after the Client.Get call from net/http above.
1121+
var containsHTTP2 bool = false
1122+
for _, proto := range d.TLSClientConfig.NextProtos {
1123+
if proto == "h2" {
1124+
containsHTTP2 = true
1125+
}
1126+
}
1127+
if !containsHTTP2 {
1128+
t.Fatalf("Dialer.TLSClientConfig.NextProtos does not contain \"h2\"")
1129+
}
1130+
1131+
_, _, err = d.Dial(makeWsProto(ts.URL), nil)
1132+
if err == nil {
1133+
t.Fatalf("Dial succeeded, expect fail ")
1134+
}
1135+
}

0 commit comments

Comments
 (0)