diff --git a/dht.go b/dht.go index 91302cc08..d27458331 100644 --- a/dht.go +++ b/dht.go @@ -86,7 +86,7 @@ type IpfsDHT struct { // DHT protocols we can respond to. serverProtocols []protocol.ID - auto bool + auto ModeOpt mode mode modeLk sync.Mutex @@ -159,15 +159,11 @@ func New(ctx context.Context, h host.Host, options ...Option) (*IpfsDHT, error) dht.Validator = cfg.validator + dht.auto = cfg.mode switch cfg.mode { - case ModeAuto: - dht.auto = true + case ModeAuto, ModeClient: dht.mode = modeClient - case ModeClient: - dht.auto = false - dht.mode = modeClient - case ModeServer: - dht.auto = false + case ModeAutoServer, ModeServer: dht.mode = modeServer default: return nil, fmt.Errorf("invalid dht mode %d", cfg.mode) @@ -312,6 +308,11 @@ func makeRoutingTable(dht *IpfsDHT, cfg config) (*kb.RoutingTable, error) { return rt, err } +// Mode allows introspection of the operation mode of the DHT +func (dht *IpfsDHT) Mode() ModeOpt { + return dht.auto +} + // fixLowPeers tries to get more peers into the routing table if we're below the threshold func (dht *IpfsDHT) fixLowPeersRoutine(proc goprocess.Process) { for { diff --git a/dht_options.go b/dht_options.go index 1e15031d6..3637c2136 100644 --- a/dht_options.go +++ b/dht_options.go @@ -25,6 +25,8 @@ const ( ModeClient // ModeServer operates the DHT as a server, it can both send and respond to queries ModeServer + // ModeAutoServer operates in the same way as ModeAuto, but acts as a server when reachability is unknown + ModeAutoServer ) // DefaultPrefix is the application specific prefix attached to all DHT protocols by default. @@ -256,6 +258,15 @@ func ProtocolPrefix(prefix protocol.ID) Option { } } +// ProtocolExtension adds an application specific protocol to the DHT protocol. For example, +// /ipfs/lan/kad/1.0.0 instead of /ipfs/kad/1.0.0. extension should be of the form /lan. +func ProtocolExtension(ext protocol.ID) Option { + return func(c *config) error { + c.protocolPrefix += ext + return nil + } +} + // BucketSize configures the bucket size (k in the Kademlia paper) of the routing table. // // The default value is 20. diff --git a/dht_test.go b/dht_test.go index 09f2b85e5..2cd45929c 100644 --- a/dht_test.go +++ b/dht_test.go @@ -25,6 +25,7 @@ import ( "github.com/stretchr/testify/require" pb "github.com/libp2p/go-libp2p-kad-dht/pb" + test "github.com/libp2p/go-libp2p-kad-dht/internal/testing" "github.com/ipfs/go-cid" u "github.com/ipfs/go-ipfs-util" @@ -72,33 +73,8 @@ type blankValidator struct{} func (blankValidator) Validate(_ string, _ []byte) error { return nil } func (blankValidator) Select(_ string, _ [][]byte) (int, error) { return 0, nil } -type testValidator struct{} - -func (testValidator) Select(_ string, bs [][]byte) (int, error) { - index := -1 - for i, b := range bs { - if bytes.Equal(b, []byte("newer")) { - index = i - } else if bytes.Equal(b, []byte("valid")) { - if index == -1 { - index = i - } - } - } - if index == -1 { - return -1, errors.New("no rec found") - } - return index, nil -} -func (testValidator) Validate(_ string, b []byte) error { - if bytes.Equal(b, []byte("expired")) { - return errors.New("expired") - } - return nil -} - type testAtomicPutValidator struct { - testValidator + test.TestValidator } // selects the entry with the 'highest' last byte @@ -372,7 +348,7 @@ func TestValueSetInvalid(t *testing.T) { defer dhtA.host.Close() defer dhtB.host.Close() - dhtA.Validator.(record.NamespacedValidator)["v"] = testValidator{} + dhtA.Validator.(record.NamespacedValidator)["v"] = test.TestValidator{} dhtB.Validator.(record.NamespacedValidator)["v"] = blankValidator{} connect(t, ctx, dhtA, dhtB) @@ -451,8 +427,8 @@ func TestSearchValue(t *testing.T) { connect(t, ctx, dhtA, dhtB) - dhtA.Validator.(record.NamespacedValidator)["v"] = testValidator{} - dhtB.Validator.(record.NamespacedValidator)["v"] = testValidator{} + dhtA.Validator.(record.NamespacedValidator)["v"] = test.TestValidator{} + dhtB.Validator.(record.NamespacedValidator)["v"] = test.TestValidator{} ctxT, cancel := context.WithTimeout(ctx, time.Second) defer cancel() @@ -554,7 +530,7 @@ func TestValueGetInvalid(t *testing.T) { defer dhtB.host.Close() dhtA.Validator.(record.NamespacedValidator)["v"] = blankValidator{} - dhtB.Validator.(record.NamespacedValidator)["v"] = testValidator{} + dhtB.Validator.(record.NamespacedValidator)["v"] = test.TestValidator{} connect(t, ctx, dhtA, dhtB) diff --git a/dual/dual.go b/dual/dual.go new file mode 100644 index 000000000..f60e6a92a --- /dev/null +++ b/dual/dual.go @@ -0,0 +1,216 @@ +// Package dual provides an implementaiton of a split or "dual" dht, where two parallel instances +// are maintained for the global internet and the local LAN respectively. +package dual + +import ( + "context" + "sync" + + "github.com/ipfs/go-cid" + ci "github.com/libp2p/go-libp2p-core/crypto" + "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/protocol" + "github.com/libp2p/go-libp2p-core/routing" + dht "github.com/libp2p/go-libp2p-kad-dht" + helper "github.com/libp2p/go-libp2p-routing-helpers" + + "github.com/hashicorp/go-multierror" +) + +// DHT implements the routing interface to provide two concrete DHT implementationts for use +// in IPFS that are used to support both global network users and disjoint LAN usecases. +type DHT struct { + WAN *dht.IpfsDHT + LAN *dht.IpfsDHT +} + +// LanExtension is used to differentiate local protocol requests from those on the WAN DHT. +const LanExtension protocol.ID = "/lan" + +// Assert that IPFS assumptions about interfaces aren't broken. These aren't a +// guarantee, but we can use them to aid refactoring. +var ( + _ routing.ContentRouting = (*DHT)(nil) + _ routing.Routing = (*DHT)(nil) + _ routing.PeerRouting = (*DHT)(nil) + _ routing.PubKeyFetcher = (*DHT)(nil) + _ routing.ValueStore = (*DHT)(nil) +) + +// New creates a new DualDHT instance. Options provided are forwarded on to the two concrete +// IpfsDHT internal constructions, modulo additional options used by the Dual DHT to enforce +// the LAN-vs-WAN distinction. +// Note: query or routing table functional options provided as arguments to this function +// will be overriden by this constructor. +func New(ctx context.Context, h host.Host, options ...dht.Option) (*DHT, error) { + wanOpts := append(options, + dht.QueryFilter(dht.PublicQueryFilter), + dht.RoutingTableFilter(dht.PublicRoutingTableFilter), + ) + wan, err := dht.New(ctx, h, wanOpts...) + if err != nil { + return nil, err + } + + // Unless overridden by user supplied options, the LAN DHT should default + // to 'AutoServer' mode. + lanOpts := append(options, + dht.ProtocolExtension(LanExtension), + dht.QueryFilter(dht.PrivateQueryFilter), + dht.RoutingTableFilter(dht.PrivateRoutingTableFilter), + ) + if wan.Mode() != dht.ModeClient { + lanOpts = append(lanOpts, dht.Mode(dht.ModeServer)) + } + lan, err := dht.New(ctx, h, lanOpts...) + if err != nil { + return nil, err + } + + impl := DHT{wan, lan} + return &impl, nil +} + +// Close closes the DHT context. +func (dht *DHT) Close() error { + return multierror.Append(dht.WAN.Close(), dht.LAN.Close()).ErrorOrNil() +} + +func (dht *DHT) activeWAN() bool { + return dht.WAN.RoutingTable().Size() > 0 +} + +// Provide adds the given cid to the content routing system. +func (dht *DHT) Provide(ctx context.Context, key cid.Cid, announce bool) error { + if dht.activeWAN() { + return dht.WAN.Provide(ctx, key, announce) + } + return dht.LAN.Provide(ctx, key, announce) +} + +// FindProvidersAsync searches for peers who are able to provide a given key +func (dht *DHT) FindProvidersAsync(ctx context.Context, key cid.Cid, count int) <-chan peer.AddrInfo { + reqCtx, cancel := context.WithCancel(ctx) + outCh := make(chan peer.AddrInfo) + wanCh := dht.WAN.FindProvidersAsync(reqCtx, key, count) + lanCh := dht.LAN.FindProvidersAsync(reqCtx, key, count) + zeroCount := (count == 0) + go func() { + defer cancel() + defer close(outCh) + + found := make(map[peer.ID]struct{}, count) + var pi peer.AddrInfo + for (zeroCount || count > 0) && (wanCh != nil || lanCh != nil) { + var ok bool + select { + case pi, ok = <-wanCh: + if !ok { + wanCh = nil + continue + } + case pi, ok = <-lanCh: + if !ok { + lanCh = nil + continue + } + } + // already found + if _, ok = found[pi.ID]; ok { + continue + } + + select { + case outCh <- pi: + found[pi.ID] = struct{}{} + count-- + case <-ctx.Done(): + return + } + } + }() + return outCh +} + +// FindPeer searches for a peer with given ID +// Note: with signed peer records, we can change this to short circuit once either DHT returns. +func (dht *DHT) FindPeer(ctx context.Context, pid peer.ID) (peer.AddrInfo, error) { + var wg sync.WaitGroup + wg.Add(2) + var wanInfo, lanInfo peer.AddrInfo + var wanErr, lanErr error + go func() { + defer wg.Done() + wanInfo, wanErr = dht.WAN.FindPeer(ctx, pid) + }() + go func() { + defer wg.Done() + lanInfo, lanErr = dht.LAN.FindPeer(ctx, pid) + }() + + wg.Wait() + + return peer.AddrInfo{ + ID: pid, + Addrs: append(wanInfo.Addrs, lanInfo.Addrs...), + }, multierror.Append(wanErr, lanErr).ErrorOrNil() +} + +// Bootstrap allows callers to hint to the routing system to get into a +// Boostrapped state and remain there. +func (dht *DHT) Bootstrap(ctx context.Context) error { + erra := dht.WAN.Bootstrap(ctx) + errb := dht.LAN.Bootstrap(ctx) + return multierror.Append(erra, errb).ErrorOrNil() +} + +// PutValue adds value corresponding to given Key. +func (dht *DHT) PutValue(ctx context.Context, key string, val []byte, opts ...routing.Option) error { + if dht.activeWAN() { + return dht.WAN.PutValue(ctx, key, val, opts...) + } + return dht.LAN.PutValue(ctx, key, val, opts...) +} + +// GetValue searches for the value corresponding to given Key. +func (d *DHT) GetValue(ctx context.Context, key string, opts ...routing.Option) ([]byte, error) { + lanCtx, cancelLan := context.WithCancel(ctx) + defer cancelLan() + + var ( + lanVal []byte + lanErr error + lanWaiter sync.WaitGroup + ) + lanWaiter.Add(1) + go func() { + defer lanWaiter.Done() + lanVal, lanErr = d.LAN.GetValue(lanCtx, key, opts...) + }() + + wanVal, wanErr := d.WAN.GetValue(ctx, key, opts...) + if wanErr == nil { + cancelLan() + } + lanWaiter.Wait() + if wanErr != nil { + if lanErr != nil { + return nil, multierror.Append(wanErr, lanErr).ErrorOrNil() + } + return lanVal, nil + } + return wanVal, nil +} + +// SearchValue searches for better values from this value +func (dht *DHT) SearchValue(ctx context.Context, key string, opts ...routing.Option) (<-chan []byte, error) { + p := helper.Parallel{Routers: []routing.Routing{dht.WAN, dht.LAN}, Validator: dht.WAN.Validator} + return p.SearchValue(ctx, key, opts...) +} + +// GetPublicKey returns the public key for the given peer. +func (dht *DHT) GetPublicKey(ctx context.Context, pid peer.ID) (ci.PubKey, error) { + p := helper.Parallel{Routers: []routing.Routing{dht.WAN, dht.LAN}, Validator: dht.WAN.Validator} + return p.GetPublicKey(ctx, pid) +} diff --git a/dual/dual_test.go b/dual/dual_test.go new file mode 100644 index 000000000..87349f3c8 --- /dev/null +++ b/dual/dual_test.go @@ -0,0 +1,369 @@ +package dual + +import ( + "context" + "testing" + "time" + + "github.com/ipfs/go-cid" + u "github.com/ipfs/go-ipfs-util" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + peerstore "github.com/libp2p/go-libp2p-core/peerstore" + dht "github.com/libp2p/go-libp2p-kad-dht" + test "github.com/libp2p/go-libp2p-kad-dht/internal/testing" + record "github.com/libp2p/go-libp2p-record" + swarmt "github.com/libp2p/go-libp2p-swarm/testing" + bhost "github.com/libp2p/go-libp2p/p2p/host/basic" +) + +var wancid, lancid cid.Cid + +func init() { + wancid = cid.NewCidV1(cid.DagCBOR, u.Hash([]byte("wan cid -- value"))) + lancid = cid.NewCidV1(cid.DagCBOR, u.Hash([]byte("lan cid -- value"))) +} + +type blankValidator struct{} + +func (blankValidator) Validate(_ string, _ []byte) error { return nil } +func (blankValidator) Select(_ string, _ [][]byte) (int, error) { return 0, nil } + +type customRtHelper struct { + allow peer.ID +} + +func MkFilterForPeer() (func(d *dht.IpfsDHT, conns []network.Conn) bool, *customRtHelper) { + helper := customRtHelper{} + f := func(_ *dht.IpfsDHT, conns []network.Conn) bool { + for _, c := range conns { + if c.RemotePeer() == helper.allow { + return true + } + } + return false + } + return f, &helper +} + +func setupDHTWithFilters(ctx context.Context, t *testing.T, options ...dht.Option) (*DHT, []*customRtHelper) { + h := bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)) + + wanFilter, wanRef := MkFilterForPeer() + wanOpts := []dht.Option{ + dht.NamespacedValidator("v", blankValidator{}), + dht.ProtocolPrefix("/test"), + dht.DisableAutoRefresh(), + dht.RoutingTableFilter(wanFilter), + } + wan, err := dht.New(ctx, h, wanOpts...) + if err != nil { + t.Fatal(err) + } + + lanFilter, lanRef := MkFilterForPeer() + lanOpts := []dht.Option{ + dht.NamespacedValidator("v", blankValidator{}), + dht.ProtocolPrefix("/test"), + dht.ProtocolExtension(LanExtension), + dht.DisableAutoRefresh(), + dht.RoutingTableFilter(lanFilter), + dht.Mode(dht.ModeServer), + } + lan, err := dht.New(ctx, h, lanOpts...) + if err != nil { + t.Fatal(err) + } + + impl := DHT{wan, lan} + return &impl, []*customRtHelper{wanRef, lanRef} +} + +func setupDHT(ctx context.Context, t *testing.T, options ...dht.Option) *DHT { + t.Helper() + baseOpts := []dht.Option{ + dht.NamespacedValidator("v", blankValidator{}), + dht.ProtocolPrefix("/test"), + dht.DisableAutoRefresh(), + } + + d, err := New( + ctx, + bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)), + append(baseOpts, options...)..., + ) + if err != nil { + t.Fatal(err) + } + return d +} + +func connect(ctx context.Context, t *testing.T, a, b *dht.IpfsDHT) { + t.Helper() + bid := b.PeerID() + baddr := b.Host().Peerstore().Addrs(bid) + if len(baddr) == 0 { + t.Fatal("no addresses for connection.") + } + a.Host().Peerstore().AddAddrs(bid, baddr, peerstore.TempAddrTTL) + if err := a.Host().Connect(ctx, peer.AddrInfo{ID: bid}); err != nil { + t.Fatal(err) + } + wait(ctx, t, a, b) +} + +func wait(ctx context.Context, t *testing.T, a, b *dht.IpfsDHT) { + t.Helper() + for a.RoutingTable().Find(b.PeerID()) == "" { + //fmt.Fprintf(os.Stderr, "%v\n", a.RoutingTable().GetPeerInfos()) + select { + case <-ctx.Done(): + t.Fatal(ctx.Err()) + case <-time.After(time.Millisecond * 5): + } + } +} + +func setupTier(ctx context.Context, t *testing.T) (*DHT, *dht.IpfsDHT, *dht.IpfsDHT) { + t.Helper() + baseOpts := []dht.Option{ + dht.NamespacedValidator("v", blankValidator{}), + dht.ProtocolPrefix("/test"), + dht.DisableAutoRefresh(), + } + + d, hlprs := setupDHTWithFilters(ctx, t) + + wan, err := dht.New( + ctx, + bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)), + append(baseOpts, dht.Mode(dht.ModeServer))..., + ) + if err != nil { + t.Fatal(err) + } + hlprs[0].allow = wan.PeerID() + connect(ctx, t, d.WAN, wan) + + lan, err := dht.New( + ctx, + bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)), + append(baseOpts, dht.Mode(dht.ModeServer), dht.ProtocolExtension("/lan"))..., + ) + if err != nil { + t.Fatal(err) + } + hlprs[1].allow = lan.PeerID() + connect(ctx, t, d.LAN, lan) + + return d, wan, lan +} + +func TestDualModes(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + d := setupDHT(ctx, t) + defer d.Close() + + if d.WAN.Mode() != dht.ModeAuto { + t.Fatal("wrong default mode for wan") + } else if d.LAN.Mode() != dht.ModeServer { + t.Fatal("wrong default mode for lan") + } + + d2 := setupDHT(ctx, t, dht.Mode(dht.ModeClient)) + defer d2.Close() + if d2.WAN.Mode() != dht.ModeClient || + d2.LAN.Mode() != dht.ModeClient { + t.Fatal("wrong client mode operation") + } +} + +func TestFindProviderAsync(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + d, wan, lan := setupTier(ctx, t) + defer d.Close() + defer wan.Close() + defer lan.Close() + + time.Sleep(5 * time.Millisecond) + + if err := wan.Provide(ctx, wancid, false); err != nil { + t.Fatal(err) + } + + if err := lan.Provide(ctx, lancid, true); err != nil { + t.Fatal(err) + } + + wpc := d.FindProvidersAsync(ctx, wancid, 1) + select { + case p := <-wpc: + if p.ID != wan.PeerID() { + t.Fatal("wrong wan provider") + } + case <-ctx.Done(): + t.Fatal("find provider timeout.") + } + + lpc := d.FindProvidersAsync(ctx, lancid, 1) + select { + case p := <-lpc: + if p.ID != lan.PeerID() { + t.Fatal("wrong lan provider") + } + case <-ctx.Done(): + t.Fatal("find provider timeout.") + } +} + +func TestValueGetSet(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + d, wan, lan := setupTier(ctx, t) + defer d.Close() + defer wan.Close() + defer lan.Close() + + time.Sleep(5 * time.Millisecond) + + err := d.PutValue(ctx, "/v/hello", []byte("valid")) + if err != nil { + t.Fatal(err) + } + val, err := wan.GetValue(ctx, "/v/hello") + if err != nil { + t.Fatal(err) + } + if string(val) != "valid" { + t.Fatal("failed to get expected string.") + } + + _, err = lan.GetValue(ctx, "/v/hello") + if err == nil { + t.Fatal(err) + } +} + +func TestSearchValue(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + d, wan, lan := setupTier(ctx, t) + defer d.Close() + defer wan.Close() + defer lan.Close() + + d.WAN.Validator.(record.NamespacedValidator)["v"] = test.TestValidator{} + d.LAN.Validator.(record.NamespacedValidator)["v"] = test.TestValidator{} + + _ = wan.PutValue(ctx, "/v/hello", []byte("valid")) + + valCh, err := d.SearchValue(ctx, "/v/hello", dht.Quorum(0)) + if err != nil { + t.Fatal(err) + } + + select { + case v := <-valCh: + if string(v) != "valid" { + t.Errorf("expected 'valid', got '%s'", string(v)) + } + case <-ctx.Done(): + t.Fatal(ctx.Err()) + } + + select { + case _, ok := <-valCh: + if ok { + t.Errorf("chan should close") + } + case <-ctx.Done(): + t.Fatal(ctx.Err()) + } + + err = lan.PutValue(ctx, "/v/hello", []byte("newer")) + if err != nil { + t.Error(err) + } + + valCh, err = d.SearchValue(ctx, "/v/hello", dht.Quorum(0)) + if err != nil { + t.Fatal(err) + } + + var lastVal []byte + for c := range valCh { + lastVal = c + } + if string(lastVal) != "newer" { + t.Fatal("incorrect best search value") + } +} + +func TestGetPublicKey(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + d, wan, lan := setupTier(ctx, t) + defer d.Close() + defer wan.Close() + defer lan.Close() + + time.Sleep(5 * time.Millisecond) + + pk, err := d.GetPublicKey(ctx, wan.PeerID()) + if err != nil { + t.Fatal(err) + } + id, err := peer.IDFromPublicKey(pk) + if err != nil { + t.Fatal(err) + } + if id != wan.PeerID() { + t.Fatal("incorrect PK") + } + + pk, err = d.GetPublicKey(ctx, lan.PeerID()) + if err != nil { + t.Fatal(err) + } + id, err = peer.IDFromPublicKey(pk) + if err != nil { + t.Fatal(err) + } + if id != lan.PeerID() { + t.Fatal("incorrect PK") + } +} + +func TestFindPeer(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + d, wan, lan := setupTier(ctx, t) + defer d.Close() + defer wan.Close() + defer lan.Close() + + time.Sleep(5 * time.Millisecond) + + p, err := d.FindPeer(ctx, lan.PeerID()) + if err != nil { + t.Fatal(err) + } + if len(p.Addrs) == 0 { + t.Fatal("expeced find peer to find addresses.") + } + p, err = d.FindPeer(ctx, wan.PeerID()) + if err != nil { + t.Fatal(err) + } + if len(p.Addrs) == 0 { + t.Fatal("expeced find peer to find addresses.") + } +} diff --git a/go.mod b/go.mod index edd3a3f12..14f8c830c 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/libp2p/go-libp2p-kbucket v0.3.3 github.com/libp2p/go-libp2p-peerstore v0.2.2 github.com/libp2p/go-libp2p-record v0.1.2 + github.com/libp2p/go-libp2p-routing-helpers v0.2.0 github.com/libp2p/go-libp2p-swarm v0.2.3 github.com/libp2p/go-libp2p-testing v0.1.1 github.com/libp2p/go-msgio v0.0.4 diff --git a/go.sum b/go.sum index c9bd54f76..61bb8f5a0 100644 --- a/go.sum +++ b/go.sum @@ -79,6 +79,7 @@ github.com/gxed/hashland/murmur3 v0.0.1 h1:SheiaIt0sda5K+8FLz952/1iWS9zrnKsEJaOJ github.com/gxed/hashland/murmur3 v0.0.1/go.mod h1:KjXop02n4/ckmZSnY2+HKcLud/tcmvhST0bie/0lS48= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= github.com/hashicorp/go-multierror v1.1.0 h1:B9UzwGQJehnUY1yNrnwREHc3fGbC2xefo8g4TbElacI= github.com/hashicorp/go-multierror v1.1.0/go.mod h1:spPvp8C1qA32ftKqdAHm4hHTbPw+vmowP0z+KUhOZdA= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= @@ -264,6 +265,8 @@ github.com/libp2p/go-libp2p-pnet v0.2.0 h1:J6htxttBipJujEjz1y0a5+eYoiPcFHhSYHH6n github.com/libp2p/go-libp2p-pnet v0.2.0/go.mod h1:Qqvq6JH/oMZGwqs3N1Fqhv8NVhrdYcO0BW4wssv21LA= github.com/libp2p/go-libp2p-record v0.1.2 h1:M50VKzWnmUrk/M5/Dz99qO9Xh4vs8ijsK+7HkJvRP+0= github.com/libp2p/go-libp2p-record v0.1.2/go.mod h1:pal0eNcT5nqZaTV7UGhqeGqxFgGdsU/9W//C8dqjQDk= +github.com/libp2p/go-libp2p-routing-helpers v0.2.0 h1:+QKTsx2Bg0q3oueQ9CopTwKN5NsnF+qEC+sbkSVXnsU= +github.com/libp2p/go-libp2p-routing-helpers v0.2.0/go.mod h1:Db+7LRSPImkV9fOKsNWVW5IXyy9XDse92lUtO3O+jlo= github.com/libp2p/go-libp2p-secio v0.1.0/go.mod h1:tMJo2w7h3+wN4pgU2LSYeiKPrfqBgkOsdiKK77hE7c8= github.com/libp2p/go-libp2p-secio v0.2.0 h1:ywzZBsWEEz2KNTn5RtzauEDq5RFEefPsttXYwAWqHng= github.com/libp2p/go-libp2p-secio v0.2.0/go.mod h1:2JdZepB8J5V9mBp79BmwsaPQhRPNN2NrnB2lKQcdy6g= diff --git a/internal/testing/helper.go b/internal/testing/helper.go new file mode 100644 index 000000000..52961f3f8 --- /dev/null +++ b/internal/testing/helper.go @@ -0,0 +1,31 @@ +package testing + +import ( + "bytes" + "errors" +) + +type TestValidator struct{} + +func (TestValidator) Select(_ string, bs [][]byte) (int, error) { + index := -1 + for i, b := range bs { + if bytes.Equal(b, []byte("newer")) { + index = i + } else if bytes.Equal(b, []byte("valid")) { + if index == -1 { + index = i + } + } + } + if index == -1 { + return -1, errors.New("no rec found") + } + return index, nil +} +func (TestValidator) Validate(_ string, b []byte) error { + if bytes.Equal(b, []byte("expired")) { + return errors.New("expired") + } + return nil +} diff --git a/subscriber_notifee.go b/subscriber_notifee.go index f29bfe98c..f1a6a3efd 100644 --- a/subscriber_notifee.go +++ b/subscriber_notifee.go @@ -41,7 +41,7 @@ func newSubscriberNotifiee(dht *IpfsDHT) (*subscriberNotifee, error) { // register for event bus local routability changes in order to trigger switching between client and server modes // only register for events if the DHT is operating in ModeAuto - if dht.auto { + if dht.auto == ModeAuto || dht.auto == ModeAutoServer { evts = append(evts, new(event.EvtLocalReachabilityChanged)) } @@ -96,7 +96,7 @@ func (nn *subscriberNotifee) subscribe(proc goprocess.Process) { case event.EvtPeerIdentificationCompleted: handlePeerIdentificationCompletedEvent(dht, evt) case event.EvtLocalReachabilityChanged: - if dht.auto { + if dht.auto == ModeAuto || dht.auto == ModeAutoServer { handleLocalReachabilityChangedEvent(dht, evt) } else { // something has gone really wrong if we get an event we did not subscribe to @@ -150,8 +150,14 @@ func handleLocalReachabilityChangedEvent(dht *IpfsDHT, e event.EvtLocalReachabil var target mode switch e.Reachability { - case network.ReachabilityPrivate, network.ReachabilityUnknown: + case network.ReachabilityPrivate: target = modeClient + case network.ReachabilityUnknown: + if dht.auto == ModeAutoServer { + target = modeServer + } else { + target = modeClient + } case network.ReachabilityPublic: target = modeServer }