Skip to content

Commit

Permalink
Merge pull request #1185 from libp2p/circuit-shutdown
Browse files Browse the repository at this point in the history
don't use a context to shut down the circuitv2
  • Loading branch information
marten-seemann committed Sep 17, 2021
2 parents 7cb03db + 729386c commit a5f982f
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 35 deletions.
14 changes: 5 additions & 9 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (cfg *Config) makeSwarm(ctx context.Context) (*swarm.Swarm, error) {
return swrm, nil
}

func (cfg *Config) addTransports(ctx context.Context, h host.Host) (err error) {
func (cfg *Config) addTransports(h host.Host) (err error) {
swrm, ok := h.Network().(transport.TransportNetwork)
if !ok {
// Should probably skip this if no transports.
Expand Down Expand Up @@ -165,15 +165,13 @@ func (cfg *Config) addTransports(ctx context.Context, h host.Host) (err error) {
return err
}
for _, t := range tpts {
err = swrm.AddTransport(t)
if err != nil {
if err := swrm.AddTransport(t); err != nil {
return err
}
}

if cfg.Relay {
err := circuitv2.AddTransport(ctx, h, upgrader)
if err != nil {
if err := circuitv2.AddTransport(h, upgrader); err != nil {
h.Close()
return err
}
Expand Down Expand Up @@ -225,8 +223,7 @@ func (cfg *Config) NewNode(ctx context.Context) (host.Host, error) {
}
}

err = cfg.addTransports(ctx, h)
if err != nil {
if err := cfg.addTransports(h); err != nil {
h.Close()
return nil, err
}
Expand Down Expand Up @@ -314,8 +311,7 @@ func (cfg *Config) NewNode(ctx context.Context) (host.Host, error) {
return nil, err
}
dialerHost := blankhost.NewBlankHost(dialer)
err = autoNatCfg.addTransports(ctx, dialerHost)
if err != nil {
if err := autoNatCfg.addTransports(dialerHost); err != nil {
dialerHost.Close()
h.Close()
return nil, err
Expand Down
28 changes: 21 additions & 7 deletions p2p/protocol/circuitv2/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package client

import (
"context"
"io"
"sync"

"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/proto"

"github.com/libp2p/go-libp2p-core/host"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/transport"

logging "github.com/ipfs/go-log/v2"
tptu "github.com/libp2p/go-libp2p-transport-upgrader"
Expand All @@ -24,9 +26,10 @@ var log = logging.Logger("p2p-circuit")
// This allows us to use the v2 code as drop in replacement for v1 in a host without breaking
// existing code and interoperability with older nodes.
type Client struct {
ctx context.Context
host host.Host
upgrader *tptu.Upgrader
ctx context.Context
ctxCancel context.CancelFunc
host host.Host
upgrader *tptu.Upgrader

incoming chan accept

Expand All @@ -35,6 +38,9 @@ type Client struct {
hopCount map[peer.ID]int
}

var _ io.Closer = &Client{}
var _ transport.Transport = &Client{}

type accept struct {
conn *Conn
writeResponse func() error
Expand All @@ -48,19 +54,27 @@ type completion struct {

// New constructs a new p2p-circuit/v2 client, attached to the given host and using the given
// upgrader to perform connection upgrades.
func New(ctx context.Context, h host.Host, upgrader *tptu.Upgrader) (*Client, error) {
return &Client{
ctx: ctx,
func New(h host.Host, upgrader *tptu.Upgrader) (*Client, error) {
cl := &Client{
host: h,
upgrader: upgrader,
incoming: make(chan accept),
activeDials: make(map[peer.ID]*completion),
hopCount: make(map[peer.ID]int),
}, nil
}
cl.ctx, cl.ctxCancel = context.WithCancel(context.Background())
return cl, nil
}

// Start registers the circuit (client) protocol stream handlers
func (c *Client) Start() {
c.host.SetStreamHandler(proto.ProtoIDv1, c.handleStreamV1)
c.host.SetStreamHandler(proto.ProtoIDv2Stop, c.handleStreamV2)
}

func (c *Client) Close() error {
c.ctxCancel()
c.host.RemoveStreamHandler(proto.ProtoIDv1)
c.host.RemoveStreamHandler(proto.ProtoIDv2Stop)
return nil
}
6 changes: 3 additions & 3 deletions p2p/protocol/circuitv2/client/listen.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"errors"
"net"

ma "github.com/multiformats/go-multiaddr"
Expand Down Expand Up @@ -32,7 +33,7 @@ func (l *Listener) Accept() (manet.Conn, error) {
return evt.conn, nil

case <-l.ctx.Done():
return nil, l.ctx.Err()
return nil, errors.New("circuit v2 client closed")
}
}
}
Expand All @@ -49,6 +50,5 @@ func (l *Listener) Multiaddr() ma.Multiaddr {
}

func (l *Listener) Close() error {
// noop for now
return nil
return (*Client)(l).Close()
}
6 changes: 4 additions & 2 deletions p2p/protocol/circuitv2/client/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package client
import (
"context"
"fmt"
"io"

"github.com/libp2p/go-libp2p-core/host"
"github.com/libp2p/go-libp2p-core/peer"
Expand All @@ -17,13 +18,13 @@ var circuitAddr = ma.Cast(circuitProtocol.VCode)

// AddTransport constructs a new p2p-circuit/v2 client and adds it as a transport to the
// host network
func AddTransport(ctx context.Context, h host.Host, upgrader *tptu.Upgrader) error {
func AddTransport(h host.Host, upgrader *tptu.Upgrader) error {
n, ok := h.Network().(transport.TransportNetwork)
if !ok {
return fmt.Errorf("%v is not a transport network", h.Network())
}

c, err := New(ctx, h, upgrader)
c, err := New(h, upgrader)
if err != nil {
return fmt.Errorf("error constructing circuit client: %w", err)
}
Expand All @@ -45,6 +46,7 @@ func AddTransport(ctx context.Context, h host.Host, upgrader *tptu.Upgrader) err

// Transport interface
var _ transport.Transport = (*Client)(nil)
var _ io.Closer = (*Client)(nil)

func (c *Client) Dial(ctx context.Context, a ma.Multiaddr, p peer.ID) (transport.CapableConn, error) {
conn, err := c.dial(ctx, a, p)
Expand Down
4 changes: 2 additions & 2 deletions p2p/protocol/circuitv2/test/compat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestRelayCompatV2DialV1(t *testing.T) {

hosts, upgraders := getNetHosts(t, ctx, 3)
addTransportV1(t, ctx, hosts[0], upgraders[0])
addTransport(t, ctx, hosts[2], upgraders[2])
addTransport(t, hosts[2], upgraders[2])

rch := make(chan []byte, 1)
hosts[0].SetStreamHandler("test", func(s network.Stream) {
Expand Down Expand Up @@ -105,7 +105,7 @@ func TestRelayCompatV1DialV2(t *testing.T) {
defer cancel()

hosts, upgraders := getNetHosts(t, ctx, 3)
addTransport(t, ctx, hosts[0], upgraders[0])
addTransport(t, hosts[0], upgraders[0])
addTransportV1(t, ctx, hosts[2], upgraders[2])

rch := make(chan []byte, 1)
Expand Down
23 changes: 11 additions & 12 deletions p2p/protocol/circuitv2/test/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ import (

logging "github.com/ipfs/go-log/v2"
bhost "github.com/libp2p/go-libp2p-blankhost"
metrics "github.com/libp2p/go-libp2p-core/metrics"
pstoremem "github.com/libp2p/go-libp2p-peerstore/pstoremem"
"github.com/libp2p/go-libp2p-core/metrics"
"github.com/libp2p/go-libp2p-peerstore/pstoremem"
swarm "github.com/libp2p/go-libp2p-swarm"
swarmt "github.com/libp2p/go-libp2p-swarm/testing"
tptu "github.com/libp2p/go-libp2p-transport-upgrader"
tcp "github.com/libp2p/go-tcp-transport"
"github.com/libp2p/go-tcp-transport"
ma "github.com/multiformats/go-multiaddr"
)

Expand Down Expand Up @@ -85,9 +85,8 @@ func connect(t *testing.T, a, b host.Host) {
}
}

func addTransport(t *testing.T, ctx context.Context, h host.Host, upgrader *tptu.Upgrader) {
err := client.AddTransport(ctx, h, upgrader)
if err != nil {
func addTransport(t *testing.T, h host.Host, upgrader *tptu.Upgrader) {
if err := client.AddTransport(h, upgrader); err != nil {
t.Fatal(err)
}
}
Expand All @@ -97,8 +96,8 @@ func TestBasicRelay(t *testing.T) {
defer cancel()

hosts, upgraders := getNetHosts(t, ctx, 3)
addTransport(t, ctx, hosts[0], upgraders[0])
addTransport(t, ctx, hosts[2], upgraders[2])
addTransport(t, hosts[0], upgraders[0])
addTransport(t, hosts[2], upgraders[2])

rch := make(chan []byte, 1)
hosts[0].SetStreamHandler("test", func(s network.Stream) {
Expand Down Expand Up @@ -184,8 +183,8 @@ func TestRelayLimitTime(t *testing.T) {
defer cancel()

hosts, upgraders := getNetHosts(t, ctx, 3)
addTransport(t, ctx, hosts[0], upgraders[0])
addTransport(t, ctx, hosts[2], upgraders[2])
addTransport(t, hosts[0], upgraders[0])
addTransport(t, hosts[2], upgraders[2])

rch := make(chan error, 1)
hosts[0].SetStreamHandler("test", func(s network.Stream) {
Expand Down Expand Up @@ -258,8 +257,8 @@ func TestRelayLimitData(t *testing.T) {
defer cancel()

hosts, upgraders := getNetHosts(t, ctx, 3)
addTransport(t, ctx, hosts[0], upgraders[0])
addTransport(t, ctx, hosts[2], upgraders[2])
addTransport(t, hosts[0], upgraders[0])
addTransport(t, hosts[2], upgraders[2])

rch := make(chan int, 1)
hosts[0].SetStreamHandler("test", func(s network.Stream) {
Expand Down

0 comments on commit a5f982f

Please sign in to comment.