diff --git a/swarm.go b/swarm.go index 43f7a36f..a55c8d9c 100644 --- a/swarm.go +++ b/swarm.go @@ -41,6 +41,22 @@ var ErrAddrFiltered = errors.New("address filtered") // ErrDialTimeout is returned when one a dial times out due to the global timeout var ErrDialTimeout = errors.New("dial timed out") +type Option func(*Swarm) + +// WithConnectionGater sets a connection gater +func WithConnectionGater(gater connmgr.ConnectionGater) Option { + return func(s *Swarm) { + s.gater = gater + } +} + +// WithMetrics sets a metrics reporter +func WithMetrics(reporter metrics.Reporter) Option { + return func(s *Swarm) { + s.bwc = reporter + } +} + // Swarm is a connection muxer, allowing connections to other peers to // be opened and closed, while still using the same Chan for all // communication. The Chan sends/receives Messages, which note the @@ -98,17 +114,11 @@ type Swarm struct { } // NewSwarm constructs a Swarm. -// -// NOTE: go-libp2p will be moving to dependency injection soon. The variadic -// `extra` interface{} parameter facilitates the future migration. Supported -// elements are: -// - connmgr.ConnectionGater -func NewSwarm(local peer.ID, peers peerstore.Peerstore, bwc metrics.Reporter, extra ...interface{}) *Swarm { +func NewSwarm(local peer.ID, peers peerstore.Peerstore, opts ...Option) *Swarm { ctx, cancel := context.WithCancel(context.Background()) s := &Swarm{ local: local, peers: peers, - bwc: bwc, ctx: ctx, ctxCancel: cancel, } @@ -118,11 +128,8 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, bwc metrics.Reporter, ex s.transports.m = make(map[int]transport.Transport) s.notifs.m = make(map[network.Notifiee]struct{}) - for _, i := range extra { - switch v := i.(type) { - case connmgr.ConnectionGater: - s.gater = v - } + for _, opt := range opts { + opt(s) } s.dsync = newDialSync(s.dialWorkerLoop) diff --git a/swarm_test.go b/swarm_test.go index 9f799f66..07551fd0 100644 --- a/swarm_test.go +++ b/swarm_test.go @@ -16,7 +16,7 @@ import ( "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peerstore" - . "github.com/libp2p/go-libp2p-swarm" + swarm "github.com/libp2p/go-libp2p-swarm" . "github.com/libp2p/go-libp2p-swarm/testing" logging "github.com/ipfs/go-log" @@ -58,14 +58,14 @@ func EchoStreamHandler(stream network.Stream) { }() } -func makeDialOnlySwarm(t *testing.T) *Swarm { +func makeDialOnlySwarm(t *testing.T) *swarm.Swarm { swarm := GenSwarm(t, OptDialOnly) swarm.SetStreamHandler(EchoStreamHandler) return swarm } -func makeSwarms(t *testing.T, num int, opts ...Option) []*Swarm { - swarms := make([]*Swarm, 0, num) +func makeSwarms(t *testing.T, num int, opts ...Option) []*swarm.Swarm { + swarms := make([]*swarm.Swarm, 0, num) for i := 0; i < num; i++ { swarm := GenSwarm(t, opts...) swarm.SetStreamHandler(EchoStreamHandler) @@ -74,9 +74,9 @@ func makeSwarms(t *testing.T, num int, opts ...Option) []*Swarm { return swarms } -func connectSwarms(t *testing.T, ctx context.Context, swarms []*Swarm) { +func connectSwarms(t *testing.T, ctx context.Context, swarms []*swarm.Swarm) { var wg sync.WaitGroup - connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { + connect := func(s *swarm.Swarm, dst peer.ID, addr ma.Multiaddr) { // TODO: make a DialAddr func. s.Peerstore().AddAddr(dst, addr, peerstore.PermanentAddrTTL) if _, err := s.DialPeer(ctx, dst); err != nil { @@ -455,7 +455,7 @@ func TestPreventDialListenAddr(t *testing.T) { remote := peer.ID("foobar") s.Peerstore().AddAddr(remote, addr, time.Hour) _, err = s.DialPeer(context.Background(), remote) - if !errors.Is(err, ErrNoGoodAddresses) { + if !errors.Is(err, swarm.ErrNoGoodAddresses) { t.Fatal("expected dial to fail: %w", err) } } diff --git a/testing/testing.go b/testing/testing.go index 201b4f0f..b0d4c218 100644 --- a/testing/testing.go +++ b/testing/testing.go @@ -113,7 +113,11 @@ func GenSwarm(t *testing.T, opts ...Option) *swarm.Swarm { ps.AddPrivKey(p.ID, p.PrivKey) t.Cleanup(func() { ps.Close() }) - s := swarm.NewSwarm(p.ID, ps, metrics.NewBandwidthCounter(), cfg.connectionGater) + swarmOpts := []swarm.Option{swarm.WithMetrics(metrics.NewBandwidthCounter())} + if cfg.connectionGater != nil { + swarmOpts = append(swarmOpts, swarm.WithConnectionGater(cfg.connectionGater)) + } + s := swarm.NewSwarm(p.ID, ps, swarmOpts...) upgrader := GenUpgrader(s) upgrader.ConnGater = cfg.connectionGater