diff --git a/http2/h2c/h2c.go b/http2/h2c/h2c.go index c3df711d9..2b77ffdaf 100644 --- a/http2/h2c/h2c.go +++ b/http2/h2c/h2c.go @@ -70,6 +70,15 @@ func NewHandler(h http.Handler, s *http2.Server) http.Handler { } } +// extractServer extracts existing http.Server instance from http.Request or create an empty http.Server +func extractServer(r *http.Request) *http.Server { + server, ok := r.Context().Value(http.ServerContextKey).(*http.Server) + if ok { + return server + } + return new(http.Server) +} + // ServeHTTP implement the h2c support that is enabled by h2c.GetH2CHandler. func (s h2cHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Handle h2c with prior knowledge (RFC 7540 Section 3.4) @@ -87,6 +96,7 @@ func (s h2cHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer conn.Close() s.s.ServeConn(conn, &http2.ServeConnOpts{ Context: r.Context(), + BaseConfig: extractServer(r), Handler: s.Handler, SawClientPreface: true, }) @@ -104,6 +114,7 @@ func (s h2cHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer conn.Close() s.s.ServeConn(conn, &http2.ServeConnOpts{ Context: r.Context(), + BaseConfig: extractServer(r), Handler: s.Handler, UpgradeRequest: r, Settings: settings, diff --git a/http2/h2c/h2c_test.go b/http2/h2c/h2c_test.go index 3e5a2eb42..558e597c6 100644 --- a/http2/h2c/h2c_test.go +++ b/http2/h2c/h2c_test.go @@ -13,6 +13,7 @@ import ( "net" "net/http" "net/http/httptest" + "strings" "testing" "golang.org/x/net/http2" @@ -74,3 +75,62 @@ func TestContext(t *testing.T) { t.Fatal(err) } } + +func TestPropagation(t *testing.T) { + var ( + server *http.Server + // double the limit because http2 will compress header + headerSize = 1 << 11 + headerLimit = 1 << 10 + ) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.ProtoMajor != 2 { + t.Errorf("Request wasn't handled by h2c. Got ProtoMajor=%v", r.ProtoMajor) + } + if r.Context().Value(http.ServerContextKey).(*http.Server) != server { + t.Errorf("Request doesn't have expected http server: %v", r.Context()) + } + if len(r.Header.Get("Long-Header")) != headerSize { + t.Errorf("Request doesn't have expected http header length: %v", len(r.Header.Get("Long-Header"))) + } + fmt.Fprint(w, "Hello world") + }) + + h2s := &http2.Server{} + h1s := httptest.NewUnstartedServer(NewHandler(handler, h2s)) + + server = h1s.Config + server.MaxHeaderBytes = headerLimit + server.ConnState = func(conn net.Conn, state http.ConnState) { + t.Logf("server conn state: conn %s -> %s, status changed to %s", conn.RemoteAddr(), conn.LocalAddr(), state) + } + + h1s.Start() + defer h1s.Close() + + client := &http.Client{ + Transport: &http2.Transport{ + AllowHTTP: true, + DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) { + conn, err := net.Dial(network, addr) + if conn != nil { + t.Logf("client dial tls: %s -> %s", conn.RemoteAddr(), conn.LocalAddr()) + } + return conn, err + }, + }, + } + + req, err := http.NewRequest("GET", h1s.URL, nil) + if err != nil { + t.Fatal(err) + } + + req.Header.Set("Long-Header", strings.Repeat("A", headerSize)) + + _, err = client.Do(req) + if err == nil { + t.Fatal("expected server err, got nil") + } +}