From 6e2c99c943496e33025da68db088edff5dc7d07b Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 18 Mar 2024 13:01:26 -0700 Subject: [PATCH] http2: allow testing Transports with testSyncHooks Change-Id: Icafc4860ef0691e5133221a0b53bb1d2158346cc Reviewed-on: https://go-review.googlesource.com/c/net/+/572378 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- http2/clientconn_test.go | 202 ++++++++++++++++++++------ http2/transport.go | 35 +++-- http2/transport_test.go | 301 ++++++++++++++------------------------- 3 files changed, 288 insertions(+), 250 deletions(-) diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go index 97f884c66..73ceefd7b 100644 --- a/http2/clientconn_test.go +++ b/http2/clientconn_test.go @@ -99,62 +99,57 @@ type testClientConn struct { roundtrips []*testRoundTrip - rerr error // returned by Read - rbuf bytes.Buffer // sent to the test conn - wbuf bytes.Buffer // sent by the test conn + rerr error // returned by Read + netConnClosed bool // set when the ClientConn closes the net.Conn + rbuf bytes.Buffer // sent to the test conn + wbuf bytes.Buffer // sent by the test conn } -func newTestClientConn(t *testing.T, opts ...func(*Transport)) *testClientConn { - t.Helper() - - tr := &Transport{} - for _, o := range opts { - o(tr) - } - +func newTestClientConnFromClientConn(t *testing.T, cc *ClientConn) *testClientConn { tc := &testClientConn{ t: t, - tr: tr, - hooks: newTestSyncHooks(), + tr: cc.t, + cc: cc, + hooks: cc.t.syncHooks, } + cc.tconn = (*testClientConnNetConn)(tc) tc.enc = hpack.NewEncoder(&tc.encbuf) tc.fr = NewFramer(&tc.rbuf, &tc.wbuf) tc.fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil) tc.fr.SetMaxReadFrameSize(10 << 20) - t.Cleanup(func() { tc.sync() if tc.rerr == nil { tc.rerr = io.EOF } tc.sync() - if tc.hooks.total != 0 { - t.Errorf("%v goroutines still running after test completed", tc.hooks.total) - } - }) + return tc +} - tc.hooks.newclientconn = func(cc *ClientConn) { - tc.cc = cc - } - const singleUse = false - _, err := tc.tr.newClientConn((*testClientConnNetConn)(tc), singleUse, tc.hooks) - if err != nil { - t.Fatal(err) - } - tc.sync() - tc.hooks.newclientconn = nil - +func (tc *testClientConn) readClientPreface() { + tc.t.Helper() // Read the client's HTTP/2 preface, sent prior to any HTTP/2 frames. buf := make([]byte, len(clientPreface)) if _, err := io.ReadFull(&tc.wbuf, buf); err != nil { - t.Fatalf("reading preface: %v", err) + tc.t.Fatalf("reading preface: %v", err) } if !bytes.Equal(buf, clientPreface) { - t.Fatalf("client preface: %q, want %q", buf, clientPreface) + tc.t.Fatalf("client preface: %q, want %q", buf, clientPreface) } +} - return tc +func newTestClientConn(t *testing.T, opts ...func(*Transport)) *testClientConn { + t.Helper() + + tt := newTestTransport(t, opts...) + const singleUse = false + _, err := tt.tr.newClientConn(nil, singleUse, tt.tr.syncHooks) + if err != nil { + t.Fatalf("newClientConn: %v", err) + } + + return tt.getConn() } // sync waits for the ClientConn under test to reach a stable state, @@ -349,7 +344,7 @@ func (b *testRequestBody) closeWithError(err error) { // the request times out, or some other terminal condition is reached.) func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip { rt := &testRoundTrip{ - tc: tc, + t: tc.t, donec: make(chan struct{}), } tc.roundtrips = append(tc.roundtrips, rt) @@ -362,6 +357,9 @@ func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip { tc.hooks.newstream = nil tc.t.Cleanup(func() { + if !rt.done() { + return + } res, _ := rt.result() if res != nil { res.Body.Close() @@ -460,6 +458,14 @@ func (tc *testClientConn) writeContinuation(streamID uint32, endHeaders bool, he tc.sync() } +func (tc *testClientConn) writeRSTStream(streamID uint32, code ErrCode) { + tc.t.Helper() + if err := tc.fr.WriteRSTStream(streamID, code); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + func (tc *testClientConn) writePing(ack bool, data [8]byte) { tc.t.Helper() if err := tc.fr.WritePing(ack, data); err != nil { @@ -491,9 +497,25 @@ func (tc *testClientConn) closeWrite(err error) { tc.sync() } +// inflowWindow returns the amount of inbound flow control available for a stream, +// or for the connection if streamID is 0. +func (tc *testClientConn) inflowWindow(streamID uint32) int32 { + tc.cc.mu.Lock() + defer tc.cc.mu.Unlock() + if streamID == 0 { + return tc.cc.inflow.avail + tc.cc.inflow.unsent + } + cs := tc.cc.streams[streamID] + if cs == nil { + tc.t.Errorf("no stream with id %v", streamID) + return -1 + } + return cs.inflow.avail + cs.inflow.unsent +} + // testRoundTrip manages a RoundTrip in progress. type testRoundTrip struct { - tc *testClientConn + t *testing.T resp *http.Response respErr error donec chan struct{} @@ -502,6 +524,9 @@ type testRoundTrip struct { // streamID returns the HTTP/2 stream ID of the request. func (rt *testRoundTrip) streamID() uint32 { + if rt.cs == nil { + panic("stream ID unknown") + } return rt.cs.ID } @@ -517,12 +542,12 @@ func (rt *testRoundTrip) done() bool { // result returns the result of the RoundTrip. func (rt *testRoundTrip) result() (*http.Response, error) { - t := rt.tc.t + t := rt.t t.Helper() select { case <-rt.donec: default: - t.Fatalf("RoundTrip (stream %v) is not done; want it to be", rt.streamID()) + t.Fatalf("RoundTrip is not done; want it to be") } return rt.resp, rt.respErr } @@ -530,7 +555,7 @@ func (rt *testRoundTrip) result() (*http.Response, error) { // response returns the response of a successful RoundTrip. // If the RoundTrip unexpectedly failed, it calls t.Fatal. func (rt *testRoundTrip) response() *http.Response { - t := rt.tc.t + t := rt.t t.Helper() resp, err := rt.result() if err != nil { @@ -544,7 +569,7 @@ func (rt *testRoundTrip) response() *http.Response { // err returns the (possibly nil) error result of RoundTrip. func (rt *testRoundTrip) err() error { - t := rt.tc.t + t := rt.t t.Helper() _, err := rt.result() return err @@ -552,7 +577,7 @@ func (rt *testRoundTrip) err() error { // wantStatus indicates the expected response StatusCode. func (rt *testRoundTrip) wantStatus(want int) { - t := rt.tc.t + t := rt.t t.Helper() if got := rt.response().StatusCode; got != want { t.Fatalf("got response status %v, want %v", got, want) @@ -561,7 +586,7 @@ func (rt *testRoundTrip) wantStatus(want int) { // body reads the contents of the response body. func (rt *testRoundTrip) readBody() ([]byte, error) { - t := rt.tc.t + t := rt.t t.Helper() return io.ReadAll(rt.response().Body) } @@ -569,7 +594,7 @@ func (rt *testRoundTrip) readBody() ([]byte, error) { // wantBody indicates the expected response body. // (Note that this consumes the body.) func (rt *testRoundTrip) wantBody(want []byte) { - t := rt.tc.t + t := rt.t t.Helper() got, err := rt.readBody() if err != nil { @@ -582,7 +607,7 @@ func (rt *testRoundTrip) wantBody(want []byte) { // wantHeaders indicates the expected response headers. func (rt *testRoundTrip) wantHeaders(want http.Header) { - t := rt.tc.t + t := rt.t t.Helper() res := rt.response() if diff := diffHeaders(res.Header, want); diff != "" { @@ -592,7 +617,7 @@ func (rt *testRoundTrip) wantHeaders(want http.Header) { // wantTrailers indicates the expected response trailers. func (rt *testRoundTrip) wantTrailers(want http.Header) { - t := rt.tc.t + t := rt.t t.Helper() res := rt.response() if diff := diffHeaders(res.Trailer, want); diff != "" { @@ -630,7 +655,8 @@ func (nc *testClientConnNetConn) Write(b []byte) (n int, err error) { return nc.wbuf.Write(b) } -func (*testClientConnNetConn) Close() error { +func (nc *testClientConnNetConn) Close() error { + nc.netConnClosed = true return nil } @@ -639,3 +665,91 @@ func (*testClientConnNetConn) RemoteAddr() (_ net.Addr) { return } func (*testClientConnNetConn) SetDeadline(t time.Time) error { return nil } func (*testClientConnNetConn) SetReadDeadline(t time.Time) error { return nil } func (*testClientConnNetConn) SetWriteDeadline(t time.Time) error { return nil } + +// A testTransport allows testing Transport.RoundTrip against fake servers. +// Tests that aren't specifically exercising RoundTrip's retry loop or connection pooling +// should use testClientConn instead. +type testTransport struct { + t *testing.T + tr *Transport + + ccs []*testClientConn +} + +func newTestTransport(t *testing.T, opts ...func(*Transport)) *testTransport { + tr := &Transport{ + syncHooks: newTestSyncHooks(), + } + for _, o := range opts { + o(tr) + } + + tt := &testTransport{ + t: t, + tr: tr, + } + tr.syncHooks.newclientconn = func(cc *ClientConn) { + tt.ccs = append(tt.ccs, newTestClientConnFromClientConn(t, cc)) + } + + t.Cleanup(func() { + tt.sync() + if len(tt.ccs) > 0 { + t.Fatalf("%v test ClientConns created, but not examined by test", len(tt.ccs)) + } + if tt.tr.syncHooks.total != 0 { + t.Errorf("%v goroutines still running after test completed", tt.tr.syncHooks.total) + } + }) + + return tt +} + +func (tt *testTransport) sync() { + tt.tr.syncHooks.waitInactive() +} + +func (tt *testTransport) advance(d time.Duration) { + tt.tr.syncHooks.advance(d) + tt.sync() +} + +func (tt *testTransport) hasConn() bool { + return len(tt.ccs) > 0 +} + +func (tt *testTransport) getConn() *testClientConn { + tt.t.Helper() + if len(tt.ccs) == 0 { + tt.t.Fatalf("no new ClientConns created; wanted one") + } + tc := tt.ccs[0] + tt.ccs = tt.ccs[1:] + tc.sync() + tc.readClientPreface() + return tc +} + +func (tt *testTransport) roundTrip(req *http.Request) *testRoundTrip { + rt := &testRoundTrip{ + t: tt.t, + donec: make(chan struct{}), + } + tt.tr.syncHooks.goRun(func() { + defer close(rt.donec) + rt.resp, rt.respErr = tt.tr.RoundTrip(req) + }) + tt.sync() + + tt.t.Cleanup(func() { + if !rt.done() { + return + } + res, _ := rt.result() + if res != nil { + res.Body.Close() + } + }) + + return rt +} diff --git a/http2/transport.go b/http2/transport.go index 1ce5f125c..bf1dacd35 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -184,6 +184,8 @@ type Transport struct { connPoolOnce sync.Once connPoolOrDef ClientConnPool // non-nil version of ConnPool + + syncHooks *testSyncHooks } func (t *Transport) maxHeaderListSize() uint32 { @@ -597,15 +599,6 @@ func authorityAddr(scheme string, authority string) (addr string) { return net.JoinHostPort(host, port) } -var retryBackoffHook func(time.Duration) *time.Timer - -func backoffNewTimer(d time.Duration) *time.Timer { - if retryBackoffHook != nil { - return retryBackoffHook(d) - } - return time.NewTimer(d) -} - // RoundTripOpt is like RoundTrip, but takes options. func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) { @@ -633,13 +626,27 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res backoff := float64(uint(1) << (uint(retry) - 1)) backoff += backoff * (0.1 * mathrand.Float64()) d := time.Second * time.Duration(backoff) - timer := backoffNewTimer(d) + var tm timer + if t.syncHooks != nil { + tm = t.syncHooks.newTimer(d) + t.syncHooks.blockUntil(func() bool { + select { + case <-tm.C(): + case <-req.Context().Done(): + default: + return false + } + return true + }) + } else { + tm = newTimeTimer(d) + } select { - case <-timer.C: + case <-tm.C(): t.vlogf("RoundTrip retrying after failure: %v", roundTripErr) continue case <-req.Context().Done(): - timer.Stop() + tm.Stop() err = req.Context().Err() } } @@ -718,6 +725,9 @@ func canRetryError(err error) bool { } func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) { + if t.syncHooks != nil { + return t.newClientConn(nil, singleUse, t.syncHooks) + } host, _, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -814,6 +824,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool, hooks *testSyncHoo } if hooks != nil { hooks.newclientconn(cc) + c = cc.tconn } if d := t.idleConnTimeout(); d != 0 { cc.idleTimeout = d diff --git a/http2/transport_test.go b/http2/transport_test.go index bab2472f3..5de0ad8c4 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -3688,61 +3688,49 @@ func TestTransportRetryAfterRefusedStream(t *testing.T) { } func TestTransportRetryHasLimit(t *testing.T) { - // Skip in short mode because the total expected delay is 1s+2s+4s+8s+16s=29s. - if testing.Short() { - t.Skip("skipping long test in short mode") - } - retryBackoffHook = func(d time.Duration) *time.Timer { - return time.NewTimer(0) // fires immediately - } - defer func() { - retryBackoffHook = nil - }() - clientDone := make(chan struct{}) - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - defer close(clientDone) - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - resp, err := ct.tr.RoundTrip(req) - if err == nil { - return fmt.Errorf("RoundTrip expected error, got response: %+v", resp) + tt := newTestTransport(t) + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tt.roundTrip(req) + + // First attempt: Server sends a GOAWAY. + tc := tt.getConn() + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + + var totalDelay time.Duration + count := 0 + for streamID := uint32(1); ; streamID += 2 { + count++ + tc.wantHeaders(wantHeader{ + streamID: streamID, + endStream: true, + }) + if streamID == 1 { + tc.writeSettings() + tc.wantFrameType(FrameSettings) // settings ACK } - t.Logf("expected error, got: %v", err) - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - // If the client's done, it - // will have reported any - // errors on its side. - return nil - default: - return err - } - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - if !f.HeadersEnded() { - return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) - } - ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) - default: - return fmt.Errorf("Unexpected client frame %v", f) + tc.writeRSTStream(streamID, ErrCodeRefusedStream) + + d := tt.tr.syncHooks.timeUntilEvent() + if d == 0 { + if streamID == 1 { + continue } + break + } + totalDelay += d + if totalDelay > 5*time.Minute { + t.Fatalf("RoundTrip still retrying after %v, should have given up", totalDelay) } + tt.advance(d) + } + if got, want := count, 5; got < count { + t.Errorf("RoundTrip made %v attempts, want at least %v", got, want) + } + if rt.err() == nil { + t.Errorf("RoundTrip succeeded, want error") } - ct.run() } func TestTransportResponseDataBeforeHeaders(t *testing.T) { @@ -5593,155 +5581,80 @@ func TestTransportCloseRequestBody(t *testing.T) { } } -// collectClientsConnPool is a ClientConnPool that wraps lower and -// collects what calls were made on it. -type collectClientsConnPool struct { - lower ClientConnPool - - mu sync.Mutex - getErrs int - got []*ClientConn -} - -func (p *collectClientsConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) { - cc, err := p.lower.GetClientConn(req, addr) - p.mu.Lock() - defer p.mu.Unlock() - if err != nil { - p.getErrs++ - return nil, err - } - p.got = append(p.got, cc) - return cc, nil -} - -func (p *collectClientsConnPool) MarkDead(cc *ClientConn) { - p.lower.MarkDead(cc) -} - func TestTransportRetriesOnStreamProtocolError(t *testing.T) { - ct := newClientTester(t) - pool := &collectClientsConnPool{ - lower: &clientConnPool{t: ct.tr}, - } - ct.tr.ConnPool = pool + // This test verifies that + // - receiving a protocol error on a connection does not interfere with + // other requests in flight on that connection; + // - the connection is not reused for further requests; and + // - the failed request is retried on a new connecection. + tt := newTestTransport(t) + + // Start two requests. The first is a long request + // that will finish after the second. The second one + // will result in the protocol error. + + // Request #1: The long request. + req1, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt1 := tt.roundTrip(req1) + tc1 := tt.getConn() + tc1.wantFrameType(FrameSettings) + tc1.wantFrameType(FrameWindowUpdate) + tc1.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + }) + tc1.writeSettings() + tc1.wantFrameType(FrameSettings) // settings ACK + + // Request #2(a): The short request. + req2, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt2 := tt.roundTrip(req2) + tc1.wantHeaders(wantHeader{ + streamID: 3, + endStream: true, + }) - gotProtoError := make(chan bool, 1) - ct.tr.CountError = func(errType string) { - if errType == "recv_rststream_PROTOCOL_ERROR" { - select { - case gotProtoError <- true: - default: - } - } + // Request #2(a) fails with ErrCodeProtocol. + tc1.writeRSTStream(3, ErrCodeProtocol) + if rt1.done() { + t.Fatalf("After protocol error on RoundTrip #2, RoundTrip #1 is done; want still in progress") } - ct.client = func() error { - // Start two requests. The first is a long request - // that will finish after the second. The second one - // will result in the protocol error. We check that - // after the first one closes, the connection then - // shuts down. - - // The long, outer request. - req1, _ := http.NewRequest("GET", "https://dummy.tld/long", nil) - res1, err := ct.tr.RoundTrip(req1) - if err != nil { - return err - } - if got, want := res1.Header.Get("Is-Long"), "1"; got != want { - return fmt.Errorf("First response's Is-Long header = %q; want %q", got, want) - } - - req, _ := http.NewRequest("POST", "https://dummy.tld/fails", nil) - res, err := ct.tr.RoundTrip(req) - const want = "only one dial allowed in test mode" - if got := fmt.Sprint(err); got != want { - t.Errorf("didn't dial again: got %#q; want %#q", got, want) - } - if res != nil { - res.Body.Close() - } - select { - case <-gotProtoError: - default: - t.Errorf("didn't get stream protocol error") - } - - if n, err := res1.Body.Read(make([]byte, 10)); err != io.EOF || n != 0 { - t.Errorf("unexpected body read %v, %v", n, err) - } - - pool.mu.Lock() - defer pool.mu.Unlock() - if pool.getErrs != 1 { - t.Errorf("pool get errors = %v; want 1", pool.getErrs) - } - if len(pool.got) == 2 { - if pool.got[0] != pool.got[1] { - t.Errorf("requests went on different connections") - } - cc := pool.got[0] - cc.mu.Lock() - if !cc.doNotReuse { - t.Error("ClientConn not marked doNotReuse") - } - cc.mu.Unlock() - - select { - case <-cc.readerDone: - case <-time.After(5 * time.Second): - t.Errorf("timeout waiting for reader to be done") - } - } else { - t.Errorf("pool get success = %v; want 2", len(pool.got)) - } - return nil + if rt2.done() { + t.Fatalf("After protocol error on RoundTrip #2, RoundTrip #2 is done; want still in progress") } - ct.server = func() error { - ct.greet() - var sentErr bool - var numHeaders int - var firstStreamID uint32 - var hbuf bytes.Buffer - enc := hpack.NewEncoder(&hbuf) + // Request #2(b): The short request is retried on a new connection. + tc2 := tt.getConn() + tc2.wantFrameType(FrameSettings) + tc2.wantFrameType(FrameWindowUpdate) + tc2.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + }) + tc2.writeSettings() + tc2.wantFrameType(FrameSettings) // settings ACK - for { - f, err := ct.fr.ReadFrame() - if err == io.EOF { - // Client hung up on us, as it should at the end. - return nil - } - if err != nil { - return nil - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - numHeaders++ - if numHeaders == 1 { - firstStreamID = f.StreamID - hbuf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "is-long", Value: "1"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: hbuf.Bytes(), - }) - continue - } - if !sentErr { - sentErr = true - ct.fr.WriteRSTStream(f.StreamID, ErrCodeProtocol) - ct.fr.WriteData(firstStreamID, true, nil) - continue - } - } - } - } - ct.run() + // Request #2(b) succeeds. + tc2.writeHeaders(HeadersFrameParam{ + StreamID: 1, + EndHeaders: true, + EndStream: true, + BlockFragment: tc1.makeHeaderBlockFragment( + ":status", "201", + ), + }) + rt2.wantStatus(201) + + // Request #1 succeeds. + tc1.writeHeaders(HeadersFrameParam{ + StreamID: 1, + EndHeaders: true, + EndStream: true, + BlockFragment: tc1.makeHeaderBlockFragment( + ":status", "200", + ), + }) + rt1.wantStatus(200) } func TestClientConnReservations(t *testing.T) {