Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

swarm: wait for transient connections to upgrade for NewStream #2542

Merged
merged 3 commits into from
Oct 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 94 additions & 17 deletions p2p/net/swarm/swarm.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/transport"
"golang.org/x/exp/slices"

logging "github.com/ipfs/go-log/v2"
ma "github.com/multiformats/go-multiaddr"
Expand Down Expand Up @@ -172,6 +173,11 @@ type Swarm struct {
m map[network.Notifiee]struct{}
}

directConnNotifs struct {
sync.Mutex
m map[peer.ID][]chan struct{}
}

transports struct {
sync.RWMutex
m map[int]transport.Transport
Expand Down Expand Up @@ -231,6 +237,7 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts
s.listeners.m = make(map[transport.Listener]struct{})
s.transports.m = make(map[int]transport.Transport)
s.notifs.m = make(map[network.Notifiee]struct{})
s.directConnNotifs.m = make(map[peer.ID][]chan struct{})

for _, opt := range opts {
if err := opt(s); err != nil {
Expand Down Expand Up @@ -390,6 +397,19 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn,
c.notifyLk.Lock()
s.conns.Unlock()

// Notify goroutines waiting for a direct connection
if !c.Stat().Transient {
// Go routines interested in waiting for direct connection first acquire this lock
// and then acquire s.conns.RLock. Do not acquire this lock before conns.Unlock to
// prevent deadlock.
s.directConnNotifs.Lock()
for _, ch := range s.directConnNotifs.m[p] {
close(ch)
}
delete(s.directConnNotifs.m, p)
s.directConnNotifs.Unlock()
}

// Emit event after releasing `s.conns` lock so that a consumer can still
// use swarm methods that need the `s.conns` lock.
if isFirstConnection {
Expand Down Expand Up @@ -436,46 +456,103 @@ func (s *Swarm) NewStream(ctx context.Context, p peer.ID) (network.Stream, error

// Algorithm:
// 1. Find the best connection, otherwise, dial.
// 2. Try opening a stream.
// 3. If the underlying connection is, in fact, closed, close the outer
// 2. If the best connection is transient, wait for a direct conn via conn
// reversal or hole punching.
// 3. Try opening a stream.
// 4. If the underlying connection is, in fact, closed, close the outer
// connection and try again. We do this in case we have a closed
// connection but don't notice it until we actually try to open a
// stream.
//
// Note: We only dial once.
//
// TODO: Try all connections even if we get an error opening a stream on
// a non-closed connection.
dials := 0
numDials := 0
for {
// will prefer direct connections over relayed connections for opening streams
c := s.bestAcceptableConnToPeer(ctx, p)

c := s.bestConnToPeer(p)
if c == nil {
if nodial, _ := network.GetNoDial(ctx); nodial {
if nodial, _ := network.GetNoDial(ctx); !nodial {
numDials++
if numDials > DialAttempts {
return nil, errors.New("max dial attempts exceeded")
sukunrt marked this conversation as resolved.
Show resolved Hide resolved
}
var err error
c, err = s.dialPeer(ctx, p)
if err != nil {
return nil, err
}
} else {
return nil, network.ErrNoConn
}
}

if dials >= DialAttempts {
return nil, errors.New("max dial attempts exceeded")
}
dials++

useTransient, _ := network.GetUseTransient(ctx)
if !useTransient && c.Stat().Transient {
var err error
c, err = s.dialPeer(ctx, p)
c, err = s.waitForDirectConn(ctx, p)
if err != nil {
return nil, err
}
}

s, err := c.NewStream(ctx)
str, err := c.NewStream(ctx)
if err != nil {
if c.conn.IsClosed() {
continue
}
return nil, err
}
return s, nil
return str, nil
}
}

// waitForDirectConn waits for a direct connection established through hole punching or connection reversal.
func (s *Swarm) waitForDirectConn(ctx context.Context, p peer.ID) (*Conn, error) {
s.directConnNotifs.Lock()
c := s.bestConnToPeer(p)
if c == nil {
s.directConnNotifs.Unlock()
return nil, network.ErrNoConn
} else if !c.Stat().Transient {
s.directConnNotifs.Unlock()
return c, nil
}
sukunrt marked this conversation as resolved.
Show resolved Hide resolved

// Wait for transient connection to upgrade to a direct connection either by
// connection reversal or hole punching.
ch := make(chan struct{})
s.directConnNotifs.m[p] = append(s.directConnNotifs.m[p], ch)
s.directConnNotifs.Unlock()

// apply the DialPeer timeout
ctx, cancel := context.WithTimeout(ctx, network.GetDialPeerTimeout(ctx))
defer cancel()

// Wait for notification.
select {
case <-ctx.Done():
// Remove ourselves from the notification list
s.directConnNotifs.Lock()
sukunrt marked this conversation as resolved.
Show resolved Hide resolved
defer s.directConnNotifs.Unlock()

s.directConnNotifs.m[p] = slices.DeleteFunc(
s.directConnNotifs.m[p],
func(c chan struct{}) bool { return c == ch },
)
if len(s.directConnNotifs.m[p]) == 0 {
delete(s.directConnNotifs.m, p)
}
return nil, ctx.Err()
case <-ch:
// We do not need to remove ourselves from the list here as the notifier
// clears the map entry
c := s.bestConnToPeer(p)
if c == nil {
return nil, network.ErrNoConn
}
if c.Stat().Transient {
return nil, network.ErrTransientConn
}
return c, nil
}
}

Expand Down
88 changes: 86 additions & 2 deletions p2p/test/basichost/basic_host_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@ import (
"context"
"fmt"
"testing"
"time"

"github.com/libp2p/go-libp2p"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client"
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -62,11 +65,92 @@ func TestNoStreamOverTransientConnection(t *testing.T) {
err = h1.Connect(context.Background(), h2Info)
require.NoError(t, err)

ctx := network.WithNoDial(context.Background(), "test")
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
ctx = network.WithNoDial(ctx, "test")
_, err = h1.NewStream(ctx, h2.ID(), "/testprotocol")

require.ErrorIs(t, err, network.ErrTransientConn)
require.Error(t, err)

_, err = h1.NewStream(network.WithUseTransient(context.Background(), "test"), h2.ID(), "/testprotocol")
require.NoError(t, err)
}

func TestNewStreamTransientConnection(t *testing.T) {
h1, err := libp2p.New(
libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1"),
libp2p.EnableRelay(),
)
require.NoError(t, err)

h2, err := libp2p.New(
libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1"),
libp2p.EnableRelay(),
)
require.NoError(t, err)

relay1, err := libp2p.New()
require.NoError(t, err)

_, err = relay.New(relay1)
require.NoError(t, err)

relay1info := peer.AddrInfo{
ID: relay1.ID(),
Addrs: relay1.Addrs(),
}
err = h1.Connect(context.Background(), relay1info)
require.NoError(t, err)

err = h2.Connect(context.Background(), relay1info)
require.NoError(t, err)

h2.SetStreamHandler("/testprotocol", func(s network.Stream) {
fmt.Println("testprotocol")

// End the example
s.Close()
})

_, err = client.Reserve(context.Background(), h2, relay1info)
require.NoError(t, err)

relayaddr := ma.StringCast("/p2p/" + relay1info.ID.String() + "/p2p-circuit/p2p/" + h2.ID().String())

h1.Peerstore().AddAddr(h2.ID(), relayaddr, peerstore.TempAddrTTL)

// NewStream should block transient connections till we have a direct connection
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
s, err := h1.NewStream(ctx, h2.ID(), "/testprotocol")
require.ErrorIs(t, err, context.DeadlineExceeded)
require.Nil(t, s)

// NewStream should return a stream if a direct connection is established
// while waiting
done := make(chan bool, 2)
go func() {
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), peerstore.TempAddrTTL)
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ctx = network.WithNoDial(ctx, "test")
s, err = h1.NewStream(ctx, h2.ID(), "/testprotocol")
require.NoError(t, err)
require.NotNil(t, s)
defer s.Close()
require.Equal(t, s.Conn().Stat().Direction, network.DirInbound)
done <- true
}()
go func() {
// connect h2 to h1 simulating connection reversal
h2.Peerstore().AddAddrs(h1.ID(), h1.Addrs(), peerstore.TempAddrTTL)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
ctx = network.WithForceDirectDial(ctx, "test")
err := h2.Connect(ctx, peer.AddrInfo{ID: h1.ID()})
assert.NoError(t, err)
done <- true
}()
<-done
<-done
}
92 changes: 92 additions & 0 deletions p2p/test/swarm/swarm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package swarm_test
import (
"context"
"testing"
"time"

"github.com/libp2p/go-libp2p"
"github.com/libp2p/go-libp2p/core/network"
Expand All @@ -11,6 +12,7 @@ import (
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client"
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -68,3 +70,93 @@ func TestDialPeerTransientConnection(t *testing.T) {
require.Error(t, err)
require.Nil(t, conn)
}

func TestNewStreamTransientConnection(t *testing.T) {
h1, err := libp2p.New(
libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1"),
libp2p.EnableRelay(),
)
require.NoError(t, err)

h2, err := libp2p.New(
libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1"),
libp2p.EnableRelay(),
)
require.NoError(t, err)

relay1, err := libp2p.New()
require.NoError(t, err)

_, err = relay.New(relay1)
require.NoError(t, err)

relay1info := peer.AddrInfo{
ID: relay1.ID(),
Addrs: relay1.Addrs(),
}
err = h1.Connect(context.Background(), relay1info)
require.NoError(t, err)

err = h2.Connect(context.Background(), relay1info)
require.NoError(t, err)

_, err = client.Reserve(context.Background(), h2, relay1info)
require.NoError(t, err)

relayaddr := ma.StringCast("/p2p/" + relay1info.ID.String() + "/p2p-circuit/p2p/" + h2.ID().String())

h1.Peerstore().AddAddr(h2.ID(), relayaddr, peerstore.TempAddrTTL)

// WithUseTransient should succeed
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
ctx = network.WithUseTransient(ctx, "test")
s, err := h1.Network().NewStream(ctx, h2.ID())
require.NoError(t, err)
require.NotNil(t, s)
defer s.Close()

// Without WithUseTransient should fail with context deadline exceeded
ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
s, err = h1.Network().NewStream(ctx, h2.ID())
require.ErrorIs(t, err, context.DeadlineExceeded)
require.Nil(t, s)

// Provide h2's direct address to h1.
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), peerstore.TempAddrTTL)
// network.NoDial should also fail
ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
ctx = network.WithNoDial(ctx, "test")
s, err = h1.Network().NewStream(ctx, h2.ID())
require.ErrorIs(t, err, context.DeadlineExceeded)
require.Nil(t, s)

done := make(chan bool, 2)
// NewStream should return a stream if an incoming direct connection is established
go func() {
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ctx = network.WithNoDial(ctx, "test")
s, err = h1.Network().NewStream(ctx, h2.ID())
assert.NoError(t, err)
assert.NotNil(t, s)
defer s.Close()
require.Equal(t, s.Conn().Stat().Direction, network.DirInbound)
done <- true
}()
go func() {
// connect h2 to h1 simulating connection reversal
h2.Peerstore().AddAddrs(h1.ID(), h1.Addrs(), peerstore.TempAddrTTL)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
ctx = network.WithForceDirectDial(ctx, "test")
err := h2.Connect(ctx, peer.AddrInfo{ID: h1.ID()})
assert.NoError(t, err)
done <- true
}()

<-done
<-done
}