diff --git a/dht.go b/dht.go index 19e7f585c..7a5d5fa83 100644 --- a/dht.go +++ b/dht.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/libp2p/go-libp2p-core/event" "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" @@ -40,6 +41,13 @@ var logger = logging.Logger("dht") const BaseConnMgrScore = 5 +type DHTMode int + +const ( + ModeServer = DHTMode(1) + ModeClient = DHTMode(2) +) + // IpfsDHT is an implementation of Kademlia with S/Kademlia modifications. // It is used to implement the base Routing module. type IpfsDHT struct { @@ -69,6 +77,9 @@ type IpfsDHT struct { protocols []protocol.ID // DHT protocols + mode DHTMode + modeLk sync.Mutex + bucketSize int alpha int // The concurrency parameter per path d int // Number of Disjoint Paths to query @@ -117,6 +128,8 @@ func New(ctx context.Context, h host.Host, options ...opts.Option) (*IpfsDHT, er // register for network notifs. dht.host.Network().Notify((*netNotifiee)(dht)) + go dht.handleProtocolChanges(ctx) + dht.proc = goprocessctx.WithContextAndTeardown(ctx, func() error { // remove ourselves from network notifs. dht.host.Network().StopNotify((*netNotifiee)(dht)) @@ -125,10 +138,11 @@ func New(ctx context.Context, h host.Host, options ...opts.Option) (*IpfsDHT, er dht.proc.AddChild(dht.providers.Process()) dht.Validator = cfg.Validator + dht.mode = ModeClient if !cfg.Client { - for _, p := range cfg.Protocols { - h.SetStreamHandler(p, dht.handleNewStream) + if err := dht.moveToServerMode(); err != nil { + return nil, err } } dht.startRefreshing() @@ -435,6 +449,61 @@ func (dht *IpfsDHT) betterPeersToQuery(pmes *pb.Message, p peer.ID, count int) [ return filtered } +func (dht *IpfsDHT) SetMode(m DHTMode) error { + dht.modeLk.Lock() + defer dht.modeLk.Unlock() + + if m == dht.mode { + return nil + } + + switch m { + case ModeServer: + return dht.moveToServerMode() + case ModeClient: + return dht.moveToClientMode() + default: + return fmt.Errorf("unrecognized dht mode: %d", m) + } +} + +func (dht *IpfsDHT) moveToServerMode() error { + dht.mode = ModeServer + for _, p := range dht.protocols { + dht.host.SetStreamHandler(p, dht.handleNewStream) + } + return nil +} + +func (dht *IpfsDHT) moveToClientMode() error { + dht.mode = ModeClient + for _, p := range dht.protocols { + dht.host.RemoveStreamHandler(p) + } + + pset := make(map[protocol.ID]bool) + for _, p := range dht.protocols { + pset[p] = true + } + + for _, c := range dht.host.Network().Conns() { + for _, s := range c.GetStreams() { + if pset[s.Protocol()] { + if s.Stat().Direction == network.DirInbound { + s.Reset() + } + } + } + } + return nil +} + +func (dht *IpfsDHT) getMode() DHTMode { + dht.modeLk.Lock() + defer dht.modeLk.Unlock() + return dht.mode +} + // Context return dht's context func (dht *IpfsDHT) Context() context.Context { return dht.ctx @@ -507,3 +576,54 @@ func (dht *IpfsDHT) newContextWithLocalTags(ctx context.Context, extraTags ...ta ) // ignoring error as it is unrelated to the actual function of this code. return ctx } + +func (dht *IpfsDHT) handleProtocolChanges(ctx context.Context) { + // register for event bus protocol ID changes + sub, err := dht.host.EventBus().Subscribe(new(event.EvtPeerProtocolsUpdated)) + if err != nil { + panic(err) + } + defer sub.Close() + + pmap := make(map[protocol.ID]bool) + for _, p := range dht.protocols { + pmap[p] = true + } + + for { + select { + case ie, ok := <-sub.Out(): + e, ok := ie.(event.EvtPeerProtocolsUpdated) + if !ok { + logger.Errorf("got wrong type from subscription: %T", ie) + return + } + + if !ok { + return + } + var drop, add bool + for _, p := range e.Added { + if pmap[p] { + add = true + } + } + for _, p := range e.Removed { + if pmap[p] { + drop = true + } + } + + if add && drop { + // TODO: discuss how to handle this case + logger.Warning("peer adding and dropping dht protocols? odd") + } else if add { + dht.RoutingTable().Update(e.Peer) + } else if drop { + dht.RoutingTable().Remove(e.Peer) + } + case <-ctx.Done(): + return + } + } +} diff --git a/dht_net.go b/dht_net.go index 31775ae8f..3d0f87bb9 100644 --- a/dht_net.go +++ b/dht_net.go @@ -14,6 +14,7 @@ import ( "github.com/libp2p/go-libp2p-kad-dht/metrics" pb "github.com/libp2p/go-libp2p-kad-dht/pb" + msmux "github.com/multiformats/go-multistream" ggio "github.com/gogo/protobuf/io" @@ -80,6 +81,11 @@ func (dht *IpfsDHT) handleNewMessage(s network.Stream) bool { defer timer.Stop() for { + if dht.getMode() != ModeServer { + logger.Errorf("ignoring incoming dht message while not in server mode") + return false + } + var req pb.Message msgbytes, err := r.ReadMsg() if err != nil { @@ -166,6 +172,9 @@ func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message ms, err := dht.messageSenderForPeer(ctx, p) if err != nil { + if err == msmux.ErrNotSupported { + dht.RoutingTable().Remove(p) + } stats.Record(ctx, metrics.SentRequestErrors.M(1)) return nil, err } @@ -174,6 +183,9 @@ func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message rpmes, err := ms.SendRequest(ctx, pmes) if err != nil { + if err == msmux.ErrNotSupported { + dht.RoutingTable().Remove(p) + } stats.Record(ctx, metrics.SentRequestErrors.M(1)) return nil, err } @@ -200,11 +212,17 @@ func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message ms, err := dht.messageSenderForPeer(ctx, p) if err != nil { + if err == msmux.ErrNotSupported { + dht.RoutingTable().Remove(p) + } stats.Record(ctx, metrics.SentMessageErrors.M(1)) return err } if err := ms.SendMessage(ctx, pmes); err != nil { + if err == msmux.ErrNotSupported { + dht.RoutingTable().Remove(p) + } stats.Record(ctx, metrics.SentMessageErrors.M(1)) return err } diff --git a/dht_test.go b/dht_test.go index 78de64fdd..31d35ae4e 100644 --- a/dht_test.go +++ b/dht_test.go @@ -1680,3 +1680,22 @@ func TestClientModeAtInit(t *testing.T) { err := pinger.Ping(context.Background(), client.PeerID()) assert.True(t, xerrors.Is(err, multistream.ErrNotSupported)) } + +func TestModeChange(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + clientOnly := setupDHT(ctx, t, true) + clientToServer := setupDHT(ctx, t, true) + clientOnly.Host().Peerstore().AddAddrs(clientToServer.PeerID(), clientToServer.Host().Addrs(), peerstore.AddressTTL) + err := clientOnly.Ping(ctx, clientToServer.PeerID()) + assert.True(t, xerrors.Is(err, multistream.ErrNotSupported)) + err = clientToServer.SetMode(ModeServer) + assert.Nil(t, err) + err = clientOnly.Ping(ctx, clientToServer.PeerID()) + assert.Nil(t, err) + err = clientToServer.SetMode(ModeClient) + assert.Nil(t, err) + err = clientOnly.Ping(ctx, clientToServer.PeerID()) + assert.NotNil(t, err) +}