diff --git a/bitswap.go b/bitswap.go index af648972..4a15fc58 100644 --- a/bitswap.go +++ b/bitswap.go @@ -148,6 +148,16 @@ func SetSimulateDontHavesOnTimeout(send bool) Option { } } +type TaskInfo = decision.TaskInfo +type TaskComparator = decision.TaskComparator + +// WithTaskComparator configures custom task prioritization logic. +func WithTaskComparator(comparator TaskComparator) Option { + return func(bs *Bitswap) { + bs.taskComparator = comparator + } +} + // New initializes a BitSwap instance that communicates over the provided // BitSwapNetwork. This function registers the returned instance as the network // delegate. Runs until context is cancelled or bitswap.Close is called. @@ -272,6 +282,7 @@ func New(parent context.Context, network bsnet.BitSwapNetwork, activeEngineGauge, pendingBlocksGauge, activeBlocksGauge, + decision.WithTaskComparator(bs.taskComparator), ) bs.engine.SetSendDontHaves(bs.engineSetSendDontHaves) @@ -375,6 +386,8 @@ type Bitswap struct { // whether we should actually simulate dont haves on request timeout simulateDontHavesOnTimeout bool + + taskComparator TaskComparator } type counters struct { diff --git a/go.mod b/go.mod index 88e2eba5..548ca0de 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/ipfs/go-ipfs-util v0.0.2 github.com/ipfs/go-log v1.0.5 github.com/ipfs/go-metrics-interface v0.0.1 - github.com/ipfs/go-peertaskqueue v0.4.0 + github.com/ipfs/go-peertaskqueue v0.6.0 github.com/jbenet/goprocess v0.1.4 github.com/libp2p/go-buffer-pool v0.0.2 github.com/libp2p/go-libp2p v0.14.3 diff --git a/go.sum b/go.sum index 939c92b0..62b20d19 100644 --- a/go.sum +++ b/go.sum @@ -302,8 +302,8 @@ github.com/ipfs/go-log/v2 v2.1.3 h1:1iS3IU7aXRlbgUpN8yTTpJ53NXYjAe37vcI5+5nYrzk= github.com/ipfs/go-log/v2 v2.1.3/go.mod h1:/8d0SH3Su5Ooc31QlL1WysJhvyOTDCjcCZ9Axpmri6g= github.com/ipfs/go-metrics-interface v0.0.1 h1:j+cpbjYvu4R8zbleSs36gvB7jR+wsL2fGD6n0jO4kdg= github.com/ipfs/go-metrics-interface v0.0.1/go.mod h1:6s6euYU4zowdslK0GKHmqaIZ3j/b/tL7HTWtJ4VPgWY= -github.com/ipfs/go-peertaskqueue v0.4.0 h1:x1hFgA4JOUJ3ntPfqLRu6v4k6kKL0p07r3RSg9JNyHI= -github.com/ipfs/go-peertaskqueue v0.4.0/go.mod h1:KL9F49hXJMoXCad8e5anivjN+kWdr+CyGcyh4K6doLc= +github.com/ipfs/go-peertaskqueue v0.6.0 h1:BT1/PuNViVomiz1PnnP5+WmKsTNHrxIDvkZrkj4JhOg= +github.com/ipfs/go-peertaskqueue v0.6.0/go.mod h1:M/akTIE/z1jGNXMU7kFB4TeSEFvj68ow0Rrb04donIU= github.com/jackpal/gateway v1.0.5/go.mod h1:lTpwd4ACLXmpyiCTRtfiNyVnUmqT9RivzCDQetPfnjA= github.com/jackpal/go-nat-pmp v1.0.1/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc= github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus= diff --git a/internal/decision/engine.go b/internal/decision/engine.go index df49f0bc..4426d8ce 100644 --- a/internal/decision/engine.go +++ b/internal/decision/engine.go @@ -19,6 +19,7 @@ import ( "github.com/ipfs/go-metrics-interface" "github.com/ipfs/go-peertaskqueue" "github.com/ipfs/go-peertaskqueue/peertask" + "github.com/ipfs/go-peertaskqueue/peertracker" process "github.com/jbenet/goprocess" "github.com/libp2p/go-libp2p-core/peer" ) @@ -175,6 +176,60 @@ type Engine struct { // used to ensure metrics are reported each fixed number of operation metricsLock sync.Mutex metricUpdateCounter int + + taskComparator TaskComparator +} + +// TaskInfo represents the details of a request from a peer. +type TaskInfo struct { + Peer peer.ID + // The CID of the block + Cid cid.Cid + // Tasks can be want-have or want-block + IsWantBlock bool + // Whether to immediately send a response if the block is not found + SendDontHave bool + // The size of the block corresponding to the task + BlockSize int + // Whether the block was found + HaveBlock bool +} + +// TaskComparator is used for task prioritization. +// It should return true if task 'ta' has higher priority than task 'tb' +type TaskComparator func(ta, tb *TaskInfo) bool + +type Option func(*Engine) + +func WithTaskComparator(comparator TaskComparator) Option { + return func(e *Engine) { + e.taskComparator = comparator + } +} + +// wrapTaskComparator wraps a TaskComparator so it can be used as a QueueTaskComparator +func wrapTaskComparator(tc TaskComparator) peertask.QueueTaskComparator { + return func(a, b *peertask.QueueTask) bool { + taskDataA := a.Task.Data.(*taskData) + taskInfoA := &TaskInfo{ + Peer: a.Target, + Cid: a.Task.Topic.(cid.Cid), + IsWantBlock: taskDataA.IsWantBlock, + SendDontHave: taskDataA.SendDontHave, + BlockSize: taskDataA.BlockSize, + HaveBlock: taskDataA.HaveBlock, + } + taskDataB := b.Task.Data.(*taskData) + taskInfoB := &TaskInfo{ + Peer: b.Target, + Cid: b.Task.Topic.(cid.Cid), + IsWantBlock: taskDataB.IsWantBlock, + SendDontHave: taskDataB.SendDontHave, + BlockSize: taskDataB.BlockSize, + HaveBlock: taskDataB.HaveBlock, + } + return tc(taskInfoA, taskInfoB) + } } // NewEngine creates a new block sending engine for the given block store. @@ -192,6 +247,7 @@ func NewEngine( activeEngineGauge metrics.Gauge, pendingBlocksGauge metrics.Gauge, activeBlocksGauge metrics.Gauge, + opts ...Option, ) *Engine { return newEngine( ctx, @@ -207,6 +263,7 @@ func NewEngine( activeEngineGauge, pendingBlocksGauge, activeBlocksGauge, + opts..., ) } @@ -223,6 +280,7 @@ func newEngine( activeEngineGauge metrics.Gauge, pendingBlocksGauge metrics.Gauge, activeBlocksGauge metrics.Gauge, + opts ...Option, ) *Engine { if scoreLedger == nil { @@ -247,12 +305,28 @@ func newEngine( } e.tagQueued = fmt.Sprintf(tagFormat, "queued", uuid.New().String()) e.tagUseful = fmt.Sprintf(tagFormat, "useful", uuid.New().String()) - e.peerRequestQueue = peertaskqueue.New( + + for _, opt := range opts { + opt(e) + } + + // default peer task queue options + peerTaskQueueOpts := []peertaskqueue.Option{ peertaskqueue.OnPeerAddedHook(e.onPeerAdded), peertaskqueue.OnPeerRemovedHook(e.onPeerRemoved), peertaskqueue.TaskMerger(newTaskMerger()), peertaskqueue.IgnoreFreezing(true), - peertaskqueue.MaxOutstandingWorkPerPeer(maxOutstandingBytesPerPeer)) + peertaskqueue.MaxOutstandingWorkPerPeer(maxOutstandingBytesPerPeer), + } + + if e.taskComparator != nil { + queueTaskComparator := wrapTaskComparator(e.taskComparator) + peerTaskQueueOpts = append(peerTaskQueueOpts, peertaskqueue.PeerComparator(peertracker.TaskPriorityPeerComparator(queueTaskComparator))) + peerTaskQueueOpts = append(peerTaskQueueOpts, peertaskqueue.TaskComparator(queueTaskComparator)) + } + + e.peerRequestQueue = peertaskqueue.New(peerTaskQueueOpts...) + return e } diff --git a/internal/decision/engine_test.go b/internal/decision/engine_test.go index d8445fde..acde1795 100644 --- a/internal/decision/engine_test.go +++ b/internal/decision/engine_test.go @@ -18,6 +18,7 @@ import ( "github.com/ipfs/go-metrics-interface" blocks "github.com/ipfs/go-block-format" + "github.com/ipfs/go-cid" ds "github.com/ipfs/go-datastore" dssync "github.com/ipfs/go-datastore/sync" blockstore "github.com/ipfs/go-ipfs-blockstore" @@ -92,14 +93,14 @@ type engineSet struct { Blockstore blockstore.Blockstore } -func newTestEngine(ctx context.Context, idStr string) engineSet { - return newTestEngineWithSampling(ctx, idStr, shortTerm, nil, clock.New()) +func newTestEngine(ctx context.Context, idStr string, opts ...Option) engineSet { + return newTestEngineWithSampling(ctx, idStr, shortTerm, nil, clock.New(), opts...) } -func newTestEngineWithSampling(ctx context.Context, idStr string, peerSampleInterval time.Duration, sampleCh chan struct{}, clock clock.Clock) engineSet { +func newTestEngineWithSampling(ctx context.Context, idStr string, peerSampleInterval time.Duration, sampleCh chan struct{}, clock clock.Clock, opts ...Option) engineSet { fpt := &fakePeerTagger{} bs := blockstore.NewBlockstore(dssync.MutexWrap(ds.NewMapDatastore())) - e := newEngineForTesting(ctx, bs, 4, defaults.BitswapEngineTaskWorkerCount, defaults.BitswapMaxOutstandingBytesPerPeer, fpt, "localhost", 0, NewTestScoreLedger(peerSampleInterval, sampleCh, clock)) + e := newEngineForTesting(ctx, bs, 4, defaults.BitswapEngineTaskWorkerCount, defaults.BitswapMaxOutstandingBytesPerPeer, fpt, "localhost", 0, NewTestScoreLedger(peerSampleInterval, sampleCh, clock), opts...) e.StartWorkers(ctx, process.WithTeardown(func() error { return nil })) return engineSet{ Peer: peer.ID(idStr), @@ -193,6 +194,7 @@ func newEngineForTesting( self peer.ID, maxReplaceSize int, scoreLedger ScoreLedger, + opts ...Option, ) *Engine { testPendingEngineGauge := metrics.NewCtx(ctx, "pending_tasks", "Total number of pending tasks").Gauge() testActiveEngineGauge := metrics.NewCtx(ctx, "active_tasks", "Total number of active tasks").Gauge() @@ -212,6 +214,7 @@ func newEngineForTesting( testActiveEngineGauge, testPendingBlocksGauge, testActiveBlocksGauge, + opts..., ) } @@ -1054,6 +1057,61 @@ func TestWantlistForPeer(t *testing.T) { } +func TestTaskComparator(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + keys := []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"} + cids := make(map[cid.Cid]int) + blks := make([]blocks.Block, 0, len(keys)) + for i, letter := range keys { + block := blocks.NewBlock([]byte(letter)) + blks = append(blks, block) + cids[block.Cid()] = i + } + + fpt := &fakePeerTagger{} + sl := NewTestScoreLedger(shortTerm, nil, clock.New()) + bs := blockstore.NewBlockstore(dssync.MutexWrap(ds.NewMapDatastore())) + if err := bs.PutMany(blks); err != nil { + t.Fatal(err) + } + + // use a single task worker so that the order of outgoing messages is deterministic + engineTaskWorkerCount := 1 + e := newEngineForTesting(ctx, bs, 4, engineTaskWorkerCount, defaults.BitswapMaxOutstandingBytesPerPeer, fpt, "localhost", 0, sl, + // if this Option is omitted, the test fails + WithTaskComparator(func(ta, tb *TaskInfo) bool { + // prioritize based on lexicographic ordering of block content + return cids[ta.Cid] < cids[tb.Cid] + }), + ) + e.StartWorkers(ctx, process.WithTeardown(func() error { return nil })) + + // rely on randomness of Go map's iteration order to add Want entries in random order + peerIDs := make([]peer.ID, len(keys)) + for _, i := range cids { + peerID := libp2ptest.RandPeerIDFatal(t) + peerIDs[i] = peerID + partnerWantBlocks(e, keys[i:i+1], peerID) + } + + // check that outgoing messages are sent in the correct order + for i, peerID := range peerIDs { + next := <-e.Outbox() + envelope := <-next + if peerID != envelope.Peer { + t.Errorf("expected message for peer ID %#v but instead got message for peer ID %#v", peerID, envelope.Peer) + } + responseBlocks := envelope.Message.Blocks() + if len(responseBlocks) != 1 { + t.Errorf("expected 1 block in response but instead got %v", len(blks)) + } else if responseBlocks[0].Cid() != blks[i].Cid() { + t.Errorf("expected block with CID %#v but instead got block with CID %#v", blks[i].Cid(), responseBlocks[0].Cid()) + } + } +} + func TestTaggingPeers(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel()