diff --git a/http2/h2c/h2c.go b/http2/h2c/h2c.go index c0970d846..3dba9175d 100644 --- a/http2/h2c/h2c.go +++ b/http2/h2c/h2c.go @@ -249,6 +249,38 @@ func convertH1ReqToH2(r *http.Request) (*bytes.Buffer, []http2.Setting, error) { } } + // Any request body create as DATA frames + if r.Body != nil && r.Body != http.NoBody { + body, err := io.ReadAll(r.Body) + if err != nil { + return nil, nil, fmt.Errorf("Could not read request body: %v", err) + } + + needOneDataFrame := len(body) < maxFrameSize + err = framer.WriteData(1, + needOneDataFrame, // end stream? + body) + if err != nil { + return nil, nil, err + } + + for i := maxFrameSize; i < len(body); i += maxFrameSize { + if len(body)-i > maxFrameSize { + if err := framer.WriteData(1, + false, // end stream? + body[i:maxFrameSize]); err != nil { + return nil, nil, err + } + } else { + if err := framer.WriteData(1, + true, // end stream? + body[i:]); err != nil { + return nil, nil, err + } + } + } + } + return h2Bytes, settings, nil } diff --git a/http2/h2c/h2c_test.go b/http2/h2c/h2c_test.go index bd9461ec6..201046e99 100644 --- a/http2/h2c/h2c_test.go +++ b/http2/h2c/h2c_test.go @@ -104,3 +104,79 @@ func TestContext(t *testing.T) { t.Fatal(err) } } + +func Test_convertH1ReqToH2_with_POST(t *testing.T) { + postBody := "Some POST Body" + + r, err := http.NewRequest("POST", "http://localhost:80", bytes.NewBufferString(postBody)) + if err != nil { + t.Fatal(err) + } + + r.Header.Set("Upgrade", "h2c") + r.Header.Set("Connection", "Upgrade, HTTP2-Settings") + r.Header.Set("HTTP2-Settings", "AAEAAEAAAAIAAAABAAMAAABkAAQBAAAAAAUAAEAA") // Some Default Settings + h2Bytes, _, err := convertH1ReqToH2(r) + + if err != nil { + t.Fatal(err) + } + + // Read off the preface + preface := []byte(http2.ClientPreface) + if h2Bytes.Len() < len(preface) { + t.Fatal("Could not read HTTP/2 ClientPreface") + } + readPreface := h2Bytes.Next(len(preface)) + if string(readPreface) != http2.ClientPreface { + t.Fatalf("Expected Preface %s but got: %s", http2.ClientPreface, string(readPreface)) + } + + framer := http2.NewFramer(nil, h2Bytes) + + // Should get a SETTINGS, HEADERS, and then DATA + expectedFrameTypes := []http2.FrameType{http2.FrameSettings, http2.FrameHeaders, http2.FrameData} + for frameNumber := 0; h2Bytes.Len() > 0; { + frame, err := framer.ReadFrame() + if err != nil { + t.Fatal(err) + } + + if frameNumber >= len(expectedFrameTypes) { + t.Errorf("Got more than %d frames, wanted only %d", len(expectedFrameTypes), len(expectedFrameTypes)) + } + + if frame.Header().Type != expectedFrameTypes[frameNumber] { + t.Errorf("Got FrameType %v, wanted %v", frame.Header().Type, expectedFrameTypes[frameNumber]) + } + + frameNumber += 1 + + switch f := frame.(type) { + case *http2.SettingsFrame: + if frameNumber != 1 { + t.Errorf("Got SETTINGS frame as frame #%d, wanted it as frame #1", frameNumber) + } + case *http2.HeadersFrame: + if frameNumber != 2 { + t.Errorf("Got HEADERS frame as frame #%d, wanted it as frame #2", frameNumber) + } + if f.FrameHeader.StreamID != 1 { + t.Fatalf("Expected StreamId 1, got %v", f.FrameHeader.StreamID) + } + case *http2.DataFrame: + if frameNumber != 3 { + t.Errorf("Got DATA frame as frame #%d, wanted it as frame #3", frameNumber) + } + if f.FrameHeader.StreamID != 1 { + t.Errorf("Got StreamID %v, wanted 1", f.FrameHeader.StreamID) + } + + body := string(f.Data()) + + if body != postBody { + t.Errorf("Got DATA body %s, wanted %s", body, postBody) + } + } + } +}