Skip to content

Commit

Permalink
dht mode toggling (modulo dynamic switching) (#350)
Browse files Browse the repository at this point in the history
  • Loading branch information
whyrusleeping authored and aschmahmann committed Mar 5, 2020
1 parent 5d313b1 commit d570496
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 2 deletions.
123 changes: 121 additions & 2 deletions dht.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"sync"
"time"

"github.com/libp2p/go-libp2p-core/event"
"github.com/libp2p/go-libp2p-core/host"
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
Expand Down Expand Up @@ -40,6 +41,13 @@ var logger = logging.Logger("dht")

const BaseConnMgrScore = 5

type DHTMode int

const (
ModeServer = DHTMode(1)
ModeClient = DHTMode(2)
)

// IpfsDHT is an implementation of Kademlia with S/Kademlia modifications.
// It is used to implement the base Routing module.
type IpfsDHT struct {
Expand Down Expand Up @@ -69,6 +77,9 @@ type IpfsDHT struct {

protocols []protocol.ID // DHT protocols

mode DHTMode
modeLk sync.Mutex

bucketSize int
alpha int // The concurrency parameter per path
d int // Number of Disjoint Paths to query
Expand Down Expand Up @@ -116,13 +127,15 @@ func New(ctx context.Context, h host.Host, options ...opts.Option) (*IpfsDHT, er

// register for network notifs.
dht.proc.Go((*subscriberNotifee)(dht).subscribe)
go dht.handleProtocolChanges(ctx)
// handle providers
dht.proc.AddChild(dht.providers.Process())
dht.Validator = cfg.Validator
dht.mode = ModeClient

if !cfg.Client {
for _, p := range cfg.Protocols {
h.SetStreamHandler(p, dht.handleNewStream)
if err := dht.moveToServerMode(); err != nil {
return nil, err
}
}
dht.startRefreshing()
Expand Down Expand Up @@ -435,6 +448,61 @@ func (dht *IpfsDHT) betterPeersToQuery(pmes *pb.Message, p peer.ID, count int) [
return filtered
}

func (dht *IpfsDHT) SetMode(m DHTMode) error {
dht.modeLk.Lock()
defer dht.modeLk.Unlock()

if m == dht.mode {
return nil
}

switch m {
case ModeServer:
return dht.moveToServerMode()
case ModeClient:
return dht.moveToClientMode()
default:
return fmt.Errorf("unrecognized dht mode: %d", m)
}
}

func (dht *IpfsDHT) moveToServerMode() error {
dht.mode = ModeServer
for _, p := range dht.protocols {
dht.host.SetStreamHandler(p, dht.handleNewStream)
}
return nil
}

func (dht *IpfsDHT) moveToClientMode() error {
dht.mode = ModeClient
for _, p := range dht.protocols {
dht.host.RemoveStreamHandler(p)
}

pset := make(map[protocol.ID]bool)
for _, p := range dht.protocols {
pset[p] = true
}

for _, c := range dht.host.Network().Conns() {
for _, s := range c.GetStreams() {
if pset[s.Protocol()] {
if s.Stat().Direction == network.DirInbound {
s.Reset()
}
}
}
}
return nil
}

func (dht *IpfsDHT) getMode() DHTMode {
dht.modeLk.Lock()
defer dht.modeLk.Unlock()
return dht.mode
}

// Context return dht's context
func (dht *IpfsDHT) Context() context.Context {
return dht.ctx
Expand Down Expand Up @@ -507,3 +575,54 @@ func (dht *IpfsDHT) newContextWithLocalTags(ctx context.Context, extraTags ...ta
) // ignoring error as it is unrelated to the actual function of this code.
return ctx
}

func (dht *IpfsDHT) handleProtocolChanges(ctx context.Context) {
// register for event bus protocol ID changes
sub, err := dht.host.EventBus().Subscribe(new(event.EvtPeerProtocolsUpdated))
if err != nil {
panic(err)
}
defer sub.Close()

pmap := make(map[protocol.ID]bool)
for _, p := range dht.protocols {
pmap[p] = true
}

for {
select {
case ie, ok := <-sub.Out():
e, ok := ie.(event.EvtPeerProtocolsUpdated)
if !ok {
logger.Errorf("got wrong type from subscription: %T", ie)
return
}

if !ok {
return
}
var drop, add bool
for _, p := range e.Added {
if pmap[p] {
add = true
}
}
for _, p := range e.Removed {
if pmap[p] {
drop = true
}
}

if add && drop {
// TODO: discuss how to handle this case
logger.Warning("peer adding and dropping dht protocols? odd")
} else if add {
dht.RoutingTable().Update(e.Peer)
} else if drop {
dht.RoutingTable().Remove(e.Peer)
}
case <-ctx.Done():
return
}
}
}
18 changes: 18 additions & 0 deletions dht_net.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

"github.com/libp2p/go-libp2p-kad-dht/metrics"
pb "github.com/libp2p/go-libp2p-kad-dht/pb"
msmux "github.com/multiformats/go-multistream"

ggio "github.com/gogo/protobuf/io"

Expand Down Expand Up @@ -80,6 +81,11 @@ func (dht *IpfsDHT) handleNewMessage(s network.Stream) bool {
defer timer.Stop()

for {
if dht.getMode() != ModeServer {
logger.Errorf("ignoring incoming dht message while not in server mode")
return false
}

var req pb.Message
msgbytes, err := r.ReadMsg()
msgLen := len(msgbytes)
Expand Down Expand Up @@ -167,6 +173,9 @@ func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message

ms, err := dht.messageSenderForPeer(ctx, p)
if err != nil {
if err == msmux.ErrNotSupported {
dht.RoutingTable().Remove(p)
}
stats.Record(ctx,
metrics.SentRequests.M(1),
metrics.SentRequestErrors.M(1),
Expand All @@ -178,6 +187,9 @@ func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message

rpmes, err := ms.SendRequest(ctx, pmes)
if err != nil {
if err == msmux.ErrNotSupported {
dht.RoutingTable().Remove(p)
}
stats.Record(ctx,
metrics.SentRequests.M(1),
metrics.SentRequestErrors.M(1),
Expand All @@ -201,6 +213,9 @@ func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message

ms, err := dht.messageSenderForPeer(ctx, p)
if err != nil {
if err == msmux.ErrNotSupported {
dht.RoutingTable().Remove(p)
}
stats.Record(ctx,
metrics.SentMessages.M(1),
metrics.SentMessageErrors.M(1),
Expand All @@ -209,6 +224,9 @@ func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message
}

if err := ms.SendMessage(ctx, pmes); err != nil {
if err == msmux.ErrNotSupported {
dht.RoutingTable().Remove(p)
}
stats.Record(ctx,
metrics.SentMessages.M(1),
metrics.SentMessageErrors.M(1),
Expand Down
19 changes: 19 additions & 0 deletions dht_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1704,3 +1704,22 @@ func TestClientModeAtInit(t *testing.T) {
err := pinger.Ping(context.Background(), client.PeerID())
assert.True(t, xerrors.Is(err, multistream.ErrNotSupported))
}

func TestModeChange(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

clientOnly := setupDHT(ctx, t, true)
clientToServer := setupDHT(ctx, t, true)
clientOnly.Host().Peerstore().AddAddrs(clientToServer.PeerID(), clientToServer.Host().Addrs(), peerstore.AddressTTL)
err := clientOnly.Ping(ctx, clientToServer.PeerID())
assert.True(t, xerrors.Is(err, multistream.ErrNotSupported))
err = clientToServer.SetMode(ModeServer)
assert.Nil(t, err)
err = clientOnly.Ping(ctx, clientToServer.PeerID())
assert.Nil(t, err)
err = clientToServer.SetMode(ModeClient)
assert.Nil(t, err)
err = clientOnly.Ping(ctx, clientToServer.PeerID())
assert.NotNil(t, err)
}

0 comments on commit d570496

Please sign in to comment.