From 729386c49596896faa6a1a6a2bbb865271e98db2 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 13 Sep 2021 14:54:53 +0200 Subject: [PATCH] don't use a context to shut down the circuitv2 --- config/config.go | 14 ++++------- p2p/protocol/circuitv2/client/client.go | 28 ++++++++++++++++------ p2p/protocol/circuitv2/client/listen.go | 6 ++--- p2p/protocol/circuitv2/client/transport.go | 6 +++-- p2p/protocol/circuitv2/test/compat_test.go | 4 ++-- p2p/protocol/circuitv2/test/e2e_test.go | 23 +++++++++--------- 6 files changed, 46 insertions(+), 35 deletions(-) diff --git a/config/config.go b/config/config.go index 1111ea33b9..92c8ceaf39 100644 --- a/config/config.go +++ b/config/config.go @@ -137,7 +137,7 @@ func (cfg *Config) makeSwarm(ctx context.Context) (*swarm.Swarm, error) { return swrm, nil } -func (cfg *Config) addTransports(ctx context.Context, h host.Host) (err error) { +func (cfg *Config) addTransports(h host.Host) (err error) { swrm, ok := h.Network().(transport.TransportNetwork) if !ok { // Should probably skip this if no transports. @@ -165,15 +165,13 @@ func (cfg *Config) addTransports(ctx context.Context, h host.Host) (err error) { return err } for _, t := range tpts { - err = swrm.AddTransport(t) - if err != nil { + if err := swrm.AddTransport(t); err != nil { return err } } if cfg.Relay { - err := circuitv2.AddTransport(ctx, h, upgrader) - if err != nil { + if err := circuitv2.AddTransport(h, upgrader); err != nil { h.Close() return err } @@ -225,8 +223,7 @@ func (cfg *Config) NewNode(ctx context.Context) (host.Host, error) { } } - err = cfg.addTransports(ctx, h) - if err != nil { + if err := cfg.addTransports(h); err != nil { h.Close() return nil, err } @@ -314,8 +311,7 @@ func (cfg *Config) NewNode(ctx context.Context) (host.Host, error) { return nil, err } dialerHost := blankhost.NewBlankHost(dialer) - err = autoNatCfg.addTransports(ctx, dialerHost) - if err != nil { + if err := autoNatCfg.addTransports(dialerHost); err != nil { dialerHost.Close() h.Close() return nil, err diff --git a/p2p/protocol/circuitv2/client/client.go b/p2p/protocol/circuitv2/client/client.go index e6ee5a4b94..c1b8aec05e 100644 --- a/p2p/protocol/circuitv2/client/client.go +++ b/p2p/protocol/circuitv2/client/client.go @@ -2,12 +2,14 @@ package client import ( "context" + "io" "sync" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/proto" "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/transport" logging "github.com/ipfs/go-log" tptu "github.com/libp2p/go-libp2p-transport-upgrader" @@ -24,9 +26,10 @@ var log = logging.Logger("p2p-circuit") // This allows us to use the v2 code as drop in replacement for v1 in a host without breaking // existing code and interoperability with older nodes. type Client struct { - ctx context.Context - host host.Host - upgrader *tptu.Upgrader + ctx context.Context + ctxCancel context.CancelFunc + host host.Host + upgrader *tptu.Upgrader incoming chan accept @@ -35,6 +38,9 @@ type Client struct { hopCount map[peer.ID]int } +var _ io.Closer = &Client{} +var _ transport.Transport = &Client{} + type accept struct { conn *Conn writeResponse func() error @@ -48,15 +54,16 @@ type completion struct { // New constructs a new p2p-circuit/v2 client, attached to the given host and using the given // upgrader to perform connection upgrades. -func New(ctx context.Context, h host.Host, upgrader *tptu.Upgrader) (*Client, error) { - return &Client{ - ctx: ctx, +func New(h host.Host, upgrader *tptu.Upgrader) (*Client, error) { + cl := &Client{ host: h, upgrader: upgrader, incoming: make(chan accept), activeDials: make(map[peer.ID]*completion), hopCount: make(map[peer.ID]int), - }, nil + } + cl.ctx, cl.ctxCancel = context.WithCancel(context.Background()) + return cl, nil } // Start registers the circuit (client) protocol stream handlers @@ -64,3 +71,10 @@ func (c *Client) Start() { c.host.SetStreamHandler(proto.ProtoIDv1, c.handleStreamV1) c.host.SetStreamHandler(proto.ProtoIDv2Stop, c.handleStreamV2) } + +func (c *Client) Close() error { + c.ctxCancel() + c.host.RemoveStreamHandler(proto.ProtoIDv1) + c.host.RemoveStreamHandler(proto.ProtoIDv2Stop) + return nil +} diff --git a/p2p/protocol/circuitv2/client/listen.go b/p2p/protocol/circuitv2/client/listen.go index fcfb9fa690..0d44ac726d 100644 --- a/p2p/protocol/circuitv2/client/listen.go +++ b/p2p/protocol/circuitv2/client/listen.go @@ -1,6 +1,7 @@ package client import ( + "errors" "net" ma "github.com/multiformats/go-multiaddr" @@ -32,7 +33,7 @@ func (l *Listener) Accept() (manet.Conn, error) { return evt.conn, nil case <-l.ctx.Done(): - return nil, l.ctx.Err() + return nil, errors.New("circuit v2 client closed") } } } @@ -49,6 +50,5 @@ func (l *Listener) Multiaddr() ma.Multiaddr { } func (l *Listener) Close() error { - // noop for now - return nil + return (*Client)(l).Close() } diff --git a/p2p/protocol/circuitv2/client/transport.go b/p2p/protocol/circuitv2/client/transport.go index 406188cf42..b40adb8fc1 100644 --- a/p2p/protocol/circuitv2/client/transport.go +++ b/p2p/protocol/circuitv2/client/transport.go @@ -3,6 +3,7 @@ package client import ( "context" "fmt" + "io" "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/peer" @@ -17,13 +18,13 @@ var circuitAddr = ma.Cast(circuitProtocol.VCode) // AddTransport constructs a new p2p-circuit/v2 client and adds it as a transport to the // host network -func AddTransport(ctx context.Context, h host.Host, upgrader *tptu.Upgrader) error { +func AddTransport(h host.Host, upgrader *tptu.Upgrader) error { n, ok := h.Network().(transport.TransportNetwork) if !ok { return fmt.Errorf("%v is not a transport network", h.Network()) } - c, err := New(ctx, h, upgrader) + c, err := New(h, upgrader) if err != nil { return fmt.Errorf("error constructing circuit client: %w", err) } @@ -45,6 +46,7 @@ func AddTransport(ctx context.Context, h host.Host, upgrader *tptu.Upgrader) err // Transport interface var _ transport.Transport = (*Client)(nil) +var _ io.Closer = (*Client)(nil) func (c *Client) Dial(ctx context.Context, a ma.Multiaddr, p peer.ID) (transport.CapableConn, error) { conn, err := c.dial(ctx, a, p) diff --git a/p2p/protocol/circuitv2/test/compat_test.go b/p2p/protocol/circuitv2/test/compat_test.go index bcecba022f..750dfd341c 100644 --- a/p2p/protocol/circuitv2/test/compat_test.go +++ b/p2p/protocol/circuitv2/test/compat_test.go @@ -30,7 +30,7 @@ func TestRelayCompatV2DialV1(t *testing.T) { hosts, upgraders := getNetHosts(t, ctx, 3) addTransportV1(t, ctx, hosts[0], upgraders[0]) - addTransport(t, ctx, hosts[2], upgraders[2]) + addTransport(t, hosts[2], upgraders[2]) rch := make(chan []byte, 1) hosts[0].SetStreamHandler("test", func(s network.Stream) { @@ -105,7 +105,7 @@ func TestRelayCompatV1DialV2(t *testing.T) { defer cancel() hosts, upgraders := getNetHosts(t, ctx, 3) - addTransport(t, ctx, hosts[0], upgraders[0]) + addTransport(t, hosts[0], upgraders[0]) addTransportV1(t, ctx, hosts[2], upgraders[2]) rch := make(chan []byte, 1) diff --git a/p2p/protocol/circuitv2/test/e2e_test.go b/p2p/protocol/circuitv2/test/e2e_test.go index cf4c9bd2b9..905566b77a 100644 --- a/p2p/protocol/circuitv2/test/e2e_test.go +++ b/p2p/protocol/circuitv2/test/e2e_test.go @@ -20,12 +20,12 @@ import ( logging "github.com/ipfs/go-log" bhost "github.com/libp2p/go-libp2p-blankhost" - metrics "github.com/libp2p/go-libp2p-core/metrics" - pstoremem "github.com/libp2p/go-libp2p-peerstore/pstoremem" + "github.com/libp2p/go-libp2p-core/metrics" + "github.com/libp2p/go-libp2p-peerstore/pstoremem" swarm "github.com/libp2p/go-libp2p-swarm" swarmt "github.com/libp2p/go-libp2p-swarm/testing" tptu "github.com/libp2p/go-libp2p-transport-upgrader" - tcp "github.com/libp2p/go-tcp-transport" + "github.com/libp2p/go-tcp-transport" ma "github.com/multiformats/go-multiaddr" ) @@ -85,9 +85,8 @@ func connect(t *testing.T, a, b host.Host) { } } -func addTransport(t *testing.T, ctx context.Context, h host.Host, upgrader *tptu.Upgrader) { - err := client.AddTransport(ctx, h, upgrader) - if err != nil { +func addTransport(t *testing.T, h host.Host, upgrader *tptu.Upgrader) { + if err := client.AddTransport(h, upgrader); err != nil { t.Fatal(err) } } @@ -97,8 +96,8 @@ func TestBasicRelay(t *testing.T) { defer cancel() hosts, upgraders := getNetHosts(t, ctx, 3) - addTransport(t, ctx, hosts[0], upgraders[0]) - addTransport(t, ctx, hosts[2], upgraders[2]) + addTransport(t, hosts[0], upgraders[0]) + addTransport(t, hosts[2], upgraders[2]) rch := make(chan []byte, 1) hosts[0].SetStreamHandler("test", func(s network.Stream) { @@ -184,8 +183,8 @@ func TestRelayLimitTime(t *testing.T) { defer cancel() hosts, upgraders := getNetHosts(t, ctx, 3) - addTransport(t, ctx, hosts[0], upgraders[0]) - addTransport(t, ctx, hosts[2], upgraders[2]) + addTransport(t, hosts[0], upgraders[0]) + addTransport(t, hosts[2], upgraders[2]) rch := make(chan error, 1) hosts[0].SetStreamHandler("test", func(s network.Stream) { @@ -258,8 +257,8 @@ func TestRelayLimitData(t *testing.T) { defer cancel() hosts, upgraders := getNetHosts(t, ctx, 3) - addTransport(t, ctx, hosts[0], upgraders[0]) - addTransport(t, ctx, hosts[2], upgraders[2]) + addTransport(t, hosts[0], upgraders[0]) + addTransport(t, hosts[2], upgraders[2]) rch := make(chan int, 1) hosts[0].SetStreamHandler("test", func(s network.Stream) {