diff --git a/simple/provider.go b/simple/provider.go index b7da4b2..3993d45 100644 --- a/simple/provider.go +++ b/simple/provider.go @@ -5,17 +5,16 @@ package simple import ( "context" + "time" - cid "github.com/ipfs/go-cid" + "github.com/ipfs/go-cid" q "github.com/ipfs/go-ipfs-provider/queue" logging "github.com/ipfs/go-log" - routing "github.com/libp2p/go-libp2p-core/routing" + "github.com/libp2p/go-libp2p-core/routing" ) var logP = logging.Logger("provider.simple") -const provideOutgoingWorkerLimit = 8 - // Provider announces blocks to the network type Provider struct { ctx context.Context @@ -23,15 +22,44 @@ type Provider struct { queue *q.Queue // used to announce providing to the network contentRouting routing.ContentRouting + // how long to wait for announce to complete before giving up + timeout time.Duration + // how many workers concurrently work through thhe queue + workerLimit int +} + +// Option defines the functional option type that can be used to configure +// provider instances +type Option func(*Provider) + +// WithTimeout is an option to set a timeout on a provider +func WithTimeout(timeout time.Duration) Option { + return func(p *Provider) { + p.timeout = timeout + } +} + +// MaxWorkers is an option to set the max workers on a provider +func MaxWorkers(count int) Option { + return func(p *Provider) { + p.workerLimit = count + } } // NewProvider creates a provider that announces blocks to the network using a content router -func NewProvider(ctx context.Context, queue *q.Queue, contentRouting routing.ContentRouting) *Provider { - return &Provider{ +func NewProvider(ctx context.Context, queue *q.Queue, contentRouting routing.ContentRouting, options ...Option) *Provider { + p := &Provider{ ctx: ctx, queue: queue, contentRouting: contentRouting, + workerLimit: 8, } + + for _, option := range options { + option(p) + } + + return p } // Close stops the provider @@ -53,20 +81,33 @@ func (p *Provider) Provide(root cid.Cid) error { // Handle all outgoing cids by providing (announcing) them func (p *Provider) handleAnnouncements() { - for workers := 0; workers < provideOutgoingWorkerLimit; workers++ { + for workers := 0; workers < p.workerLimit; workers++ { go func() { for p.ctx.Err() == nil { select { case <-p.ctx.Done(): return case c := <-p.queue.Dequeue(): - logP.Info("announce - start - ", c) - if err := p.contentRouting.Provide(p.ctx, c, true); err != nil { - logP.Warningf("Unable to provide entry: %s, %s", c, err) - } - logP.Info("announce - end - ", c) + p.doProvide(c) } } }() } } + +func (p *Provider) doProvide(c cid.Cid) { + ctx := p.ctx + if p.timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, p.timeout) + defer cancel() + } else { + ctx = p.ctx + } + + logP.Info("announce - start - ", c) + if err := p.contentRouting.Provide(ctx, c, true); err != nil { + logP.Warningf("Unable to provide entry: %s, %s", c, err) + } + logP.Info("announce - end - ", c) +} diff --git a/simple/provider_test.go b/simple/provider_test.go index 6fbc528..deb0032 100644 --- a/simple/provider_test.go +++ b/simple/provider_test.go @@ -24,7 +24,11 @@ type mockRouting struct { } func (r *mockRouting) Provide(ctx context.Context, cid cid.Cid, recursive bool) error { - r.provided <- cid + select { + case r.provided <- cid: + case <-ctx.Done(): + panic("context cancelled, but shouldn't have") + } return nil } @@ -81,3 +85,47 @@ func TestAnnouncement(t *testing.T) { } } } + +func TestAnnouncementTimeout(t *testing.T) { + ctx := context.Background() + defer ctx.Done() + + ds := sync.MutexWrap(datastore.NewMapDatastore()) + queue, err := q.NewQueue(ctx, "test", ds) + if err != nil { + t.Fatal(err) + } + + r := mockContentRouting() + + prov := NewProvider(ctx, queue, r, WithTimeout(1*time.Second)) + prov.Run() + + cids := cid.NewSet() + + for i := 0; i < 100; i++ { + c := blockGenerator.Next().Cid() + cids.Add(c) + } + + go func() { + for _, c := range cids.Keys() { + err = prov.Provide(c) + // A little goroutine stirring to exercise some different states + r := rand.Intn(10) + time.Sleep(time.Microsecond * time.Duration(r)) + } + }() + + for cids.Len() > 0 { + select { + case cp := <-r.provided: + if !cids.Has(cp) { + t.Fatal("Wrong CID provided") + } + cids.Remove(cp) + case <-time.After(time.Second * 5): + t.Fatal("Timeout waiting for cids to be provided.") + } + } +}