diff --git a/go.mod b/go.mod index 00808fdc58..2e2edc1092 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 github.com/flynn/noise v1.1.0 github.com/google/gopacket v1.1.19 - github.com/gorilla/websocket v1.5.1 + github.com/gorilla/websocket v1.5.3 github.com/hashicorp/golang-lru/arc/v2 v2.0.7 github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/ipfs/go-cid v0.4.1 @@ -55,7 +55,7 @@ require ( github.com/quic-go/webtransport-go v0.8.0 github.com/raulk/go-watchdog v1.3.0 github.com/stretchr/testify v1.9.0 - go.uber.org/fx v1.21.1 + go.uber.org/fx v1.22.1 go.uber.org/goleak v1.3.0 go.uber.org/mock v0.4.0 golang.org/x/crypto v0.23.0 @@ -126,9 +126,3 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect lukechampine.com/blake3 v1.2.1 // indirect ) - -// Remove this once fx releases the next version. -// We want to ship with a fix around SIGINT handling: -// https://github.com/uber-go/fx/pull/1198. -// Context: https://github.com/libp2p/go-libp2p/issues/2785 -replace go.uber.org/fx v1.21.1 => github.com/uber-go/fx v1.21.2-0.20240515133256-cb9cccf55845 diff --git a/go.sum b/go.sum index a2a327b99b..3c44afb25e 100644 --- a/go.sum +++ b/go.sum @@ -121,8 +121,8 @@ github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= -github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= -github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= github.com/hashicorp/golang-lru/arc/v2 v2.0.7 h1:QxkVTxwColcduO+LP7eJO56r2hFiG8zEbfAAzRv52KQ= @@ -401,8 +401,6 @@ github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8 github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE= github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= -github.com/uber-go/fx v1.21.2-0.20240515133256-cb9cccf55845 h1:1ZbnuG7aj1UxZnfsJmEpACmspZMkj5Fdvg7C1yWgQCE= -github.com/uber-go/fx v1.21.2-0.20240515133256-cb9cccf55845/go.mod h1:HT2M7d7RHo+ebKGh9NRcrsrHHfpZ60nW3QRubMRfv48= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= github.com/urfave/cli v1.22.2/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= @@ -417,6 +415,8 @@ go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/dig v1.17.1 h1:Tga8Lz8PcYNsWsyHMZ1Vm0OQOUaJNDyvPImgbAu9YSc= go.uber.org/dig v1.17.1/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE= +go.uber.org/fx v1.22.1 h1:nvvln7mwyT5s1q201YE29V/BFrGor6vMiDNpU/78Mys= +go.uber.org/fx v1.22.1/go.mod h1:HT2M7d7RHo+ebKGh9NRcrsrHHfpZ60nW3QRubMRfv48= go.uber.org/goleak v1.1.11-0.20210813005559-691160354723/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= diff --git a/p2p/host/peerstore/pstoreds/protobook.go b/p2p/host/peerstore/pstoreds/protobook.go index 40fa7d951b..9ef7d1c9fa 100644 --- a/p2p/host/peerstore/pstoreds/protobook.go +++ b/p2p/host/peerstore/pstoreds/protobook.go @@ -48,7 +48,7 @@ func NewProtoBook(meta pstore.PeerMetadata, opts ...ProtoBookOption) (*dsProtoBo } return ret }(), - maxProtos: 1024, + maxProtos: 128, } for _, opt := range opts { diff --git a/p2p/host/peerstore/pstoremem/protobook.go b/p2p/host/peerstore/pstoremem/protobook.go index 51c4b0282a..b28ffe11be 100644 --- a/p2p/host/peerstore/pstoremem/protobook.go +++ b/p2p/host/peerstore/pstoremem/protobook.go @@ -26,9 +26,6 @@ type memoryProtoBook struct { segments protoSegments maxProtos int - - lk sync.RWMutex - interned map[protocol.ID]protocol.ID } var _ pstore.ProtoBook = (*memoryProtoBook)(nil) @@ -44,7 +41,6 @@ func WithMaxProtocols(num int) ProtoBookOption { func NewProtoBook(opts ...ProtoBookOption) (*memoryProtoBook, error) { pb := &memoryProtoBook{ - interned: make(map[protocol.ID]protocol.ID, 256), segments: func() (ret protoSegments) { for i := range ret { ret[i] = &protoSegment{ @@ -53,7 +49,7 @@ func NewProtoBook(opts ...ProtoBookOption) (*memoryProtoBook, error) { } return ret }(), - maxProtos: 1024, + maxProtos: 128, } for _, opt := range opts { @@ -64,30 +60,6 @@ func NewProtoBook(opts ...ProtoBookOption) (*memoryProtoBook, error) { return pb, nil } -func (pb *memoryProtoBook) internProtocol(proto protocol.ID) protocol.ID { - // check if it is interned with the read lock - pb.lk.RLock() - interned, ok := pb.interned[proto] - pb.lk.RUnlock() - - if ok { - return interned - } - - // intern with the write lock - pb.lk.Lock() - defer pb.lk.Unlock() - - // check again in case it got interned in between locks - interned, ok = pb.interned[proto] - if ok { - return interned - } - - pb.interned[proto] = proto - return proto -} - func (pb *memoryProtoBook) SetProtocols(p peer.ID, protos ...protocol.ID) error { if len(protos) > pb.maxProtos { return errTooManyProtocols @@ -95,7 +67,7 @@ func (pb *memoryProtoBook) SetProtocols(p peer.ID, protos ...protocol.ID) error newprotos := make(map[protocol.ID]struct{}, len(protos)) for _, proto := range protos { - newprotos[pb.internProtocol(proto)] = struct{}{} + newprotos[proto] = struct{}{} } s := pb.segments.get(p) @@ -121,7 +93,7 @@ func (pb *memoryProtoBook) AddProtocols(p peer.ID, protos ...protocol.ID) error } for _, proto := range protos { - protomap[pb.internProtocol(proto)] = struct{}{} + protomap[proto] = struct{}{} } return nil } @@ -151,7 +123,10 @@ func (pb *memoryProtoBook) RemoveProtocols(p peer.ID, protos ...protocol.ID) err } for _, proto := range protos { - delete(protomap, pb.internProtocol(proto)) + delete(protomap, proto) + } + if len(protomap) == 0 { + delete(s.protocols, p) } return nil } diff --git a/p2p/host/pstoremanager/pstoremanager.go b/p2p/host/pstoremanager/pstoremanager.go index 93cc2a98d9..f4a20f8ac4 100644 --- a/p2p/host/pstoremanager/pstoremanager.go +++ b/p2p/host/pstoremanager/pstoremanager.go @@ -121,10 +121,12 @@ func (m *PeerstoreManager) background(ctx context.Context, sub event.Subscriptio // Check that the peer is actually not connected at this point. // This avoids a race condition where the Connected notification // is processed after this time has fired. - if m.network.Connectedness(p) != network.Connected { + switch m.network.Connectedness(p) { + case network.Connected, network.Limited: + default: m.pstore.RemovePeer(p) - delete(disconnected, p) } + delete(disconnected, p) } } case <-ctx.Done(): diff --git a/p2p/protocol/identify/id_test.go b/p2p/protocol/identify/id_test.go index a65d64f24e..904e47cece 100644 --- a/p2p/protocol/identify/id_test.go +++ b/p2p/protocol/identify/id_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "math/rand" "slices" "sync" "testing" @@ -730,6 +731,15 @@ func TestLargeIdentifyMessage(t *testing.T) { } } +func randString(n int) string { + chars := "abcdefghijklmnopqrstuvwxyz" + buf := make([]byte, n) + for i := 0; i < n; i++ { + buf[i] = chars[rand.Intn(len(chars))] + } + return string(buf) +} + func TestLargePushMessage(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -738,9 +748,9 @@ func TestLargePushMessage(t *testing.T) { h2 := blhost.NewBlankHost(swarmt.GenSwarm(t)) // add protocol strings to make the message larger - // about 2K of protocol strings - for i := 0; i < 500; i++ { - r := protocol.ID(fmt.Sprintf("rand%d", i)) + // about 3K of protocol strings + for i := 0; i < 100; i++ { + r := protocol.ID(fmt.Sprintf("%s-%d", randString(30), i)) h1.SetStreamHandler(r, func(network.Stream) {}) h2.SetStreamHandler(r, func(network.Stream) {}) } diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index e39b72a71a..e75118ccf4 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -20,7 +20,9 @@ import ( "github.com/libp2p/go-libp2p/core/connmgr" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" + mocknetwork "github.com/libp2p/go-libp2p/core/network/mocks" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/sec" rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager" "github.com/libp2p/go-libp2p/p2p/muxer/yamux" @@ -29,8 +31,9 @@ import ( "github.com/libp2p/go-libp2p/p2p/security/noise" tls "github.com/libp2p/go-libp2p/p2p/security/tls" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" + "go.uber.org/mock/gomock" - "github.com/multiformats/go-multiaddr" + ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" ) @@ -669,7 +672,7 @@ func TestDiscoverPeerIDFromSecurityNegotiation(t *testing.T) { ai := &peer.AddrInfo{ ID: bogusPeerId, - Addrs: []multiaddr.Multiaddr{h1.Addrs()[0]}, + Addrs: []ma.Multiaddr{h1.Addrs()[0]}, } // Try connecting with the bogus peer ID @@ -688,3 +691,34 @@ func TestDiscoverPeerIDFromSecurityNegotiation(t *testing.T) { }) } } + +// TestCloseConnWhenBlocked tests that the server closes the connection when the rcmgr blocks it. +func TestCloseConnWhenBlocked(t *testing.T) { + for _, tc := range transportsToTest { + if tc.Name == "WebRTC" { + continue // WebRTC doesn't have a connection when we block so there's nothing to close + } + t.Run(tc.Name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockRcmgr := mocknetwork.NewMockResourceManager(ctrl) + mockRcmgr.EXPECT().OpenConnection(network.DirInbound, gomock.Any(), gomock.Any()).DoAndReturn(func(network.Direction, bool, ma.Multiaddr) (network.ConnManagementScope, error) { + // Block the connection + return nil, fmt.Errorf("connections blocked") + }) + mockRcmgr.EXPECT().Close().AnyTimes() + + server := tc.HostGenerator(t, TransportTestCaseOpts{ResourceManager: mockRcmgr}) + client := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) + defer server.Close() + defer client.Close() + + client.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, err := client.NewStream(ctx, server.ID(), ping.ID) + require.Error(t, err) + require.False(t, errors.Is(err, context.DeadlineExceeded), "expected error to be not be context deadline exceeded") + }) + } +} diff --git a/p2p/transport/quic/listener.go b/p2p/transport/quic/listener.go index d49b686497..0c69741358 100644 --- a/p2p/transport/quic/listener.go +++ b/p2p/transport/quic/listener.go @@ -51,8 +51,10 @@ func (l *listener) Accept() (tpt.CapableConn, error) { if err != nil { return nil, err } - c, err := l.setupConn(qconn) + c, err := l.wrapConn(qconn) if err != nil { + log.Debugf("failed to setup connection: %s", err) + qconn.CloseWithError(1, "") continue } l.transport.addConn(qconn, c) @@ -79,7 +81,10 @@ func (l *listener) Accept() (tpt.CapableConn, error) { } } -func (l *listener) setupConn(qconn quic.Connection) (*conn, error) { +// wrapConn wraps a QUIC connection into a libp2p [tpt.CapableConn]. +// If wrapping fails. The caller is responsible for cleaning up the +// connection. +func (l *listener) wrapConn(qconn quic.Connection) (*conn, error) { remoteMultiaddr, err := quicreuse.ToQuicMultiaddr(qconn.RemoteAddr(), qconn.ConnectionState().Version) if err != nil { return nil, err @@ -90,18 +95,16 @@ func (l *listener) setupConn(qconn quic.Connection) (*conn, error) { log.Debugw("resource manager blocked incoming connection", "addr", qconn.RemoteAddr(), "error", err) return nil, err } - c, err := l.setupConnWithScope(qconn, connScope, remoteMultiaddr) + c, err := l.wrapConnWithScope(qconn, connScope, remoteMultiaddr) if err != nil { connScope.Done() - qconn.CloseWithError(1, "") return nil, err } return c, nil } -func (l *listener) setupConnWithScope(qconn quic.Connection, connScope network.ConnManagementScope, remoteMultiaddr ma.Multiaddr) (*conn, error) { - +func (l *listener) wrapConnWithScope(qconn quic.Connection, connScope network.ConnManagementScope, remoteMultiaddr ma.Multiaddr) (*conn, error) { // The tls.Config used to establish this connection already verified the certificate chain. // Since we don't have any way of knowing which tls.Config was used though, // we have to re-determine the peer's identity here. diff --git a/p2p/transport/quic/listener_test.go b/p2p/transport/quic/listener_test.go index d739c82c44..dbd6d810e4 100644 --- a/p2p/transport/quic/listener_test.go +++ b/p2p/transport/quic/listener_test.go @@ -1,9 +1,11 @@ package libp2pquic import ( + "context" "crypto/rand" "crypto/rsa" "crypto/x509" + "errors" "fmt" "io" "net" @@ -12,8 +14,11 @@ import ( ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" + mocknetwork "github.com/libp2p/go-libp2p/core/network/mocks" tpt "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" + "github.com/quic-go/quic-go" + "go.uber.org/mock/gomock" ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" @@ -113,3 +118,51 @@ func TestCorrectNumberOfVirtualListeners(t *testing.T) { ln.Close() require.Empty(t, tpt.listeners[udpAddr.String()]) } + +func TestCleanupConnWhenBlocked(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockRcmgr := mocknetwork.NewMockResourceManager(ctrl) + mockRcmgr.EXPECT().OpenConnection(network.DirInbound, false, gomock.Any()).DoAndReturn(func(network.Direction, bool, ma.Multiaddr) (network.ConnManagementScope, error) { + // Block the connection + return nil, fmt.Errorf("connections blocked") + }) + + server := newTransport(t, mockRcmgr) + serverTpt := server.(*transport) + defer server.(io.Closer).Close() + + localAddrV1 := ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1") + ln, err := server.Listen(localAddrV1) + require.NoError(t, err) + defer ln.Close() + go ln.Accept() + + client := newTransport(t, nil) + ctx := context.Background() + + var quicErr *quic.ApplicationError = &quic.ApplicationError{} + conn, err := client.Dial(ctx, ln.Multiaddr(), serverTpt.localPeer) + if err != nil && errors.As(err, &quicErr) { + // We hit our expected application error + return + } + + // No error yet, let's continue using the conn + s, err := conn.OpenStream(ctx) + if err != nil && errors.As(err, &quicErr) { + // We hit our expected application error + return + } + + // No error yet, let's continue using the conn + s.SetReadDeadline(time.Now().Add(10 * time.Second)) + b := [1]byte{} + _, err = s.Read(b[:]) + if err != nil && errors.As(err, &quicErr) { + // We hit our expected application error + return + } + + t.Fatalf("expected application error, got %v", err) +} diff --git a/p2p/transport/webtransport/conn.go b/p2p/transport/webtransport/conn.go index 0e83b1d16f..0525124711 100644 --- a/p2p/transport/webtransport/conn.go +++ b/p2p/transport/webtransport/conn.go @@ -7,6 +7,7 @@ import ( tpt "github.com/libp2p/go-libp2p/core/transport" ma "github.com/multiformats/go-multiaddr" + "github.com/quic-go/quic-go" "github.com/quic-go/webtransport-go" ) @@ -31,16 +32,18 @@ type conn struct { session *webtransport.Session scope network.ConnManagementScope + qconn quic.Connection } var _ tpt.CapableConn = &conn{} -func newConn(tr *transport, sess *webtransport.Session, sconn *connSecurityMultiaddrs, scope network.ConnManagementScope) *conn { +func newConn(tr *transport, sess *webtransport.Session, sconn *connSecurityMultiaddrs, scope network.ConnManagementScope, qconn quic.Connection) *conn { return &conn{ connSecurityMultiaddrs: sconn, transport: tr, session: sess, scope: scope, + qconn: qconn, } } @@ -70,7 +73,9 @@ func (c *conn) allowWindowIncrease(size uint64) bool { func (c *conn) Close() error { c.scope.Done() c.transport.removeConn(c.session) - return c.session.CloseWithError(0, "") + err := c.session.CloseWithError(0, "") + _ = c.qconn.CloseWithError(1, "") + return err } func (c *conn) IsClosed() bool { return c.session.Context().Err() != nil } diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index 2a7c3546f2..ff611fe927 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -15,12 +15,61 @@ import ( "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" ma "github.com/multiformats/go-multiaddr" + "github.com/quic-go/quic-go" + "github.com/quic-go/quic-go/http3" "github.com/quic-go/webtransport-go" ) const queueLen = 16 const handshakeTimeout = 10 * time.Second +type connKey struct{} + +// negotiatingConn is a wrapper around a quic.Connection that lets us wrap it in +// our own context for the duration of the upgrade process. Upgrading a quic +// connection to an h3 connection to a webtransport session. +type negotiatingConn struct { + quic.Connection + ctx context.Context + cancel context.CancelFunc + // stopClose is a function that stops the connection from being closed when + // the context is done. Returns true if the connection close function was + // not called. + stopClose func() bool + err error +} + +func (c *negotiatingConn) Unwrap() (quic.Connection, error) { + defer c.cancel() + if c.stopClose != nil { + // unwrap the first time + if !c.stopClose() { + c.err = errTimeout + } + c.stopClose = nil + } + if c.err != nil { + return nil, c.err + } + return c.Connection, nil +} + +func wrapConn(ctx context.Context, c quic.Connection, handshakeTimeout time.Duration) *negotiatingConn { + ctx, cancel := context.WithTimeout(ctx, handshakeTimeout) + stopClose := context.AfterFunc(ctx, func() { + log.Debugf("failed to handshake on conn: %s", c.RemoteAddr()) + c.CloseWithError(1, "") + }) + return &negotiatingConn{ + Connection: c, + ctx: ctx, + cancel: cancel, + stopClose: stopClose, + } +} + +var errTimeout = errors.New("timeout") + type listener struct { transport *transport isStaticTLSConf bool @@ -56,6 +105,11 @@ func newListener(reuseListener quicreuse.Listener, t *transport, isStaticTLSConf addr: reuseListener.Addr(), multiaddr: localMultiaddr, server: webtransport.Server{ + H3: http3.Server{ + ConnContext: func(ctx context.Context, c quic.Connection) context.Context { + return context.WithValue(ctx, connKey{}, c) + }, + }, CheckOrigin: func(r *http.Request) bool { return true }, }, } @@ -71,7 +125,8 @@ func newListener(reuseListener quicreuse.Listener, t *transport, isStaticTLSConf log.Debugw("serving failed", "addr", ln.Addr(), "error", err) return } - go ln.server.ServeQUICConn(conn) + wrapped := wrapConn(ln.ctx, conn, t.handshakeTimeout) + go ln.server.ServeQUICConn(wrapped) } }() return ln, nil @@ -137,13 +192,32 @@ func (l *listener) httpHandlerWithConnScope(w http.ResponseWriter, r *http.Reque return err } - conn := newConn(l.transport, sess, sconn, connScope) + connVal := r.Context().Value(connKey{}) + if connVal == nil { + log.Errorf("missing conn from context") + sess.CloseWithError(1, "") + return errors.New("invalid context") + } + nconn, ok := connVal.(*negotiatingConn) + if !ok { + log.Errorf("unexpected connection in context. invalid conn type: %T", nconn) + sess.CloseWithError(1, "") + return errors.New("invalid context") + } + qconn, err := nconn.Unwrap() + if err != nil { + log.Debugf("handshake timed out: %s", r.RemoteAddr) + sess.CloseWithError(1, "") + return err + } + + conn := newConn(l.transport, sess, sconn, connScope, qconn) l.transport.addConn(sess, conn) select { case l.queue <- conn: default: log.Debugw("accept queue full, dropping incoming connection", "peer", sconn.RemotePeer(), "addr", r.RemoteAddr, "error", err) - sess.CloseWithError(1, "") + conn.Close() return errors.New("accept queue full") } diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 97172703f7..ef8551d60f 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -60,6 +60,13 @@ func WithTLSClientConfig(c *tls.Config) Option { } } +func WithHandshakeTimeout(d time.Duration) Option { + return func(t *transport) error { + t.handshakeTimeout = d + return nil + } +} + type transport struct { privKey ic.PrivKey pid peer.ID @@ -78,8 +85,9 @@ type transport struct { noise *noise.Transport - connMx sync.Mutex - conns map[quic.ConnectionTracingID]*conn // using quic-go's ConnectionTracingKey as map key + connMx sync.Mutex + conns map[quic.ConnectionTracingID]*conn // using quic-go's ConnectionTracingKey as map key + handshakeTimeout time.Duration } var _ tpt.Transport = &transport{} @@ -99,13 +107,14 @@ func New(key ic.PrivKey, psk pnet.PSK, connManager *quicreuse.ConnManager, gater return nil, err } t := &transport{ - pid: id, - privKey: key, - rcmgr: rcmgr, - gater: gater, - clock: clock.New(), - connManager: connManager, - conns: map[quic.ConnectionTracingID]*conn{}, + pid: id, + privKey: key, + rcmgr: rcmgr, + gater: gater, + clock: clock.New(), + connManager: connManager, + conns: map[quic.ConnectionTracingID]*conn{}, + handshakeTimeout: handshakeTimeout, } for _, opt := range opts { if err := opt(t); err != nil { @@ -159,7 +168,7 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee } maddr, _ := ma.SplitFunc(raddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_WEBTRANSPORT }) - sess, err := t.dial(ctx, maddr, url, sni, certHashes) + sess, qconn, err := t.dial(ctx, maddr, url, sni, certHashes) if err != nil { return nil, err } @@ -172,12 +181,12 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee sess.CloseWithError(errorCodeConnectionGating, "") return nil, fmt.Errorf("secured connection gated") } - conn := newConn(t, sess, sconn, scope) + conn := newConn(t, sess, sconn, scope, qconn) t.addConn(sess, conn) return conn, nil } -func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string, certHashes []multihash.DecodedMultihash) (*webtransport.Session, error) { +func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string, certHashes []multihash.DecodedMultihash) (*webtransport.Session, quic.Connection, error) { var tlsConf *tls.Config if t.tlsClientConf != nil { tlsConf = t.tlsClientConf.Clone() @@ -200,7 +209,7 @@ func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string } conn, err := t.connManager.DialQUIC(ctx, addr, tlsConf, t.allowWindowIncrease) if err != nil { - return nil, err + return nil, nil, err } dialer := webtransport.Dialer{ DialAddr: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { @@ -210,12 +219,14 @@ func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string } rsp, sess, err := dialer.Dial(ctx, url, nil) if err != nil { - return nil, err + conn.CloseWithError(1, "") + return nil, nil, err } if rsp.StatusCode < 200 || rsp.StatusCode > 299 { - return nil, fmt.Errorf("invalid response status code: %d", rsp.StatusCode) + conn.CloseWithError(1, "") + return nil, nil, fmt.Errorf("invalid response status code: %d", rsp.StatusCode) } - return sess, err + return sess, conn, err } func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash) (*connSecurityMultiaddrs, error) { diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index f6c850a2b9..bd41446218 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "net" + "net/http" "os" "runtime" "sync/atomic" @@ -827,3 +828,37 @@ func TestServerRotatesCertCorrectlyAfterSteps(t *testing.T) { require.True(t, found, "Failed after hour: %v", i) } } + +func TestH3ConnClosed(t *testing.T) { + _, serverKey := newIdentity(t) + tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, nil, libp2pwebtransport.WithHandshakeTimeout(1*time.Second)) + require.NoError(t, err) + defer tr.(io.Closer).Close() + ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) + require.NoError(t, err) + defer ln.Close() + + p, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + conn, err := quic.Dial(context.Background(), p, ln.Addr(), &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{http3.NextProtoH3}, + }, nil) + require.NoError(t, err) + rt := &http3.SingleDestinationRoundTripper{ + Connection: conn, + } + rt.Start() + require.Eventually(t, func() bool { + c := http.Client{ + Transport: rt, + Timeout: 1 * time.Second, + } + resp, err := c.Get(fmt.Sprintf("https://%s", ln.Addr().String())) + if err != nil { + return true + } + resp.Body.Close() + return false + }, 10*time.Second, 1*time.Second) +} diff --git a/test-plans/go.mod b/test-plans/go.mod index 6e4ff51d3d..e19fd63899 100644 --- a/test-plans/go.mod +++ b/test-plans/go.mod @@ -28,7 +28,7 @@ require ( github.com/google/gopacket v1.1.19 // indirect github.com/google/pprof v0.0.0-20240207164012-fb44976bdcd5 // indirect github.com/google/uuid v1.4.0 // indirect - github.com/gorilla/websocket v1.5.1 // indirect + github.com/gorilla/websocket v1.5.3 // indirect github.com/huin/goupnp v1.3.0 // indirect github.com/ipfs/go-cid v0.4.1 // indirect github.com/ipfs/go-log/v2 v2.5.1 // indirect @@ -93,7 +93,7 @@ require ( github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/stretchr/testify v1.9.0 // indirect go.uber.org/dig v1.17.1 // indirect - go.uber.org/fx v1.21.1 // indirect + go.uber.org/fx v1.22.1 // indirect go.uber.org/mock v0.4.0 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect diff --git a/test-plans/go.sum b/test-plans/go.sum index fdd19bbb02..5e9755f1f0 100644 --- a/test-plans/go.sum +++ b/test-plans/go.sum @@ -101,8 +101,8 @@ github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= -github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= -github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= github.com/huin/goupnp v1.3.0 h1:UvLUlWDNpoUdYzb2TCn+MuTWtcjXKSza2n6CBdQ0xXc= @@ -350,8 +350,8 @@ go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/dig v1.17.1 h1:Tga8Lz8PcYNsWsyHMZ1Vm0OQOUaJNDyvPImgbAu9YSc= go.uber.org/dig v1.17.1/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE= -go.uber.org/fx v1.21.1 h1:RqBh3cYdzZS0uqwVeEjOX2p73dddLpym315myy/Bpb0= -go.uber.org/fx v1.21.1/go.mod h1:HT2M7d7RHo+ebKGh9NRcrsrHHfpZ60nW3QRubMRfv48= +go.uber.org/fx v1.22.1 h1:nvvln7mwyT5s1q201YE29V/BFrGor6vMiDNpU/78Mys= +go.uber.org/fx v1.22.1/go.mod h1:HT2M7d7RHo+ebKGh9NRcrsrHHfpZ60nW3QRubMRfv48= go.uber.org/goleak v1.1.11-0.20210813005559-691160354723/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= diff --git a/version.json b/version.json index f765664902..808595258c 100644 --- a/version.json +++ b/version.json @@ -1,3 +1,3 @@ { - "version": "v0.35.1" + "version": "v0.35.2" }