diff --git a/http2/h2c/h2c.go b/http2/h2c/h2c.go index c0970d8467..d2a4c12613 100644 --- a/http2/h2c/h2c.go +++ b/http2/h2c/h2c.go @@ -249,6 +249,33 @@ func convertH1ReqToH2(r *http.Request) (*bytes.Buffer, []http2.Setting, error) { } } + // Any request body create as DATA frames + if r.Body != nil && r.Body != http.NoBody { + buf := make([]byte, maxFrameSize) + for { + n, err := r.Body.Read(buf) + if err != nil && err != io.EOF { + return nil, nil, fmt.Errorf("Could not read request body: %v", err) + } + + if n < maxFrameSize || err == io.EOF { + err = framer.WriteData(1, + true, // end stream? + buf[:n]) + if err != nil { + return nil, nil, err + } + break + } + + if err = framer.WriteData(1, + false, // end stream? + buf); err != nil { + return nil, nil, err + } + } + } + return h2Bytes, settings, nil } diff --git a/http2/h2c/h2c_test.go b/http2/h2c/h2c_test.go index bd9461ec6e..ee8b73a8bc 100644 --- a/http2/h2c/h2c_test.go +++ b/http2/h2c/h2c_test.go @@ -9,6 +9,7 @@ import ( "bytes" "context" "crypto/tls" + "encoding/base64" "fmt" "io/ioutil" "log" @@ -104,3 +105,175 @@ func TestContext(t *testing.T) { t.Fatal(err) } } + +func TestConvertH1ReqToH2WithPOST(t *testing.T) { + postBody := "Some POST Body" + + r, err := http.NewRequest("POST", "http://localhost:80", bytes.NewBufferString(postBody)) + if err != nil { + t.Fatal(err) + } + + r.Header.Set("Upgrade", "h2c") + r.Header.Set("Connection", "Upgrade, HTTP2-Settings") + r.Header.Set("HTTP2-Settings", "AAEAAEAAAAIAAAABAAMAAABkAAQBAAAAAAUAAEAA") // Some Default Settings + h2Bytes, _, err := convertH1ReqToH2(r) + + if err != nil { + t.Fatal(err) + } + + // Read off the preface + preface := []byte(http2.ClientPreface) + if h2Bytes.Len() < len(preface) { + t.Fatal("Could not read HTTP/2 ClientPreface") + } + readPreface := h2Bytes.Next(len(preface)) + if string(readPreface) != http2.ClientPreface { + t.Fatalf("Expected Preface %s but got: %s", http2.ClientPreface, string(readPreface)) + } + + framer := http2.NewFramer(nil, h2Bytes) + + // Should get a SETTINGS, HEADERS, and then DATA + expectedFrameTypes := []http2.FrameType{http2.FrameSettings, http2.FrameHeaders, http2.FrameData} + for frameNumber := 0; h2Bytes.Len() > 0; { + frame, err := framer.ReadFrame() + if err != nil { + t.Fatal(err) + } + + if frameNumber >= len(expectedFrameTypes) { + t.Errorf("Got more than %d frames, wanted only %d", len(expectedFrameTypes), len(expectedFrameTypes)) + } + + if frame.Header().Type != expectedFrameTypes[frameNumber] { + t.Errorf("Got FrameType %v, wanted %v", frame.Header().Type, expectedFrameTypes[frameNumber]) + } + + frameNumber += 1 + + switch f := frame.(type) { + case *http2.SettingsFrame: + if frameNumber != 1 { + t.Errorf("Got SETTINGS frame as frame #%d, wanted it as frame #1", frameNumber) + } + case *http2.HeadersFrame: + if frameNumber != 2 { + t.Errorf("Got HEADERS frame as frame #%d, wanted it as frame #2", frameNumber) + } + if f.FrameHeader.StreamID != 1 { + t.Fatalf("Got StreamID %v, wanted StreamID 1", f.FrameHeader.StreamID) + } + case *http2.DataFrame: + if frameNumber != 3 { + t.Errorf("Got DATA frame as frame #%d, wanted it as frame #3", frameNumber) + } + if f.FrameHeader.StreamID != 1 { + t.Fatalf("Got StreamID %v, wanted StreamID 1", f.FrameHeader.StreamID) + } + + body := string(f.Data()) + + if body != postBody { + t.Errorf("Got DATA body %s, wanted %s", body, postBody) + } + } + } +} + +func TestConvertH1ReqToH2WithPOSTBodyMultipleOfFrameSize(t *testing.T) { + frameSize := 1024 + fillByte := byte(0x45) + postBody := bytes.Repeat([]byte{fillByte}, frameSize*2) + + r, err := http.NewRequest("POST", "http://localhost:80", bytes.NewBuffer(postBody)) + if err != nil { + t.Fatal(err) + } + + var settingsBuffer bytes.Buffer + settingsFramer := http2.NewFramer(&settingsBuffer, nil) + settingsFramer.WriteSettings(http2.Setting{http2.SettingMaxFrameSize, uint32(frameSize)}) + settingsEncoded := base64.URLEncoding.EncodeToString(settingsBuffer.Bytes()) + + r.Header.Set("Upgrade", "h2c") + r.Header.Set("Connection", "Upgrade, HTTP2-Settings") + r.Header.Set("HTTP2-Settings", settingsEncoded) + h2Bytes, _, err := convertH1ReqToH2(r) + + if err != nil { + t.Fatal(err) + } + + // Read off the preface + preface := []byte(http2.ClientPreface) + if h2Bytes.Len() < len(preface) { + t.Fatal("Could not read HTTP/2 ClientPreface") + } + readPreface := h2Bytes.Next(len(preface)) + if string(readPreface) != http2.ClientPreface { + t.Fatalf("Expected Preface %s but got: %s", http2.ClientPreface, string(readPreface)) + } + + framer := http2.NewFramer(nil, h2Bytes) + + // Should get a SETTINGS, HEADERS, and then DATA + expectedFrameTypes := []http2.FrameType{http2.FrameSettings, http2.FrameHeaders, http2.FrameData, http2.FrameData, http2.FrameData} + for frameNumber := 0; h2Bytes.Len() > 0; { + frame, err := framer.ReadFrame() + if err != nil { + t.Fatal(err) + } + + if frameNumber >= len(expectedFrameTypes) { + t.Errorf("Got more than %d frames, wanted only %d", len(expectedFrameTypes), len(expectedFrameTypes)) + } + + if frame.Header().Type != expectedFrameTypes[frameNumber] { + t.Errorf("Got FrameType %v, wanted %v", frame.Header().Type, expectedFrameTypes[frameNumber]) + } + + frameNumber += 1 + + switch f := frame.(type) { + case *http2.SettingsFrame: + if frameNumber != 1 { + t.Errorf("Got SETTINGS frame as frame #%d, wanted it as frame #1", frameNumber) + } + case *http2.HeadersFrame: + if frameNumber != 2 { + t.Errorf("Got HEADERS frame as frame #%d, wanted it as frame #2", frameNumber) + } + if f.FrameHeader.StreamID != 1 { + t.Fatalf("Got StreamID %v, wanted StreamID 1", f.FrameHeader.StreamID) + } + case *http2.DataFrame: + if frameNumber < 3 { + t.Errorf("Got DATA frame as frame #%d, wanted it as frame #3 or later", frameNumber) + } + if f.FrameHeader.StreamID != 1 { + t.Fatalf("Got StreamID %v, wanted StreamID 1", f.FrameHeader.StreamID) + } + + if frameNumber < len(expectedFrameTypes) && len(f.Data()) < frameSize { + t.Errorf("Expected data frame with length %d, got %d", frameSize, len(f.Data())) + } + + if frameNumber == len(expectedFrameTypes) && len(f.Data()) != 0 { + t.Errorf("Got non-empty DATA frame as last frame, expected it to be empty") + } + + if frameNumber == len(expectedFrameTypes) && (f.FrameHeader.Flags&http2.FlagHeadersEndStream) == 0 { + t.Errorf("Got last DATA frame with end stream not set, expected end stream set") + } + + for _, b := range f.Data() { + if b != fillByte { + t.Errorf("Got data byte %d, wanted %d", b, fillByte) + break + } + } + } + } +}