From b2a323ea4cf42222bc689b50488290d3030ee97f Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Thu, 26 Apr 2018 14:55:01 -0700 Subject: [PATCH] replace GetValues with a GetValue that takes options Changes for: https://github.com/libp2p/go-libp2p-routing/pull/21 (also fixes some test issues) --- ext_test.go | 6 +- options.go | 29 +++++++++ records.go | 9 +-- routing.go | 166 ++++++++++++++++++++++++++++++++-------------------- 4 files changed, 137 insertions(+), 73 deletions(-) create mode 100644 options.go diff --git a/ext_test.go b/ext_test.go index 2a5340b32..ce8fd45e7 100644 --- a/ext_test.go +++ b/ext_test.go @@ -2,7 +2,6 @@ package dht import ( "context" - "io" "math/rand" "testing" "time" @@ -35,9 +34,8 @@ func TestGetFailures(t *testing.T) { d := NewDHT(ctx, hosts[0], tsds) d.Update(ctx, hosts[1].ID()) - // Reply with failures to every message hosts[1].SetStreamHandler(ProtocolDHT, func(s inet.Stream) { - s.Close() + // hang forever }) // This one should time out @@ -48,7 +46,7 @@ func TestGetFailures(t *testing.T) { err = merr[0] } - if err != io.EOF { + if err != context.DeadlineExceeded { t.Fatal("Got different error than we expected", err) } } else { diff --git a/options.go b/options.go new file mode 100644 index 000000000..8f4978b5a --- /dev/null +++ b/options.go @@ -0,0 +1,29 @@ +package dht + +import ( + ropts "github.com/libp2p/go-libp2p-routing/options" +) + +type quorumOptionKey struct{} + +// Quorum is a DHT option that tells the DHT how many peers it needs to get +// values from before returning the best one. +// +// Default: 16 +func Quorum(n int) ropts.Option { + return func(opts *ropts.Options) error { + if opts.Other == nil { + opts.Other = make(map[interface{}]interface{}, 1) + } + opts.Other[quorumOptionKey{}] = n + return nil + } +} + +func getQuorum(opts *ropts.Options) int64 { + responsesNeeded, ok := opts.Other[quorumOptionKey{}].(int) + if !ok { + responsesNeeded = 16 + } + return int64(responsesNeeded) +} diff --git a/records.go b/records.go index 20ab78c7d..29dae735d 100644 --- a/records.go +++ b/records.go @@ -78,17 +78,12 @@ func (dht *IpfsDHT) getPublicKeyFromDHT(ctx context.Context, p peer.ID) (ci.PubK // Only retrieve one value, because the public key is immutable // so there's no need to retrieve multiple versions pkkey := routing.KeyForPublicKey(p) - vals, err := dht.GetValues(ctx, pkkey, 1) + val, err := dht.GetValue(ctx, pkkey, Quorum(1)) if err != nil { return nil, err } - if len(vals) == 0 || vals[0].Val == nil { - log.Debugf("Could not find public key for %v in DHT", p) - return nil, routing.ErrNotFound - } - - pubk, err := ci.UnmarshalPublicKey(vals[0].Val) + pubk, err := ci.UnmarshalPublicKey(val) if err != nil { log.Errorf("Could not unmarshall public key retrieved from DHT for %v", p) return nil, err diff --git a/routing.go b/routing.go index 722009454..490e5532a 100644 --- a/routing.go +++ b/routing.go @@ -6,6 +6,7 @@ import ( "fmt" "runtime" "sync" + "sync/atomic" "time" proto "github.com/gogo/protobuf/proto" @@ -21,6 +22,7 @@ import ( record "github.com/libp2p/go-libp2p-record" routing "github.com/libp2p/go-libp2p-routing" notif "github.com/libp2p/go-libp2p-routing/notifications" + ropts "github.com/libp2p/go-libp2p-routing/options" ) // asyncQueryBuffer is the size of buffered channels in async queries. This @@ -35,7 +37,7 @@ var asyncQueryBuffer = 10 // PutValue adds value corresponding to given Key. // This is the top level "Store" operation of the DHT -func (dht *IpfsDHT) PutValue(ctx context.Context, key string, value []byte) (err error) { +func (dht *IpfsDHT) PutValue(ctx context.Context, key string, value []byte, options ...ropts.Option) (err error) { eip := log.EventBegin(ctx, "PutValue") defer func() { eip.Append(loggableKey(key)) @@ -46,6 +48,8 @@ func (dht *IpfsDHT) PutValue(ctx context.Context, key string, value []byte) (err }() log.Debugf("PutValue %s", key) + // TODO: How to handle the offline option? + rec := record.MakePutRecord(key, value) rec.TimeReceived = proto.String(u.FormatRFC3339(time.Now())) err = dht.putLocal(key, rec) @@ -81,7 +85,7 @@ func (dht *IpfsDHT) PutValue(ctx context.Context, key string, value []byte) (err } // GetValue searches for the value corresponding to given Key. -func (dht *IpfsDHT) GetValue(ctx context.Context, key string) (_ []byte, err error) { +func (dht *IpfsDHT) GetValue(ctx context.Context, key string, opts ...ropts.Option) (_ []byte, err error) { eip := log.EventBegin(ctx, "GetValue") defer func() { eip.Append(loggableKey(key)) @@ -93,39 +97,55 @@ func (dht *IpfsDHT) GetValue(ctx context.Context, key string) (_ []byte, err err ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() - vals, err := dht.GetValues(ctx, key, 16) - if err != nil { + var cfg ropts.Options + if err := cfg.Apply(opts...); err != nil { return nil, err } - recs := make([][]byte, 0, len(vals)) - for _, v := range vals { - if v.Val != nil { - recs = append(recs, v.Val) - } - } - if len(recs) == 0 { - return nil, routing.ErrNotFound - } - - i, err := dht.Selector.BestRecord(key, recs) + results, err := dht.getValues(ctx, key, &cfg) if err != nil { return nil, err } - best := recs[i] - log.Debugf("GetValue %v %v", key, best) - if best == nil { - log.Errorf("GetValue yielded correct record with nil value.") - return nil, routing.ErrNotFound + var outdatedPeers, currentPeers []peer.ID + + var best []byte + for result := range results { + switch { + case result.val == nil: + outdatedPeers = append(outdatedPeers, result.from) + case best == nil: + best = result.val + fallthrough + case bytes.Equal(result.val, best): + currentPeers = append(currentPeers, result.from) + default: + i, err := dht.Selector.BestRecord(key, [][]byte{best, result.val}) + if err != nil { + log.Error(err) + return nil, err + } + switch i { + case 0: + outdatedPeers = append(outdatedPeers, result.from) + case 1: + outdatedPeers = append(outdatedPeers, currentPeers...) + currentPeers = append(currentPeers[:0], result.from) + default: + err := fmt.Errorf("invalid bad selector for key: %s", loggableKey(key)) + log.Error(err) + return nil, err + } + } } - fixupRec := record.MakePutRecord(key, best) - for _, v := range vals { - // if someone sent us a different 'less-valid' record, lets correct them - if !bytes.Equal(v.Val, best) { - go func(v routing.RecvdVal) { - if v.From == dht.self { + // if someone sent us a different 'less-valid' record, lets correct them + if best != nil && len(outdatedPeers) > 0 { + fixupRec := record.MakePutRecord(key, best) + for _, p := range outdatedPeers { + // TODO: Use a worker. + go func(p peer.ID) { + if p == dht.self { err := dht.putLocal(key, fixupRec) if err != nil { log.Error("Error correcting local dht entry:", err) @@ -134,45 +154,60 @@ func (dht *IpfsDHT) GetValue(ctx context.Context, key string) (_ []byte, err err } ctx, cancel := context.WithTimeout(dht.Context(), time.Second*30) defer cancel() - err := dht.putValueToPeer(ctx, v.From, key, fixupRec) + err := dht.putValueToPeer(ctx, p, key, fixupRec) if err != nil { log.Error("Error correcting DHT entry: ", err) } - }(v) + }(p) } } + if err := ctx.Err(); err != nil { + return best, err + } + if best == nil { + return nil, routing.ErrNotFound + } + return best, nil } -func (dht *IpfsDHT) GetValues(ctx context.Context, key string, nvals int) (_ []routing.RecvdVal, err error) { - eip := log.EventBegin(ctx, "GetValues") +type recvdVal struct { + val []byte + from peer.ID +} + +func (dht *IpfsDHT) getValues(ctx context.Context, key string, opts *ropts.Options) (_ <-chan recvdVal, _err error) { + eip := log.EventBegin(ctx, "getValues") defer func() { eip.Append(loggableKey(key)) - if err != nil { - eip.SetError(err) + if _err != nil { + eip.SetError(_err) } eip.Done() }() - vals := make([]routing.RecvdVal, 0, nvals) - var valslock sync.Mutex - // If we have it local, don't bother doing an RPC! + if err := ctx.Err(); err != nil { + return nil, err + } + + vals := make(chan recvdVal, 1) + + responsesNeeded := getQuorum(opts) + + // If we have it locally, don't bother doing an RPC! lrec, err := dht.getLocal(key) if err == nil { - // TODO: this is tricky, we don't always want to trust our own value - // what if the authoritative source updated it? - log.Debug("have it locally") - vals = append(vals, routing.RecvdVal{ - Val: lrec.GetValue(), - From: dht.self, - }) - - if nvals <= 1 { - return vals, nil + vals <- recvdVal{ + val: lrec.GetValue(), + from: dht.self, } - } else if nvals == 0 { - return nil, err + responsesNeeded-- + } + + if opts.Offline || responsesNeeded <= 0 { + close(vals) + return vals, nil } // get closest peers in the routing table @@ -212,18 +247,17 @@ func (dht *IpfsDHT) GetValues(ctx context.Context, key string, nvals int) (_ []r res := &dhtQueryResult{closerPeers: peers} if rec.GetValue() != nil || err == errInvalidRecord { - rv := routing.RecvdVal{ - Val: rec.GetValue(), - From: p, + select { + case vals <- recvdVal{ + val: rec.GetValue(), + from: p, + }: + case <-ctx.Done(): + return nil, ctx.Err() } - valslock.Lock() - vals = append(vals, rv) - - // If we have collected enough records, we're done - if len(vals) >= nvals { + if atomic.AddInt64(&responsesNeeded, -1) <= 0 { res.success = true } - valslock.Unlock() } notif.PublishQueryEvent(parent, ¬if.QueryEvent{ @@ -235,13 +269,21 @@ func (dht *IpfsDHT) GetValues(ctx context.Context, key string, nvals int) (_ []r return res, nil }) - // run it! - _, err = query.Run(ctx, rtp) - if len(vals) == 0 { - if err != nil { - return nil, err + go func() { + defer close(vals) + // run it! + _, err := query.Run(ctx, rtp) + + if err == nil || ctx.Err() != nil { + return } - } + + // Not much we can do about the error. + // Any error that's *not* a context related error is likely a + // programmer error. There's not much that a user can do about + // it and no real reason to expose it. + log.Error(err) + }() return vals, nil }