diff --git a/impl/graphsync.go b/impl/graphsync.go index 43cdb12b..b1430e0c 100644 --- a/impl/graphsync.go +++ b/impl/graphsync.go @@ -17,12 +17,14 @@ import ( "github.com/ipfs/go-graphsync/peermanager" "github.com/ipfs/go-graphsync/requestmanager" "github.com/ipfs/go-graphsync/requestmanager/asyncloader" + "github.com/ipfs/go-graphsync/requestmanager/executor" requestorhooks "github.com/ipfs/go-graphsync/requestmanager/hooks" "github.com/ipfs/go-graphsync/responsemanager" responderhooks "github.com/ipfs/go-graphsync/responsemanager/hooks" "github.com/ipfs/go-graphsync/responsemanager/persistenceoptions" "github.com/ipfs/go-graphsync/responsemanager/responseassembler" "github.com/ipfs/go-graphsync/selectorvalidator" + "github.com/ipfs/go-graphsync/taskqueue" ) var log = logging.Logger("graphsync") @@ -40,6 +42,8 @@ type GraphSync struct { requestManager *requestmanager.RequestManager responseManager *responsemanager.ResponseManager asyncLoader *asyncloader.AsyncLoader + requestQueue taskqueue.TaskQueue + requestExecutor *executor.Executor responseAssembler *responseassembler.ResponseAssembler peerTaskQueue *peertaskqueue.PeerTaskQueue peerManager *peermanager.PeerMessageManager @@ -63,12 +67,13 @@ type GraphSync struct { } type graphsyncConfigOptions struct { - totalMaxMemoryResponder uint64 - maxMemoryPerPeerResponder uint64 - totalMaxMemoryRequestor uint64 - maxMemoryPerPeerRequestor uint64 - maxInProgressRequests uint64 - registerDefaultValidator bool + totalMaxMemoryResponder uint64 + maxMemoryPerPeerResponder uint64 + totalMaxMemoryRequestor uint64 + maxMemoryPerPeerRequestor uint64 + maxInProgressIncomingRequests uint64 + maxInProgressOutgoingRequests uint64 + registerDefaultValidator bool } // Option defines the functional option type that can be used to configure @@ -115,11 +120,19 @@ func MaxMemoryPerPeerRequestor(maxMemoryPerPeer uint64) Option { } } -// MaxInProgressRequests changes the maximum number of +// MaxInProgressIncomingRequests changes the maximum number of // graphsync requests that are processed in parallel (default 6) -func MaxInProgressRequests(maxInProgressRequests uint64) Option { +func MaxInProgressIncomingRequests(maxInProgressIncomingRequests uint64) Option { return func(gs *graphsyncConfigOptions) { - gs.maxInProgressRequests = maxInProgressRequests + gs.maxInProgressIncomingRequests = maxInProgressIncomingRequests + } +} + +// MaxInProgressOutgoingRequests changes the maximum number of +// graphsync requests that are processed in parallel (default 6) +func MaxInProgressOutgoingRequests(maxInProgressOutgoingRequests uint64) Option { + return func(gs *graphsyncConfigOptions) { + gs.maxInProgressOutgoingRequests = maxInProgressOutgoingRequests } } @@ -130,12 +143,13 @@ func New(parent context.Context, network gsnet.GraphSyncNetwork, ctx, cancel := context.WithCancel(parent) gsConfig := &graphsyncConfigOptions{ - totalMaxMemoryResponder: defaultTotalMaxMemory, - maxMemoryPerPeerResponder: defaultMaxMemoryPerPeer, - totalMaxMemoryRequestor: defaultTotalMaxMemory, - maxMemoryPerPeerRequestor: defaultMaxMemoryPerPeer, - maxInProgressRequests: defaultMaxInProgressRequests, - registerDefaultValidator: true, + totalMaxMemoryResponder: defaultTotalMaxMemory, + maxMemoryPerPeerResponder: defaultMaxMemoryPerPeer, + totalMaxMemoryRequestor: defaultTotalMaxMemory, + maxMemoryPerPeerRequestor: defaultMaxMemoryPerPeer, + maxInProgressIncomingRequests: defaultMaxInProgressRequests, + maxInProgressOutgoingRequests: defaultMaxInProgressRequests, + registerDefaultValidator: true, } for _, option := range options { option(gsConfig) @@ -164,16 +178,20 @@ func New(parent context.Context, network gsnet.GraphSyncNetwork, requestAllocator := allocator.NewAllocator(gsConfig.totalMaxMemoryRequestor, gsConfig.maxMemoryPerPeerRequestor) asyncLoader := asyncloader.New(ctx, linkSystem, requestAllocator) - requestManager := requestmanager.New(ctx, asyncLoader, linkSystem, outgoingRequestHooks, incomingResponseHooks, incomingBlockHooks, networkErrorListeners) + requestQueue := taskqueue.NewTaskQueue(ctx) + requestManager := requestmanager.New(ctx, asyncLoader, linkSystem, outgoingRequestHooks, incomingResponseHooks, networkErrorListeners, requestQueue) + requestExecutor := executor.NewExecutor(requestManager, incomingBlockHooks, asyncLoader.AsyncLoad) responseAssembler := responseassembler.New(ctx, peerManager) peerTaskQueue := peertaskqueue.New() - responseManager := responsemanager.New(ctx, linkSystem, responseAssembler, peerTaskQueue, requestQueuedHooks, incomingRequestHooks, outgoingBlockHooks, requestUpdatedHooks, completedResponseListeners, requestorCancelledListeners, blockSentListeners, networkErrorListeners, gsConfig.maxInProgressRequests) + responseManager := responsemanager.New(ctx, linkSystem, responseAssembler, peerTaskQueue, requestQueuedHooks, incomingRequestHooks, outgoingBlockHooks, requestUpdatedHooks, completedResponseListeners, requestorCancelledListeners, blockSentListeners, networkErrorListeners, gsConfig.maxInProgressIncomingRequests) graphSync := &GraphSync{ network: network, linkSystem: linkSystem, requestManager: requestManager, responseManager: responseManager, asyncLoader: asyncLoader, + requestQueue: requestQueue, + requestExecutor: requestExecutor, responseAssembler: responseAssembler, peerTaskQueue: peerTaskQueue, peerManager: peerManager, @@ -198,6 +216,7 @@ func New(parent context.Context, network gsnet.GraphSyncNetwork, requestManager.SetDelegate(peerManager) requestManager.Startup() + requestQueue.Startup(gsConfig.maxInProgressOutgoingRequests, requestExecutor) responseManager.Startup() network.SetDelegate((*graphSyncReceiver)(graphSync)) return graphSync diff --git a/impl/graphsync_test.go b/impl/graphsync_test.go index c7177bf2..c975b4e5 100644 --- a/impl/graphsync_test.go +++ b/impl/graphsync_test.go @@ -416,7 +416,7 @@ func TestPauseResumeRequest(t *testing.T) { progressChan, errChan := requestor.Request(ctx, td.host2.ID(), blockChain.TipLink, blockChain.Selector(), td.extension) - blockChain.VerifyResponseRange(ctx, progressChan, 0, stopPoint-1) + blockChain.VerifyResponseRange(ctx, progressChan, 0, stopPoint) timer := time.NewTimer(100 * time.Millisecond) testutil.AssertDoesReceiveFirst(t, timer.C, "should pause request", progressChan) @@ -424,7 +424,7 @@ func TestPauseResumeRequest(t *testing.T) { err := requestor.UnpauseRequest(requestID, td.extensionUpdate) require.NoError(t, err) - blockChain.VerifyRemainder(ctx, progressChan, stopPoint-1) + blockChain.VerifyRemainder(ctx, progressChan, stopPoint) testutil.VerifyEmptyErrors(ctx, t, errChan) require.Len(t, td.blockStore1, blockChainLength, "did not store all blocks") } diff --git a/requestmanager/client.go b/requestmanager/client.go index 43181f36..f72dcf07 100644 --- a/requestmanager/client.go +++ b/requestmanager/client.go @@ -9,19 +9,25 @@ import ( "github.com/hannahhoward/go-pubsub" blocks "github.com/ipfs/go-block-format" + "github.com/ipfs/go-cid" logging "github.com/ipfs/go-log/v2" + "github.com/ipfs/go-peertaskqueue/peertask" "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/traversal" "github.com/ipld/go-ipld-prime/traversal/selector" "github.com/libp2p/go-libp2p-core/peer" "github.com/ipfs/go-graphsync" + "github.com/ipfs/go-graphsync/ipldutil" "github.com/ipfs/go-graphsync/listeners" gsmsg "github.com/ipfs/go-graphsync/message" "github.com/ipfs/go-graphsync/messagequeue" "github.com/ipfs/go-graphsync/metadata" "github.com/ipfs/go-graphsync/notifications" + "github.com/ipfs/go-graphsync/requestmanager/executor" "github.com/ipfs/go-graphsync/requestmanager/hooks" "github.com/ipfs/go-graphsync/requestmanager/types" + "github.com/ipfs/go-graphsync/taskqueue" ) // The code in this file implements the public interface of the request manager. @@ -35,17 +41,30 @@ const ( defaultPriority = graphsync.Priority(0) ) +type state uint64 + +const ( + queued state = iota + running + paused +) + type inProgressRequestStatus struct { - ctx context.Context - startTime time.Time - cancelFn func() - p peer.ID - terminalError chan error - resumeMessages chan []graphsync.ExtensionData - pauseMessages chan struct{} - paused bool - lastResponse atomic.Value - onTerminated []chan<- error + ctx context.Context + startTime time.Time + cancelFn func() + p peer.ID + terminalError error + pauseMessages chan struct{} + state state + lastResponse atomic.Value + onTerminated []chan<- error + request gsmsg.GraphSyncRequest + doNotSendCids *cid.Set + nodeStyleChooser traversal.LinkTargetNodePrototypeChooser + inProgressChan chan graphsync.ResponseProgress + inProgressErr chan error + traverser ipldutil.Traverser } // PeerHandler is an interface that can send requests to peers @@ -81,8 +100,8 @@ type RequestManager struct { inProgressRequestStatuses map[graphsync.RequestID]*inProgressRequestStatus requestHooks RequestHooks responseHooks ResponseHooks - blockHooks BlockHooks networkErrorListeners *listeners.NetworkErrorListeners + requestQueue taskqueue.TaskQueue } type requestManagerMessage interface { @@ -99,19 +118,14 @@ type ResponseHooks interface { ProcessResponseHooks(p peer.ID, response graphsync.ResponseData) hooks.UpdateResult } -// BlockHooks run for each block loaded -type BlockHooks interface { - ProcessBlockHooks(p peer.ID, response graphsync.ResponseData, block graphsync.BlockData) hooks.UpdateResult -} - // New generates a new request manager from a context, network, and selectorQuerier func New(ctx context.Context, asyncLoader AsyncLoader, linkSystem ipld.LinkSystem, requestHooks RequestHooks, responseHooks ResponseHooks, - blockHooks BlockHooks, networkErrorListeners *listeners.NetworkErrorListeners, + requestQueue taskqueue.TaskQueue, ) *RequestManager { ctx, cancel := context.WithCancel(ctx) return &RequestManager{ @@ -125,8 +139,8 @@ func New(ctx context.Context, inProgressRequestStatuses: make(map[graphsync.RequestID]*inProgressRequestStatus), requestHooks: requestHooks, responseHooks: responseHooks, - blockHooks: blockHooks, networkErrorListeners: networkErrorListeners, + requestQueue: requestQueue, } } @@ -227,7 +241,7 @@ func (rm *RequestManager) cancelRequestAndClose(requestID graphsync.RequestID, cancelMessageChannel := rm.messages for cancelMessageChannel != nil || incomingResponses != nil || incomingErrors != nil { select { - case cancelMessageChannel <- &cancelRequestMessage{requestID, false, nil, nil}: + case cancelMessageChannel <- &cancelRequestMessage{requestID, nil, nil}: cancelMessageChannel = nil // clear out any remaining responses, in case and "incoming reponse" // messages get processed before our cancel message @@ -248,7 +262,7 @@ func (rm *RequestManager) cancelRequestAndClose(requestID graphsync.RequestID, // CancelRequest cancels the given request ID and waits for the request to terminate func (rm *RequestManager) CancelRequest(ctx context.Context, requestID graphsync.RequestID) error { terminated := make(chan error, 1) - rm.send(&cancelRequestMessage{requestID, false, terminated, graphsync.RequestClientCancelledErr{}}, ctx.Done()) + rm.send(&cancelRequestMessage{requestID, terminated, graphsync.RequestClientCancelledErr{}}, ctx.Done()) select { case <-rm.ctx.Done(): return errors.New("context cancelled") @@ -289,24 +303,14 @@ func (rm *RequestManager) PauseRequest(requestID graphsync.RequestID) error { } } -// ProcessBlockHooks processes block hooks for the given response & block and cancels -// the request as needed -func (rm *RequestManager) ProcessBlockHooks(p peer.ID, response graphsync.ResponseData, block graphsync.BlockData) error { - result := rm.blockHooks.ProcessBlockHooks(p, response, block) - if len(result.Extensions) > 0 { - updateRequest := gsmsg.UpdateRequest(response.RequestID(), result.Extensions...) - rm.SendRequest(p, updateRequest) - } - if result.Err != nil { - _, isPause := result.Err.(hooks.ErrPaused) - rm.send(&cancelRequestMessage{response.RequestID(), isPause, nil, nil}, nil) - } - return result.Err +// GetRequestTask gets data for the given task in the request queue +func (rm *RequestManager) GetRequestTask(p peer.ID, task *peertask.Task, requestExecutionChan chan executor.RequestTask) { + rm.send(&getRequestTaskMessage{p, task, requestExecutionChan}, nil) } -// TerminateRequest marks a request done -func (rm *RequestManager) TerminateRequest(requestID graphsync.RequestID) { - rm.send(&terminateRequestMessage{requestID}, nil) +// ReleaseRequestTask releases a task request the requestQueue +func (rm *RequestManager) ReleaseRequestTask(p peer.ID, task *peertask.Task, err error) { + rm.send(&releaseRequestTaskMessage{p, task, err}, nil) } // SendRequest sends a request to the message queue diff --git a/requestmanager/executor/executor.go b/requestmanager/executor/executor.go index d79291e0..3158004a 100644 --- a/requestmanager/executor/executor.go +++ b/requestmanager/executor/executor.go @@ -7,6 +7,7 @@ import ( "sync/atomic" "github.com/ipfs/go-cid" + "github.com/ipfs/go-peertaskqueue/peertask" ipld "github.com/ipld/go-ipld-prime" cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/ipld/go-ipld-prime/traversal" @@ -18,183 +19,180 @@ import ( gsmsg "github.com/ipfs/go-graphsync/message" "github.com/ipfs/go-graphsync/requestmanager/hooks" "github.com/ipfs/go-graphsync/requestmanager/types" + logging "github.com/ipfs/go-log/v2" ) +var log = logging.Logger("gs_request_executor") + +// Manager is an interface the Executor uses to interact with the request manager +type Manager interface { + SendRequest(peer.ID, gsmsg.GraphSyncRequest) + GetRequestTask(peer.ID, *peertask.Task, chan RequestTask) + ReleaseRequestTask(peer.ID, *peertask.Task, error) +} + +// BlockHooks run for each block loaded +type BlockHooks interface { + ProcessBlockHooks(p peer.ID, response graphsync.ResponseData, block graphsync.BlockData) hooks.UpdateResult +} + // AsyncLoadFn is a function which given a request id and an ipld.Link, returns // a channel which will eventually return data for the link or an err type AsyncLoadFn func(peer.ID, graphsync.RequestID, ipld.Link, ipld.LinkContext) <-chan types.AsyncLoadResult -// ExecutionEnv are request parameters that last between requests -type ExecutionEnv struct { - Ctx context.Context - SendRequest func(peer.ID, gsmsg.GraphSyncRequest) - RunBlockHooks func(p peer.ID, response graphsync.ResponseData, blk graphsync.BlockData) error - TerminateRequest func(graphsync.RequestID) - WaitForMessages func(ctx context.Context, resumeMessages chan graphsync.ExtensionData) ([]graphsync.ExtensionData, error) - Loader AsyncLoadFn - LinkSystem ipld.LinkSystem +// Executor handles actually executing graphsync requests and verifying them. +// It has control of requests when they are in the "running" state, while +// the manager is in charge when requests are queued or paused +type Executor struct { + manager Manager + blockHooks BlockHooks + loader AsyncLoadFn } -// RequestExecution are parameters for a single request execution -type RequestExecution struct { - Ctx context.Context - P peer.ID - TerminalError chan error - Request gsmsg.GraphSyncRequest - LastResponse *atomic.Value - DoNotSendCids *cid.Set - NodePrototypeChooser traversal.LinkTargetNodePrototypeChooser - ResumeMessages chan []graphsync.ExtensionData - PauseMessages chan struct{} +// NewExecutor returns a new executor +func NewExecutor( + manager Manager, + blockHooks BlockHooks, + loader AsyncLoadFn) *Executor { + return &Executor{ + manager: manager, + blockHooks: blockHooks, + loader: loader, + } } -// Start begins execution of a request in a go routine -func (ee ExecutionEnv) Start(re RequestExecution) (chan graphsync.ResponseProgress, chan error) { - executor := &requestExecutor{ - inProgressChan: make(chan graphsync.ResponseProgress), - inProgressErr: make(chan error), - ctx: re.Ctx, - p: re.P, - terminalError: re.TerminalError, - request: re.Request, - lastResponse: re.LastResponse, - doNotSendCids: re.DoNotSendCids, - nodeStyleChooser: re.NodePrototypeChooser, - resumeMessages: re.ResumeMessages, - pauseMessages: re.PauseMessages, - env: ee, +func (e *Executor) ExecuteTask(ctx context.Context, pid peer.ID, task *peertask.Task) bool { + requestTaskChan := make(chan RequestTask) + var requestTask RequestTask + e.manager.GetRequestTask(pid, task, requestTaskChan) + select { + case requestTask = <-requestTaskChan: + case <-ctx.Done(): + return true + } + if requestTask.Empty { + log.Info("Empty task on peer request stack") + return false + } + log.Debugw("beginning request execution", "id", requestTask.Request.ID(), "peer", pid.String(), "root_cid", requestTask.Request.Root().String()) + err := e.traverse(requestTask) + if err != nil && !isContextErr(err) { + e.manager.SendRequest(requestTask.P, gsmsg.CancelRequest(requestTask.Request.ID())) + if !isPausedErr(err) { + select { + case <-requestTask.Ctx.Done(): + case requestTask.InProgressErr <- err: + } + } } - executor.sendRequest(executor.request) - go executor.run() - return executor.inProgressChan, executor.inProgressErr + e.manager.ReleaseRequestTask(pid, task, err) + log.Debugw("finishing response execution", "id", requestTask.Request.ID(), "peer", pid.String(), "root_cid", requestTask.Request.Root().String()) + return false } -type requestExecutor struct { - inProgressChan chan graphsync.ResponseProgress - inProgressErr chan error - ctx context.Context - p peer.ID - terminalError chan error - request gsmsg.GraphSyncRequest - lastResponse *atomic.Value - nodeStyleChooser traversal.LinkTargetNodePrototypeChooser - resumeMessages chan []graphsync.ExtensionData - pauseMessages chan struct{} - doNotSendCids *cid.Set - env ExecutionEnv - restartNeeded bool - pendingExtensions []graphsync.ExtensionData +// RequestTask are parameters for a single request execution +type RequestTask struct { + Ctx context.Context + Request gsmsg.GraphSyncRequest + LastResponse *atomic.Value + DoNotSendCids *cid.Set + PauseMessages <-chan struct{} + Traverser ipldutil.Traverser + P peer.ID + InProgressErr chan error + Empty bool + InitialRequest bool } -func (re *requestExecutor) visitor(tp traversal.Progress, node ipld.Node, tr traversal.VisitReason) error { - select { - case <-re.ctx.Done(): - case re.inProgressChan <- graphsync.ResponseProgress{ - Node: node, - Path: tp.Path, - LastBlock: tp.LastBlock, - }: +func (e *Executor) traverse(rt RequestTask) error { + onlyOnce := &onlyOnce{e, rt, false} + // for initial request, start remote right away + if rt.InitialRequest { + if err := onlyOnce.startRemoteRequest(); err != nil { + return err + } } - return nil -} - -func (re *requestExecutor) traverse() error { - traverser := ipldutil.TraversalBuilder{ - Root: cidlink.Link{Cid: re.request.Root()}, - Selector: re.request.Selector(), - Visitor: re.visitor, - Chooser: re.nodeStyleChooser, - LinkSystem: re.env.LinkSystem, - }.Start(re.ctx) - defer traverser.Shutdown(context.Background()) for { - isComplete, err := traverser.IsComplete() + // check if traversal is complete + isComplete, err := rt.Traverser.IsComplete() if isComplete { return err } - lnk, linkContext := traverser.CurrentRequest() - resultChan := re.env.Loader(re.p, re.request.ID(), lnk, linkContext) + // get current link request + lnk, linkContext := rt.Traverser.CurrentRequest() + // attempt to load + log.Debugf("will load link=%s", lnk) + resultChan := e.loader(rt.P, rt.Request.ID(), lnk, linkContext) var result types.AsyncLoadResult + // check for immediate result select { case result = <-resultChan: default: - err := re.sendRestartAsNeeded() - if err != nil { + // if no immediate result + // initiate remote request if not already sent (we want to fill out the doNotSendCids on a resume) + if err := onlyOnce.startRemoteRequest(); err != nil { return err } + // wait for block result select { - case <-re.ctx.Done(): + case <-rt.Ctx.Done(): return ipldutil.ContextCancelError{} case result = <-resultChan: } } - err = re.processResult(traverser, lnk, result) - if _, ok := err.(hooks.ErrPaused); ok { - err = re.waitForResume() - if err != nil { - return err - } - err = traverser.Advance(bytes.NewBuffer(result.Data)) - if err != nil { - return err - } - } else if err != nil { + log.Debugf("successfully loaded link=%s, nBlocksRead=%d", lnk, rt.Traverser.NBlocksTraversed()) + // advance the traversal based on results + err = e.advanceTraversal(rt, result) + if err != nil { return err } - } -} -func (re *requestExecutor) run() { - err := re.traverse() - if err != nil { - if !isContextErr(err) { - select { - case <-re.ctx.Done(): - case re.inProgressErr <- err: - } - } - } - select { - case terminalError := <-re.terminalError: - select { - case re.inProgressErr <- terminalError: - case <-re.env.Ctx.Done(): + // check for interrupts and run block hooks + err = e.processResult(rt, lnk, result) + if err != nil { + return err } - default: } - re.terminateRequest() - close(re.inProgressChan) - close(re.inProgressErr) -} - -func (re *requestExecutor) sendRequest(request gsmsg.GraphSyncRequest) { - re.env.SendRequest(re.p, request) } -func (re *requestExecutor) terminateRequest() { - re.env.TerminateRequest(re.request.ID()) +func (e *Executor) processBlockHooks(p peer.ID, response graphsync.ResponseData, block graphsync.BlockData) error { + result := e.blockHooks.ProcessBlockHooks(p, response, block) + if len(result.Extensions) > 0 { + updateRequest := gsmsg.UpdateRequest(response.RequestID(), result.Extensions...) + e.manager.SendRequest(p, updateRequest) + } + return result.Err } -func (re *requestExecutor) runBlockHooks(blk graphsync.BlockData) error { - response := re.lastResponse.Load().(gsmsg.GraphSyncResponse) - return re.env.RunBlockHooks(re.p, response, blk) +func (e *Executor) onNewBlock(rt RequestTask, block graphsync.BlockData) error { + rt.DoNotSendCids.Add(block.Link().(cidlink.Link).Cid) + response := rt.LastResponse.Load().(gsmsg.GraphSyncResponse) + return e.processBlockHooks(rt.P, response, block) } -func (re *requestExecutor) waitForResume() error { - select { - case <-re.ctx.Done(): - return ipldutil.ContextCancelError{} - case re.pendingExtensions = <-re.resumeMessages: - re.restartNeeded = true - return nil +func (e *Executor) advanceTraversal(rt RequestTask, result types.AsyncLoadResult) error { + if result.Err != nil { + // before processing result check for context cancel to avoid sending an additional error + select { + case <-rt.Ctx.Done(): + return ipldutil.ContextCancelError{} + default: + } + select { + case <-rt.Ctx.Done(): + return ipldutil.ContextCancelError{} + case rt.InProgressErr <- result.Err: + rt.Traverser.Error(traversal.SkipMe{}) + return nil + } } + return rt.Traverser.Advance(bytes.NewBuffer(result.Data)) } -func (re *requestExecutor) onNewBlockWithPause(block graphsync.BlockData) error { - err := re.onNewBlock(block) +func (e *Executor) processResult(rt RequestTask, link ipld.Link, result types.AsyncLoadResult) error { + err := e.onNewBlock(rt, &blockData{link, result.Local, uint64(len(result.Data))}) select { - case <-re.pauseMessages: - re.sendRequest(gsmsg.CancelRequest(re.request.ID())) + case <-rt.PauseMessages: if err == nil { err = hooks.ErrPaused{} } @@ -203,46 +201,17 @@ func (re *requestExecutor) onNewBlockWithPause(block graphsync.BlockData) error return err } -func (re *requestExecutor) onNewBlock(block graphsync.BlockData) error { - re.doNotSendCids.Add(block.Link().(cidlink.Link).Cid) - return re.runBlockHooks(block) -} - -func (re *requestExecutor) processResult(traverser ipldutil.Traverser, link ipld.Link, result types.AsyncLoadResult) error { - if result.Err != nil { - select { - case <-re.ctx.Done(): - return ipldutil.ContextCancelError{} - case re.inProgressErr <- result.Err: - traverser.Error(traversal.SkipMe{}) - return nil +func (e *Executor) startRemoteRequest(rt RequestTask) error { + request := rt.Request + if rt.DoNotSendCids.Len() > 0 { + cidsData, err := cidset.EncodeCidSet(rt.DoNotSendCids) + if err != nil { + return err } + request = rt.Request.ReplaceExtensions([]graphsync.ExtensionData{{Name: graphsync.ExtensionDoNotSendCIDs, Data: cidsData}}) } - err := re.onNewBlockWithPause(&blockData{link, result.Local, uint64(len(result.Data))}) - if err != nil { - return err - } - err = traverser.Advance(bytes.NewBuffer(result.Data)) - if err != nil { - return err - } - return nil -} - -func (re *requestExecutor) sendRestartAsNeeded() error { - if !re.restartNeeded { - return nil - } - extensions := re.pendingExtensions - re.pendingExtensions = nil - re.restartNeeded = false - cidsData, err := cidset.EncodeCidSet(re.doNotSendCids) - if err != nil { - return err - } - extensions = append(extensions, graphsync.ExtensionData{Name: graphsync.ExtensionDoNotSendCIDs, Data: cidsData}) - re.request = re.request.ReplaceExtensions(extensions) - re.sendRequest(re.request) + log.Debugw("starting remote request", "id", rt.Request.ID(), "peer", rt.P.String(), "root_cid", rt.Request.Root().String()) + e.manager.SendRequest(rt.P, request) return nil } @@ -251,6 +220,25 @@ func isContextErr(err error) bool { return strings.Contains(err.Error(), ipldutil.ContextCancelError{}.Error()) } +func isPausedErr(err error) bool { + _, isPaused := err.(hooks.ErrPaused) + return isPaused +} + +type onlyOnce struct { + e *Executor + rt RequestTask + requestSent bool +} + +func (so *onlyOnce) startRemoteRequest() error { + if so.requestSent { + return nil + } + so.requestSent = true + return so.e.startRemoteRequest(so.rt) +} + type blockData struct { link ipld.Link local bool diff --git a/requestmanager/executor/executor_test.go b/requestmanager/executor/executor_test.go index 7b3155b4..4b59067f 100644 --- a/requestmanager/executor/executor_test.go +++ b/requestmanager/executor/executor_test.go @@ -9,9 +9,10 @@ import ( "time" "github.com/ipfs/go-cid" + "github.com/ipfs/go-peertaskqueue/peertask" "github.com/ipld/go-ipld-prime" cidlink "github.com/ipld/go-ipld-prime/linking/cid" - basicnode "github.com/ipld/go-ipld-prime/node/basic" + "github.com/ipld/go-ipld-prime/traversal" peer "github.com/libp2p/go-libp2p-core/peer" "github.com/stretchr/testify/require" @@ -38,219 +39,135 @@ func TestRequestExecutionBlockChain(t *testing.T) { verifyResults: func(t *testing.T, tbc *testutil.TestBlockChain, ree *requestExecutionEnv, responses []graphsync.ResponseProgress, receivedErrors []error) { tbc.VerifyWholeChainSync(responses) require.Empty(t, receivedErrors) - require.Equal(t, 0, ree.currentWaitForResumeResult) require.Equal(t, []requestSent{{ree.p, ree.request}}, ree.requestsSent) require.Len(t, ree.blookHooksCalled, 10) - require.Equal(t, ree.request.ID(), ree.terminateRequested) - require.True(t, ree.nodeStyleChooserCalled) + require.NoError(t, ree.terminalError) }, }, "error at block hook": { configureRequestExecution: func(p peer.ID, requestID graphsync.RequestID, tbc *testutil.TestBlockChain, ree *requestExecutionEnv) { - ree.blockHookResults[blockHookKey{p, requestID, tbc.LinkTipIndex(5)}] = errors.New("something went wrong") + ree.blockHookResults[blockHookKey{p, requestID, tbc.LinkTipIndex(5)}] = hooks.UpdateResult{Err: errors.New("something went wrong")} }, verifyResults: func(t *testing.T, tbc *testutil.TestBlockChain, ree *requestExecutionEnv, responses []graphsync.ResponseProgress, receivedErrors []error) { - tbc.VerifyResponseRangeSync(responses, 0, 5) + tbc.VerifyResponseRangeSync(responses, 0, 6) require.Len(t, receivedErrors, 1) require.Regexp(t, "something went wrong", receivedErrors[0].Error()) - require.Equal(t, 0, ree.currentWaitForResumeResult) - require.Equal(t, []requestSent{{ree.p, ree.request}}, ree.requestsSent) + require.Len(t, ree.requestsSent, 2) + require.Equal(t, ree.request, ree.requestsSent[0].request) + require.True(t, ree.requestsSent[1].request.IsCancel()) require.Len(t, ree.blookHooksCalled, 6) - require.Equal(t, ree.request.ID(), ree.terminateRequested) - require.True(t, ree.nodeStyleChooserCalled) + require.EqualError(t, ree.terminalError, "something went wrong") }, }, "context cancelled": { configureRequestExecution: func(p peer.ID, requestID graphsync.RequestID, tbc *testutil.TestBlockChain, ree *requestExecutionEnv) { - ree.blockHookResults[blockHookKey{p, requestID, tbc.LinkTipIndex(5)}] = ipldutil.ContextCancelError{} + ree.blockHookResults[blockHookKey{p, requestID, tbc.LinkTipIndex(5)}] = hooks.UpdateResult{Err: ipldutil.ContextCancelError{}} }, verifyResults: func(t *testing.T, tbc *testutil.TestBlockChain, ree *requestExecutionEnv, responses []graphsync.ResponseProgress, receivedErrors []error) { - tbc.VerifyResponseRangeSync(responses, 0, 5) + tbc.VerifyResponseRangeSync(responses, 0, 6) require.Empty(t, receivedErrors) - require.Equal(t, 0, ree.currentWaitForResumeResult) require.Equal(t, []requestSent{{ree.p, ree.request}}, ree.requestsSent) require.Len(t, ree.blookHooksCalled, 6) - require.Equal(t, ree.request.ID(), ree.terminateRequested) - require.True(t, ree.nodeStyleChooserCalled) + require.EqualError(t, ree.terminalError, ipldutil.ContextCancelError{}.Error()) }, }, "simple pause": { configureRequestExecution: func(p peer.ID, requestID graphsync.RequestID, tbc *testutil.TestBlockChain, ree *requestExecutionEnv) { - ree.blockHookResults[blockHookKey{p, requestID, tbc.LinkTipIndex(5)}] = hooks.ErrPaused{} - ree.waitForResumeResults = append(ree.waitForResumeResults, nil) - ree.loaderRanges = [][2]int{{0, 6}, {6, 10}} + ree.blockHookResults[blockHookKey{p, requestID, tbc.LinkTipIndex(5)}] = hooks.UpdateResult{Err: hooks.ErrPaused{}} }, verifyResults: func(t *testing.T, tbc *testutil.TestBlockChain, ree *requestExecutionEnv, responses []graphsync.ResponseProgress, receivedErrors []error) { - tbc.VerifyWholeChainSync(responses) + tbc.VerifyResponseRangeSync(responses, 0, 6) require.Empty(t, receivedErrors) - require.Equal(t, 1, ree.currentWaitForResumeResult) + require.Len(t, ree.requestsSent, 2) require.Equal(t, ree.request, ree.requestsSent[0].request) - doNotSendCidsExt, has := ree.requestsSent[1].request.Extension(graphsync.ExtensionDoNotSendCIDs) - require.True(t, has) - cidSet, err := cidset.DecodeCidSet(doNotSendCidsExt) - require.NoError(t, err) - require.Equal(t, 6, cidSet.Len()) - require.Len(t, ree.blookHooksCalled, 10) - require.Equal(t, ree.request.ID(), ree.terminateRequested) - require.True(t, ree.nodeStyleChooserCalled) + require.True(t, ree.requestsSent[1].request.IsCancel()) + require.Len(t, ree.blookHooksCalled, 6) + require.EqualError(t, ree.terminalError, hooks.ErrPaused{}.Error()) }, }, - "multiple pause": { + "preexisting do not send cids": { configureRequestExecution: func(p peer.ID, requestID graphsync.RequestID, tbc *testutil.TestBlockChain, ree *requestExecutionEnv) { - ree.blockHookResults[blockHookKey{p, requestID, tbc.LinkTipIndex(5)}] = hooks.ErrPaused{} - ree.blockHookResults[blockHookKey{p, requestID, tbc.LinkTipIndex(7)}] = hooks.ErrPaused{} - ree.waitForResumeResults = append(ree.waitForResumeResults, nil, nil) - ree.loaderRanges = [][2]int{{0, 6}, {6, 8}, {8, 10}} + ree.doNotSendCids.Add(tbc.GenisisLink.(cidlink.Link).Cid) }, verifyResults: func(t *testing.T, tbc *testutil.TestBlockChain, ree *requestExecutionEnv, responses []graphsync.ResponseProgress, receivedErrors []error) { tbc.VerifyWholeChainSync(responses) require.Empty(t, receivedErrors) - require.Equal(t, 2, ree.currentWaitForResumeResult) - require.Equal(t, ree.request, ree.requestsSent[0].request) - doNotSendCidsExt, has := ree.requestsSent[1].request.Extension(graphsync.ExtensionDoNotSendCIDs) + require.Equal(t, ree.request.ID(), ree.requestsSent[0].request.ID()) + require.Equal(t, ree.request.Root(), ree.requestsSent[0].request.Root()) + require.Equal(t, ree.request.Selector(), ree.requestsSent[0].request.Selector()) + doNotSendCidsExt, has := ree.requestsSent[0].request.Extension(graphsync.ExtensionDoNotSendCIDs) require.True(t, has) cidSet, err := cidset.DecodeCidSet(doNotSendCidsExt) require.NoError(t, err) - require.Equal(t, 6, cidSet.Len()) - doNotSendCidsExt, has = ree.requestsSent[2].request.Extension(graphsync.ExtensionDoNotSendCIDs) - require.True(t, has) - cidSet, err = cidset.DecodeCidSet(doNotSendCidsExt) - require.NoError(t, err) - require.Equal(t, 8, cidSet.Len()) + require.Equal(t, 1, cidSet.Len()) require.Len(t, ree.blookHooksCalled, 10) - require.Equal(t, ree.request.ID(), ree.terminateRequested) - require.True(t, ree.nodeStyleChooserCalled) + require.NoError(t, ree.terminalError) }, }, - "multiple pause with extensions": { + "pause externally": { configureRequestExecution: func(p peer.ID, requestID graphsync.RequestID, tbc *testutil.TestBlockChain, ree *requestExecutionEnv) { - ree.blockHookResults[blockHookKey{p, requestID, tbc.LinkTipIndex(5)}] = hooks.ErrPaused{} - ree.blockHookResults[blockHookKey{p, requestID, tbc.LinkTipIndex(7)}] = hooks.ErrPaused{} - ree.waitForResumeResults = append(ree.waitForResumeResults, []graphsync.ExtensionData{ - { - Name: graphsync.ExtensionName("applesauce"), - Data: []byte("cheese 1"), - }, - }, []graphsync.ExtensionData{ - { - Name: graphsync.ExtensionName("applesauce"), - Data: []byte("cheese 2"), - }, - }) - ree.loaderRanges = [][2]int{{0, 6}, {6, 8}, {8, 10}} + ree.externalPause = pauseKey{requestID, tbc.LinkTipIndex(5)} }, verifyResults: func(t *testing.T, tbc *testutil.TestBlockChain, ree *requestExecutionEnv, responses []graphsync.ResponseProgress, receivedErrors []error) { - tbc.VerifyWholeChainSync(responses) + tbc.VerifyResponseRangeSync(responses, 0, 6) require.Empty(t, receivedErrors) - require.Equal(t, 2, ree.currentWaitForResumeResult) require.Equal(t, ree.request, ree.requestsSent[0].request) - testExtData, has := ree.requestsSent[1].request.Extension(graphsync.ExtensionName("applesauce")) - require.True(t, has) - require.Equal(t, "cheese 1", string(testExtData)) - doNotSendCidsExt, has := ree.requestsSent[1].request.Extension(graphsync.ExtensionDoNotSendCIDs) - require.True(t, has) - cidSet, err := cidset.DecodeCidSet(doNotSendCidsExt) - require.NoError(t, err) - require.Equal(t, 6, cidSet.Len()) - testExtData, has = ree.requestsSent[2].request.Extension(graphsync.ExtensionName("applesauce")) - require.True(t, has) - require.Equal(t, "cheese 2", string(testExtData)) - doNotSendCidsExt, has = ree.requestsSent[2].request.Extension(graphsync.ExtensionDoNotSendCIDs) - require.True(t, has) - cidSet, err = cidset.DecodeCidSet(doNotSendCidsExt) - require.NoError(t, err) - require.Equal(t, 8, cidSet.Len()) - require.Len(t, ree.blookHooksCalled, 10) - require.Equal(t, ree.request.ID(), ree.terminateRequested) - require.True(t, ree.nodeStyleChooserCalled) + require.True(t, ree.requestsSent[1].request.IsCancel()) + require.Len(t, ree.blookHooksCalled, 6) + require.EqualError(t, ree.terminalError, hooks.ErrPaused{}.Error()) }, }, - "preexisting do not send cids": { + "resume request": { configureRequestExecution: func(p peer.ID, requestID graphsync.RequestID, tbc *testutil.TestBlockChain, ree *requestExecutionEnv) { - ree.doNotSendCids.Add(tbc.GenisisLink.(cidlink.Link).Cid) - ree.blockHookResults[blockHookKey{p, requestID, tbc.LinkTipIndex(5)}] = hooks.ErrPaused{} - ree.waitForResumeResults = append(ree.waitForResumeResults, nil) - ree.loaderRanges = [][2]int{{0, 6}, {6, 10}} + ree.initialRequest = false + ree.loadLocallyUntil = 6 }, verifyResults: func(t *testing.T, tbc *testutil.TestBlockChain, ree *requestExecutionEnv, responses []graphsync.ResponseProgress, receivedErrors []error) { tbc.VerifyWholeChainSync(responses) require.Empty(t, receivedErrors) - require.Equal(t, 1, ree.currentWaitForResumeResult) - require.Equal(t, ree.request, ree.requestsSent[0].request) - doNotSendCidsExt, has := ree.requestsSent[1].request.Extension(graphsync.ExtensionDoNotSendCIDs) + require.Equal(t, ree.request.ID(), ree.requestsSent[0].request.ID()) + require.Equal(t, ree.request.Root(), ree.requestsSent[0].request.Root()) + require.Equal(t, ree.request.Selector(), ree.requestsSent[0].request.Selector()) + doNotSendCidsExt, has := ree.requestsSent[0].request.Extension(graphsync.ExtensionDoNotSendCIDs) require.True(t, has) cidSet, err := cidset.DecodeCidSet(doNotSendCidsExt) require.NoError(t, err) - require.Equal(t, 7, cidSet.Len()) + require.Equal(t, 6, cidSet.Len()) require.Len(t, ree.blookHooksCalled, 10) - require.Equal(t, ree.request.ID(), ree.terminateRequested) - require.True(t, ree.nodeStyleChooserCalled) + require.NoError(t, ree.terminalError) }, }, - "pause but request is cancelled": { + "error at block hook has precedence over pause": { configureRequestExecution: func(p peer.ID, requestID graphsync.RequestID, tbc *testutil.TestBlockChain, ree *requestExecutionEnv) { - ree.blockHookResults[blockHookKey{p, requestID, tbc.LinkTipIndex(5)}] = hooks.ErrPaused{} + ree.blockHookResults[blockHookKey{p, requestID, tbc.LinkTipIndex(5)}] = hooks.UpdateResult{Err: errors.New("something went wrong")} + ree.externalPause = pauseKey{requestID, tbc.LinkTipIndex(5)} }, verifyResults: func(t *testing.T, tbc *testutil.TestBlockChain, ree *requestExecutionEnv, responses []graphsync.ResponseProgress, receivedErrors []error) { - tbc.VerifyResponseRangeSync(responses, 0, 5) - require.Empty(t, receivedErrors) - require.Equal(t, 0, ree.currentWaitForResumeResult) - require.Equal(t, []requestSent{{ree.p, ree.request}}, ree.requestsSent) - require.Len(t, ree.blookHooksCalled, 6) - require.Equal(t, ree.request.ID(), ree.terminateRequested) - require.True(t, ree.nodeStyleChooserCalled) - }, - }, - "pause externally": { - configureRequestExecution: func(p peer.ID, requestID graphsync.RequestID, tbc *testutil.TestBlockChain, ree *requestExecutionEnv) { - ree.externalPauses = append(ree.externalPauses, pauseKey{requestID, tbc.LinkTipIndex(5)}) - ree.waitForResumeResults = append(ree.waitForResumeResults, nil) - ree.loaderRanges = [][2]int{{0, 6}, {6, 10}} - }, - verifyResults: func(t *testing.T, tbc *testutil.TestBlockChain, ree *requestExecutionEnv, responses []graphsync.ResponseProgress, receivedErrors []error) { - tbc.VerifyWholeChainSync(responses) - require.Empty(t, receivedErrors) - require.Equal(t, 1, ree.currentPauseResult) - require.Equal(t, 1, ree.currentWaitForResumeResult) + tbc.VerifyResponseRangeSync(responses, 0, 6) + require.Len(t, receivedErrors, 1) + require.Regexp(t, "something went wrong", receivedErrors[0].Error()) + require.Len(t, ree.requestsSent, 2) require.Equal(t, ree.request, ree.requestsSent[0].request) require.True(t, ree.requestsSent[1].request.IsCancel()) - doNotSendCidsExt, has := ree.requestsSent[2].request.Extension(graphsync.ExtensionDoNotSendCIDs) - require.True(t, has) - cidSet, err := cidset.DecodeCidSet(doNotSendCidsExt) - require.NoError(t, err) - require.Equal(t, 6, cidSet.Len()) - require.Len(t, ree.blookHooksCalled, 10) - require.Equal(t, ree.request.ID(), ree.terminateRequested) - require.True(t, ree.nodeStyleChooserCalled) + require.Len(t, ree.blookHooksCalled, 6) + require.EqualError(t, ree.terminalError, "something went wrong") }, }, - "pause externally multiple": { + "sending updates": { configureRequestExecution: func(p peer.ID, requestID graphsync.RequestID, tbc *testutil.TestBlockChain, ree *requestExecutionEnv) { - ree.externalPauses = append(ree.externalPauses, pauseKey{requestID, tbc.LinkTipIndex(5)}, pauseKey{requestID, tbc.LinkTipIndex(7)}) - ree.waitForResumeResults = append(ree.waitForResumeResults, nil, nil) - ree.loaderRanges = [][2]int{{0, 6}, {6, 8}, {8, 10}} + ree.blockHookResults[blockHookKey{p, requestID, tbc.LinkTipIndex(5)}] = hooks.UpdateResult{Extensions: []graphsync.ExtensionData{{Name: "something", Data: []byte("applesauce")}}} }, verifyResults: func(t *testing.T, tbc *testutil.TestBlockChain, ree *requestExecutionEnv, responses []graphsync.ResponseProgress, receivedErrors []error) { tbc.VerifyWholeChainSync(responses) require.Empty(t, receivedErrors) - require.Equal(t, 2, ree.currentPauseResult) - require.Equal(t, 2, ree.currentWaitForResumeResult) + require.Len(t, ree.requestsSent, 2) require.Equal(t, ree.request, ree.requestsSent[0].request) - require.True(t, ree.requestsSent[1].request.IsCancel()) - doNotSendCidsExt, has := ree.requestsSent[2].request.Extension(graphsync.ExtensionDoNotSendCIDs) - require.True(t, has) - cidSet, err := cidset.DecodeCidSet(doNotSendCidsExt) - require.NoError(t, err) - require.Equal(t, 6, cidSet.Len()) - require.True(t, ree.requestsSent[3].request.IsCancel()) - doNotSendCidsExt, has = ree.requestsSent[4].request.Extension(graphsync.ExtensionDoNotSendCIDs) + require.True(t, ree.requestsSent[1].request.IsUpdate()) + data, has := ree.requestsSent[1].request.Extension("something") require.True(t, has) - cidSet, err = cidset.DecodeCidSet(doNotSendCidsExt) - require.NoError(t, err) - require.Equal(t, 8, cidSet.Len()) + require.Equal(t, string(data), "applesauce") require.Len(t, ree.blookHooksCalled, 10) - require.Equal(t, ree.request.ID(), ree.terminateRequested) - require.True(t, ree.nodeStyleChooserCalled) + require.NoError(t, ree.terminalError) }, }, } @@ -271,50 +188,57 @@ func TestRequestExecutionBlockChain(t *testing.T) { } } requestCtx, requestCancel := context.WithCancel(ctx) + defer requestCancel() + var responsesReceived []graphsync.ResponseProgress ree := &requestExecutionEnv{ ctx: requestCtx, - cancelFn: requestCancel, p: p, - resumeMessages: make(chan []graphsync.ExtensionData, 1), pauseMessages: make(chan struct{}, 1), - blockHookResults: make(map[blockHookKey]error), + blockHookResults: make(map[blockHookKey]hooks.UpdateResult), doNotSendCids: cid.NewSet(), request: gsmsg.NewRequest(requestID, tbc.TipLink.(cidlink.Link).Cid, tbc.Selector(), graphsync.Priority(rand.Int31())), fal: fal, tbc: tbc, configureLoader: configureLoader, + initialRequest: true, + inProgressErr: make(chan error, 1), + traverser: ipldutil.TraversalBuilder{ + Root: tbc.TipLink, + Selector: tbc.Selector(), + Visitor: func(tp traversal.Progress, node ipld.Node, tr traversal.VisitReason) error { + responsesReceived = append(responsesReceived, graphsync.ResponseProgress{ + Node: node, + Path: tp.Path, + LastBlock: tp.LastBlock, + }) + return nil + }, + }.Start(requestCtx), } fal.OnAsyncLoad(ree.checkPause) if data.configureRequestExecution != nil { data.configureRequestExecution(p, requestID, tbc, ree) } - if len(ree.loaderRanges) == 0 { - ree.loaderRanges = [][2]int{{0, 10}} - } - inProgress, inProgressErr := ree.requestExecution() - var responsesReceived []graphsync.ResponseProgress + ree.configureLoader(p, requestID, tbc, fal, [2]int{0, ree.loadLocallyUntil}) var errorsReceived []error - var inProgressDone, inProgressErrDone bool - for !inProgressDone || !inProgressErrDone { - select { - case response, ok := <-inProgress: - if !ok { - inProgress = nil - inProgressDone = true - } else { - responsesReceived = append(responsesReceived, response) - } - case err, ok := <-inProgressErr: - if !ok { - inProgressErr = nil - inProgressErrDone = true - } else { - errorsReceived = append(errorsReceived, err) + errCollectionErr := make(chan error, 1) + go func() { + for { + select { + case err, ok := <-ree.inProgressErr: + if !ok { + errCollectionErr <- nil + } else { + errorsReceived = append(errorsReceived, err) + } + case <-ctx.Done(): + errCollectionErr <- ctx.Err() } - case <-ctx.Done(): - t.Fatal("did not complete request") } - } + }() + executor.NewExecutor(ree, ree, fal.AsyncLoad).ExecuteTask(ctx, ree.p, &peertask.Task{}) + require.NoError(t, <-errCollectionErr) + ree.traverser.Shutdown(ctx) data.verifyResults(t, tbc, ree, responsesReceived, errorsReceived) }) } @@ -338,25 +262,22 @@ type pauseKey struct { type requestExecutionEnv struct { // params - ctx context.Context - cancelFn func() - request gsmsg.GraphSyncRequest - p peer.ID - blockHookResults map[blockHookKey]error - doNotSendCids *cid.Set - waitForResumeResults [][]graphsync.ExtensionData - resumeMessages chan []graphsync.ExtensionData - pauseMessages chan struct{} - externalPauses []pauseKey - loaderRanges [][2]int + ctx context.Context + request gsmsg.GraphSyncRequest + p peer.ID + blockHookResults map[blockHookKey]hooks.UpdateResult + doNotSendCids *cid.Set + pauseMessages chan struct{} + externalPause pauseKey + loadLocallyUntil int + traverser ipldutil.Traverser + inProgressErr chan error + initialRequest bool // results - currentPauseResult int - currentWaitForResumeResult int - requestsSent []requestSent - blookHooksCalled []blockHookKey - terminateRequested graphsync.RequestID - nodeStyleChooserCalled bool + requestsSent []requestSent + blookHooksCalled []blockHookKey + terminalError error // deps configureLoader configureLoaderFn @@ -364,79 +285,51 @@ type requestExecutionEnv struct { fal *testloader.FakeAsyncLoader } -func (ree *requestExecutionEnv) terminateRequest(requestID graphsync.RequestID) { - ree.terminateRequested = requestID +func (ree *requestExecutionEnv) ReleaseRequestTask(_ peer.ID, _ *peertask.Task, err error) { + ree.terminalError = err + close(ree.inProgressErr) } -func (ree *requestExecutionEnv) waitForResume() ([]graphsync.ExtensionData, error) { - if len(ree.waitForResumeResults) <= ree.currentWaitForResumeResult { - return nil, ipldutil.ContextCancelError{} +func (ree *requestExecutionEnv) GetRequestTask(_ peer.ID, _ *peertask.Task, requestExecutionChan chan executor.RequestTask) { + var lastResponse atomic.Value + lastResponse.Store(gsmsg.NewResponse(ree.request.ID(), graphsync.RequestAcknowledged)) + + requestExecution := executor.RequestTask{ + Ctx: ree.ctx, + Request: ree.request, + LastResponse: &lastResponse, + DoNotSendCids: ree.doNotSendCids, + PauseMessages: ree.pauseMessages, + Traverser: ree.traverser, + P: ree.p, + InProgressErr: ree.inProgressErr, + Empty: false, + InitialRequest: ree.initialRequest, } - extensions := ree.waitForResumeResults[ree.currentWaitForResumeResult] - ree.currentWaitForResumeResult++ - return extensions, nil + go func() { + select { + case <-ree.ctx.Done(): + case requestExecutionChan <- requestExecution: + } + }() } -func (ree *requestExecutionEnv) sendRequest(p peer.ID, request gsmsg.GraphSyncRequest) { +func (ree *requestExecutionEnv) SendRequest(p peer.ID, request gsmsg.GraphSyncRequest) { ree.requestsSent = append(ree.requestsSent, requestSent{p, request}) - if ree.currentWaitForResumeResult < len(ree.loaderRanges) && !request.IsCancel() { - ree.configureLoader(ree.p, ree.request.ID(), ree.tbc, ree.fal, ree.loaderRanges[ree.currentWaitForResumeResult]) + if !request.IsCancel() && !request.IsUpdate() { + ree.configureLoader(ree.p, ree.request.ID(), ree.tbc, ree.fal, [2]int{ree.loadLocallyUntil, len(ree.tbc.AllBlocks())}) } } -func (ree *requestExecutionEnv) nodeStyleChooser(ipld.Link, ipld.LinkContext) (ipld.NodePrototype, error) { - ree.nodeStyleChooserCalled = true - return basicnode.Prototype.Any, nil +func (ree *requestExecutionEnv) ProcessBlockHooks(p peer.ID, response graphsync.ResponseData, blk graphsync.BlockData) hooks.UpdateResult { + bhk := blockHookKey{p, response.RequestID(), blk.Link()} + ree.blookHooksCalled = append(ree.blookHooksCalled, bhk) + return ree.blockHookResults[bhk] } func (ree *requestExecutionEnv) checkPause(requestID graphsync.RequestID, link ipld.Link, result <-chan types.AsyncLoadResult) { - if ree.currentPauseResult >= len(ree.externalPauses) { - return - } - currentPause := ree.externalPauses[ree.currentPauseResult] - if currentPause.link == link && currentPause.requestID == requestID { - ree.currentPauseResult++ + if ree.externalPause.link == link && ree.externalPause.requestID == requestID { + ree.externalPause = pauseKey{} ree.pauseMessages <- struct{}{} - extensions, err := ree.waitForResume() - if err != nil { - ree.cancelFn() - } else { - ree.resumeMessages <- extensions - } } } - -func (ree *requestExecutionEnv) runBlockHooks(p peer.ID, response graphsync.ResponseData, blk graphsync.BlockData) error { - bhk := blockHookKey{p, response.RequestID(), blk.Link()} - ree.blookHooksCalled = append(ree.blookHooksCalled, bhk) - err := ree.blockHookResults[bhk] - if _, ok := err.(hooks.ErrPaused); ok { - extensions, err := ree.waitForResume() - if err != nil { - ree.cancelFn() - } else { - ree.resumeMessages <- extensions - } - } - return err -} - -func (ree *requestExecutionEnv) requestExecution() (chan graphsync.ResponseProgress, chan error) { - var lastResponse atomic.Value - lastResponse.Store(gsmsg.NewResponse(ree.request.ID(), graphsync.RequestAcknowledged)) - return executor.ExecutionEnv{ - SendRequest: ree.sendRequest, - RunBlockHooks: ree.runBlockHooks, - TerminateRequest: ree.terminateRequest, - Loader: ree.fal.AsyncLoad, - }.Start(executor.RequestExecution{ - Ctx: ree.ctx, - P: ree.p, - LastResponse: &lastResponse, - Request: ree.request, - DoNotSendCids: ree.doNotSendCids, - NodePrototypeChooser: ree.nodeStyleChooser, - ResumeMessages: ree.resumeMessages, - PauseMessages: ree.pauseMessages, - }) -} diff --git a/requestmanager/messages.go b/requestmanager/messages.go index b18efe24..812c080b 100644 --- a/requestmanager/messages.go +++ b/requestmanager/messages.go @@ -4,6 +4,8 @@ import ( blocks "github.com/ipfs/go-block-format" "github.com/ipfs/go-graphsync" gsmsg "github.com/ipfs/go-graphsync/message" + "github.com/ipfs/go-graphsync/requestmanager/executor" + "github.com/ipfs/go-peertaskqueue/peertask" "github.com/ipld/go-ipld-prime" "github.com/libp2p/go-libp2p-core/peer" ) @@ -47,21 +49,36 @@ func (prm *processResponseMessage) handle(rm *RequestManager) { type cancelRequestMessage struct { requestID graphsync.RequestID - isPause bool onTerminated chan error terminalError error } func (crm *cancelRequestMessage) handle(rm *RequestManager) { - rm.cancelRequest(crm.requestID, crm.isPause, crm.onTerminated, crm.terminalError) + rm.cancelRequest(crm.requestID, crm.onTerminated, crm.terminalError) } -type terminateRequestMessage struct { - requestID graphsync.RequestID +type getRequestTaskMessage struct { + p peer.ID + task *peertask.Task + requestExecutionChan chan executor.RequestTask } -func (trm *terminateRequestMessage) handle(rm *RequestManager) { - rm.terminateRequest(trm.requestID) +func (irm *getRequestTaskMessage) handle(rm *RequestManager) { + requestExecution := rm.getRequestTask(irm.p, irm.task) + select { + case <-rm.ctx.Done(): + case irm.requestExecutionChan <- requestExecution: + } +} + +type releaseRequestTaskMessage struct { + p peer.ID + task *peertask.Task + err error +} + +func (trm *releaseRequestTaskMessage) handle(rm *RequestManager) { + rm.releaseRequestTask(trm.p, trm.task, trm.err) } type newRequestMessage struct { @@ -75,7 +92,7 @@ type newRequestMessage struct { func (nrm *newRequestMessage) handle(rm *RequestManager) { var ipr inProgressRequest - ipr.request, ipr.incoming, ipr.incomingError = rm.setupRequest(nrm.p, nrm.root, nrm.selector, nrm.extensions) + ipr.request, ipr.incoming, ipr.incomingError = rm.newRequest(nrm.p, nrm.root, nrm.selector, nrm.extensions) ipr.requestID = ipr.request.ID() select { diff --git a/requestmanager/requestmanager_test.go b/requestmanager/requestmanager_test.go index caaadad6..d015353b 100644 --- a/requestmanager/requestmanager_test.go +++ b/requestmanager/requestmanager_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "sort" "testing" "time" @@ -20,9 +21,11 @@ import ( gsmsg "github.com/ipfs/go-graphsync/message" "github.com/ipfs/go-graphsync/metadata" "github.com/ipfs/go-graphsync/notifications" + "github.com/ipfs/go-graphsync/requestmanager/executor" "github.com/ipfs/go-graphsync/requestmanager/hooks" "github.com/ipfs/go-graphsync/requestmanager/testloader" "github.com/ipfs/go-graphsync/requestmanager/types" + "github.com/ipfs/go-graphsync/taskqueue" "github.com/ipfs/go-graphsync/testutil" ) @@ -59,6 +62,11 @@ func readNNetworkRequests(ctx context.Context, testutil.AssertReceive(ctx, t, requestRecordChan, &rr, fmt.Sprintf("did not receive request %d", i)) requestRecords = append(requestRecords, rr) } + // because of the simultaneous request queues it's possible for the requests to go to the network layer out of order + // if the requests are queued at a near identical time + sort.Slice(requestRecords, func(i, j int) bool { + return requestRecords[i].gsr.ID() < requestRecords[j].gsr.ID() + }) return requestRecords } @@ -105,7 +113,9 @@ func TestNormalSimultaneousFetch(t *testing.T) { require.Equal(t, defaultPriority, requestRecords[0].gsr.Priority()) require.Equal(t, defaultPriority, requestRecords[1].gsr.Priority()) + require.Equal(t, td.blockChain.TipLink.String(), requestRecords[0].gsr.Root().String()) require.Equal(t, td.blockChain.Selector(), requestRecords[0].gsr.Selector(), "did not encode selector properly") + require.Equal(t, blockChain2.TipLink.String(), requestRecords[1].gsr.Root().String()) require.Equal(t, blockChain2.Selector(), requestRecords[1].gsr.Selector(), "did not encode selector properly") firstBlocks := append(td.blockChain.AllBlocks(), blockChain2.Blocks(0, 3)...) @@ -165,7 +175,7 @@ func TestNormalSimultaneousFetch(t *testing.T) { func TestCancelRequestInProgress(t *testing.T) { ctx := context.Background() td := newTestData(ctx, t) - requestCtx, cancel := context.WithTimeout(ctx, time.Second) + requestCtx, cancel := context.WithCancel(ctx) defer cancel() requestCtx1, cancel1 := context.WithCancel(requestCtx) requestCtx2, cancel2 := context.WithCancel(requestCtx) @@ -832,7 +842,7 @@ func TestPauseResume(t *testing.T) { require.EqualError(t, err, "request is not paused") close(holdForResumeAttempt) // verify responses sent read ONLY for blocks BEFORE the pause - td.blockChain.VerifyResponseRange(ctx, returnedResponseChan, 0, pauseAt-1) + td.blockChain.VerifyResponseRange(ctx, returnedResponseChan, 0, pauseAt) // wait for the pause to occur <-holdForPause @@ -868,7 +878,7 @@ func TestPauseResume(t *testing.T) { td.fal.SuccessResponseOn(peers[0], rr.gsr.ID(), td.blockChain.AllBlocks()) // verify the correct results are returned, picking up after where there request was paused - td.blockChain.VerifyRemainder(ctx, returnedResponseChan, pauseAt-1) + td.blockChain.VerifyRemainder(ctx, returnedResponseChan, pauseAt) testutil.VerifyEmptyErrors(ctx, t, returnedErrorChan) } func TestPauseResumeExternal(t *testing.T) { @@ -912,7 +922,7 @@ func TestPauseResumeExternal(t *testing.T) { td.requestManager.ProcessResponses(peers[0], responses, td.blockChain.AllBlocks()) td.fal.SuccessResponseOn(peers[0], rr.gsr.ID(), td.blockChain.AllBlocks()) // verify responses sent read ONLY for blocks BEFORE the pause - td.blockChain.VerifyResponseRange(ctx, returnedResponseChan, 0, pauseAt-1) + td.blockChain.VerifyResponseRange(ctx, returnedResponseChan, 0, pauseAt) // wait for the pause to occur <-holdForPause @@ -948,7 +958,7 @@ func TestPauseResumeExternal(t *testing.T) { td.fal.SuccessResponseOn(peers[0], rr.gsr.ID(), td.blockChain.AllBlocks()) // verify the correct results are returned, picking up after where there request was paused - td.blockChain.VerifyRemainder(ctx, returnedResponseChan, pauseAt-1) + td.blockChain.VerifyRemainder(ctx, returnedResponseChan, pauseAt) testutil.VerifyEmptyErrors(ctx, t, returnedErrorChan) } @@ -970,6 +980,8 @@ type testData struct { extensionData2 []byte extension2 graphsync.ExtensionData networkErrorListeners *listeners.NetworkErrorListeners + taskqueue *taskqueue.WorkerTaskQueue + executor *executor.Executor } func newTestData(ctx context.Context, t *testing.T) *testData { @@ -981,9 +993,13 @@ func newTestData(ctx context.Context, t *testing.T) *testData { td.responseHooks = hooks.NewResponseHooks() td.blockHooks = hooks.NewBlockHooks() td.networkErrorListeners = listeners.NewNetworkErrorListeners() - td.requestManager = New(ctx, td.fal, cidlink.DefaultLinkSystem(), td.requestHooks, td.responseHooks, td.blockHooks, td.networkErrorListeners) + td.taskqueue = taskqueue.NewTaskQueue(ctx) + lsys := cidlink.DefaultLinkSystem() + td.requestManager = New(ctx, td.fal, lsys, td.requestHooks, td.responseHooks, td.networkErrorListeners, td.taskqueue) + td.executor = executor.NewExecutor(td.requestManager, td.blockHooks, td.fal.AsyncLoad) td.requestManager.SetDelegate(td.fph) td.requestManager.Startup() + td.taskqueue.Startup(6, td.executor) td.blockStore = make(map[ipld.Link][]byte) td.persistence = testutil.NewTestStore(td.blockStore) td.blockChain = testutil.SetupBlockChain(ctx, t, td.persistence, 100, 5) diff --git a/requestmanager/server.go b/requestmanager/server.go index afe372a0..ed90c510 100644 --- a/requestmanager/server.go +++ b/requestmanager/server.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "math" "time" blocks "github.com/ipfs/go-block-format" @@ -15,8 +16,10 @@ import ( gsmsg "github.com/ipfs/go-graphsync/message" "github.com/ipfs/go-graphsync/requestmanager/executor" "github.com/ipfs/go-graphsync/requestmanager/hooks" + "github.com/ipfs/go-peertaskqueue/peertask" "github.com/ipld/go-ipld-prime" cidlink "github.com/ipld/go-ipld-prime/linking/cid" + "github.com/ipld/go-ipld-prime/traversal" "github.com/ipld/go-ipld-prime/traversal/selector" "github.com/libp2p/go-libp2p-core/peer" ) @@ -45,7 +48,7 @@ func (rm *RequestManager) cleanupInProcessRequests() { } } -func (rm *RequestManager) setupRequest(p peer.ID, root ipld.Link, selector ipld.Node, extensions []graphsync.ExtensionData) (gsmsg.GraphSyncRequest, chan graphsync.ResponseProgress, chan error) { +func (rm *RequestManager) newRequest(p peer.ID, root ipld.Link, selector ipld.Node, extensions []graphsync.ExtensionData) (gsmsg.GraphSyncRequest, chan graphsync.ResponseProgress, chan error) { requestID := rm.nextRequestID rm.nextRequestID++ @@ -68,55 +71,126 @@ func (rm *RequestManager) setupRequest(p peer.ID, root ipld.Link, selector ipld. doNotSendCids = cid.NewSet() } ctx, cancel := context.WithCancel(rm.ctx) - resumeMessages := make(chan []graphsync.ExtensionData, 1) - pauseMessages := make(chan struct{}, 1) - terminalError := make(chan error, 1) requestStatus := &inProgressRequestStatus{ - ctx: ctx, startTime: time.Now(), cancelFn: cancel, p: p, resumeMessages: resumeMessages, pauseMessages: pauseMessages, terminalError: terminalError, + ctx: ctx, + startTime: time.Now(), + cancelFn: cancel, + p: p, + pauseMessages: make(chan struct{}, 1), + doNotSendCids: doNotSendCids, + request: request, + state: queued, + nodeStyleChooser: hooksResult.CustomChooser, + inProgressChan: make(chan graphsync.ResponseProgress), + inProgressErr: make(chan error), } - lastResponse := &requestStatus.lastResponse - lastResponse.Store(gsmsg.NewResponse(request.ID(), graphsync.RequestAcknowledged)) + requestStatus.lastResponse.Store(gsmsg.NewResponse(request.ID(), graphsync.RequestAcknowledged)) rm.inProgressRequestStatuses[request.ID()] = requestStatus - incoming, incomingError := executor.ExecutionEnv{ - Ctx: rm.ctx, - SendRequest: rm.SendRequest, - TerminateRequest: rm.TerminateRequest, - RunBlockHooks: rm.ProcessBlockHooks, - Loader: rm.asyncLoader.AsyncLoad, - LinkSystem: rm.linkSystem, - }.Start( - executor.RequestExecution{ - Ctx: ctx, - P: p, - Request: request, - TerminalError: terminalError, - LastResponse: lastResponse, - DoNotSendCids: doNotSendCids, - NodePrototypeChooser: hooksResult.CustomChooser, - ResumeMessages: resumeMessages, - PauseMessages: pauseMessages, - }) - return request, incoming, incomingError + + rm.requestQueue.PushTask(p, peertask.Task{Topic: requestID, Priority: math.MaxInt32, Work: 1}) + return request, requestStatus.inProgressChan, requestStatus.inProgressErr } -func (rm *RequestManager) terminateRequest(requestID graphsync.RequestID) { +func (rm *RequestManager) requestTask(requestID graphsync.RequestID) executor.RequestTask { ipr, ok := rm.inProgressRequestStatuses[requestID] - if ok { - log.Infow("graphsync request complete", "request id", requestID, "peer", ipr.p, "total time", time.Since(ipr.startTime)) + if !ok { + return executor.RequestTask{Empty: true} + } + log.Infow("graphsync request processing begins", "request id", requestID, "peer", ipr.p, "total time", time.Since(ipr.startTime)) + + var initialRequest bool + if ipr.traverser == nil { + initialRequest = true + ipr.traverser = ipldutil.TraversalBuilder{ + Root: cidlink.Link{Cid: ipr.request.Root()}, + Selector: ipr.request.Selector(), + Visitor: func(tp traversal.Progress, node ipld.Node, tr traversal.VisitReason) error { + select { + case <-ipr.ctx.Done(): + case ipr.inProgressChan <- graphsync.ResponseProgress{ + Node: node, + Path: tp.Path, + LastBlock: tp.LastBlock, + }: + } + return nil + }, + Chooser: ipr.nodeStyleChooser, + LinkSystem: rm.linkSystem, + }.Start(ipr.ctx) + } + + ipr.state = running + return executor.RequestTask{ + Ctx: ipr.ctx, + Request: ipr.request, + LastResponse: &ipr.lastResponse, + DoNotSendCids: ipr.doNotSendCids, + PauseMessages: ipr.pauseMessages, + Traverser: ipr.traverser, + P: ipr.p, + InProgressErr: ipr.inProgressErr, + InitialRequest: initialRequest, + Empty: false, + } +} + +func (rm *RequestManager) getRequestTask(p peer.ID, task *peertask.Task) executor.RequestTask { + requestID := task.Topic.(graphsync.RequestID) + requestExecution := rm.requestTask(requestID) + if requestExecution.Empty { + rm.requestQueue.TaskDone(p, task) + } + return requestExecution +} + +func (rm *RequestManager) terminateRequest(requestID graphsync.RequestID, ipr *inProgressRequestStatus) { + if ipr.terminalError != nil { + select { + case ipr.inProgressErr <- ipr.terminalError: + case <-rm.ctx.Done(): + } } delete(rm.inProgressRequestStatuses, requestID) + ipr.cancelFn() rm.asyncLoader.CleanupRequest(requestID) - if ok { - for _, onTerminated := range ipr.onTerminated { - select { - case <-rm.ctx.Done(): - case onTerminated <- nil: - } + if ipr.traverser != nil { + ipr.traverser.Shutdown(rm.ctx) + } + // make sure context is not closed before closing channels (could cause send + // on close channel otherwise) + select { + case <-rm.ctx.Done(): + return + default: + } + close(ipr.inProgressChan) + close(ipr.inProgressErr) + for _, onTerminated := range ipr.onTerminated { + select { + case <-rm.ctx.Done(): + case onTerminated <- nil: } } } -func (rm *RequestManager) cancelRequest(requestID graphsync.RequestID, isPause bool, onTerminated chan<- error, terminalError error) { +func (rm *RequestManager) releaseRequestTask(p peer.ID, task *peertask.Task, err error) { + requestID := task.Topic.(graphsync.RequestID) + rm.requestQueue.TaskDone(p, task) + + ipr, ok := rm.inProgressRequestStatuses[requestID] + if !ok { + return + } + if _, ok := err.(hooks.ErrPaused); ok { + ipr.state = paused + return + } + log.Infow("graphsync request complete", "request id", requestID, "peer", ipr.p, "total time", time.Since(ipr.startTime)) + rm.terminateRequest(requestID, ipr) +} + +func (rm *RequestManager) cancelRequest(requestID graphsync.RequestID, onTerminated chan<- error, terminalError error) { inProgressRequestStatus, ok := rm.inProgressRequestStatuses[requestID] if !ok { if onTerminated != nil { @@ -131,28 +205,30 @@ func (rm *RequestManager) cancelRequest(requestID graphsync.RequestID, isPause b if onTerminated != nil { inProgressRequestStatus.onTerminated = append(inProgressRequestStatus.onTerminated, onTerminated) } - if terminalError != nil { - select { - case inProgressRequestStatus.terminalError <- terminalError: - default: - } - } - rm.SendRequest(inProgressRequestStatus.p, gsmsg.CancelRequest(requestID)) - if isPause { - inProgressRequestStatus.paused = true + rm.cancelOnError(requestID, inProgressRequestStatus, terminalError) +} + +func (rm *RequestManager) cancelOnError(requestID graphsync.RequestID, ipr *inProgressRequestStatus, terminalError error) { + if ipr.terminalError == nil { + ipr.terminalError = terminalError + } + if ipr.state != running { + rm.terminateRequest(requestID, ipr) } else { - inProgressRequestStatus.cancelFn() + ipr.cancelFn() } } func (rm *RequestManager) processResponseMessage(p peer.ID, responses []gsmsg.GraphSyncResponse, blks []blocks.Block) { + log.Debugf("beging rocessing message for peer %s", p) filteredResponses := rm.processExtensions(responses, p) filteredResponses = rm.filterResponsesForPeer(filteredResponses, p) rm.updateLastResponses(filteredResponses) responseMetadata := metadataForResponses(filteredResponses) rm.asyncLoader.ProcessResponse(responseMetadata, blks) rm.processTerminations(filteredResponses) + log.Debugf("end processing message for peer %s", p) } func (rm *RequestManager) filterResponsesForPeer(responses []gsmsg.GraphSyncResponse, p peer.ID) []gsmsg.GraphSyncResponse { @@ -195,13 +271,8 @@ func (rm *RequestManager) processExtensionsForResponse(p peer.ID, response gsmsg if !ok { return false } - responseError := graphsync.RequestFailedUnknown.AsError() - select { - case requestStatus.terminalError <- responseError: - default: - } - rm.SendRequest(p, gsmsg.CancelRequest(response.RequestID())) - requestStatus.cancelFn() + rm.SendRequest(requestStatus.p, gsmsg.CancelRequest(response.RequestID())) + rm.cancelOnError(response.RequestID(), requestStatus, result.Err) return false } return true @@ -211,13 +282,7 @@ func (rm *RequestManager) processTerminations(responses []gsmsg.GraphSyncRespons for _, response := range responses { if response.Status().IsTerminal() { if response.Status().IsFailure() { - requestStatus := rm.inProgressRequestStatuses[response.RequestID()] - responseError := response.Status().AsError() - select { - case requestStatus.terminalError <- responseError: - default: - } - requestStatus.cancelFn() + rm.cancelOnError(response.RequestID(), rm.inProgressRequestStatuses[response.RequestID()], response.Status().AsError()) } rm.asyncLoader.CompleteResponsesFor(response.RequestID()) } @@ -263,19 +328,13 @@ func (rm *RequestManager) unpause(id graphsync.RequestID, extensions []graphsync if !ok { return graphsync.RequestNotFoundErr{} } - if !inProgressRequestStatus.paused { + if inProgressRequestStatus.state != paused { return errors.New("request is not paused") } - inProgressRequestStatus.paused = false - select { - case <-inProgressRequestStatus.pauseMessages: - rm.SendRequest(inProgressRequestStatus.p, gsmsg.UpdateRequest(id, extensions...)) - return nil - case <-rm.ctx.Done(): - return errors.New("context cancelled") - case inProgressRequestStatus.resumeMessages <- extensions: - return nil - } + inProgressRequestStatus.state = queued + inProgressRequestStatus.request = inProgressRequestStatus.request.ReplaceExtensions(extensions) + rm.requestQueue.PushTask(inProgressRequestStatus.p, peertask.Task{Topic: id, Priority: math.MaxInt32, Work: 1}) + return nil } func (rm *RequestManager) pause(id graphsync.RequestID) error { @@ -283,14 +342,12 @@ func (rm *RequestManager) pause(id graphsync.RequestID) error { if !ok { return graphsync.RequestNotFoundErr{} } - if inProgressRequestStatus.paused { + if inProgressRequestStatus.state == paused { return errors.New("request is already paused") } - inProgressRequestStatus.paused = true select { - case <-rm.ctx.Done(): - return errors.New("context cancelled") case inProgressRequestStatus.pauseMessages <- struct{}{}: - return nil + default: } + return nil } diff --git a/taskqueue/taskqueue.go b/taskqueue/taskqueue.go new file mode 100644 index 00000000..29dbe45c --- /dev/null +++ b/taskqueue/taskqueue.go @@ -0,0 +1,94 @@ +package taskqueue + +import ( + "context" + "time" + + "github.com/ipfs/go-peertaskqueue" + "github.com/ipfs/go-peertaskqueue/peertask" + peer "github.com/libp2p/go-libp2p-core/peer" +) + +const thawSpeed = time.Millisecond * 100 + +// Executor runs a single task on the queue +type Executor interface { + ExecuteTask(ctx context.Context, pid peer.ID, task *peertask.Task) bool +} + +type TaskQueue interface { + PushTask(p peer.ID, task peertask.Task) + TaskDone(p peer.ID, task *peertask.Task) +} + +// TaskQueue is a wrapper around peertaskqueue.PeerTaskQueue that manages running workers +// that pop tasks and execute them +type WorkerTaskQueue struct { + ctx context.Context + cancelFn func() + peerTaskQueue *peertaskqueue.PeerTaskQueue + workSignal chan struct{} + ticker *time.Ticker +} + +// NewTaskQueue initializes a new queue +func NewTaskQueue(ctx context.Context) *WorkerTaskQueue { + ctx, cancelFn := context.WithCancel(ctx) + return &WorkerTaskQueue{ + ctx: ctx, + cancelFn: cancelFn, + peerTaskQueue: peertaskqueue.New(), + workSignal: make(chan struct{}, 1), + ticker: time.NewTicker(thawSpeed), + } +} + +// PushTask pushes a new task on to the queue +func (tq *WorkerTaskQueue) PushTask(p peer.ID, task peertask.Task) { + tq.peerTaskQueue.PushTasks(p, task) + select { + case tq.workSignal <- struct{}{}: + default: + } +} + +// TaskDone marks a task as completed so further tasks can be executed +func (tq *WorkerTaskQueue) TaskDone(p peer.ID, task *peertask.Task) { + tq.peerTaskQueue.TasksDone(p, task) +} + +// Startup runs the given number of task workers with the given executor +func (tq *WorkerTaskQueue) Startup(workerCount uint64, executor Executor) { + for i := uint64(0); i < workerCount; i++ { + go tq.worker(executor) + } +} + +// Shutdown shuts down all running workers +func (tq *WorkerTaskQueue) Shutdown() { + tq.cancelFn() +} + +func (tq *WorkerTaskQueue) worker(executor Executor) { + targetWork := 1 + for { + pid, tasks, _ := tq.peerTaskQueue.PopTasks(targetWork) + for len(tasks) == 0 { + select { + case <-tq.ctx.Done(): + return + case <-tq.workSignal: + pid, tasks, _ = tq.peerTaskQueue.PopTasks(targetWork) + case <-tq.ticker.C: + tq.peerTaskQueue.ThawRound() + pid, tasks, _ = tq.peerTaskQueue.PopTasks(targetWork) + } + } + for _, task := range tasks { + terminate := executor.ExecuteTask(tq.ctx, pid, task) + if terminate { + return + } + } + } +} diff --git a/testutil/testutil.go b/testutil/testutil.go index baf20b08..d9b0e978 100644 --- a/testutil/testutil.go +++ b/testutil/testutil.go @@ -125,7 +125,7 @@ func CollectResponses(ctx context.Context, t TestingT, responseChan <-chan graph } collectedBlocks = append(collectedBlocks, blk) case <-ctx.Done(): - t.Fatal("response channel never closed") + require.FailNow(t, "response channel never closed") } } } @@ -167,8 +167,8 @@ func VerifySingleTerminalError(ctx context.Context, t TestingT, errChan <-chan e var err error AssertReceive(ctx, t, errChan, &err, "should receive an error") select { - case _, ok := <-errChan: - require.False(t, ok, "shouldn't have sent second error but did") + case secondErr, ok := <-errChan: + require.Falsef(t, ok, "shouldn't have sent second error but sent: %s, %s", err, secondErr) case <-ctx.Done(): t.Fatal("errors not closed") }