From 432a10ec7e4a59d3c15a0db1bb02591844c64773 Mon Sep 17 00:00:00 2001 From: Rod Vagg Date: Fri, 18 Aug 2023 14:08:22 +1000 Subject: [PATCH] feat(graphsync): WithResponseProgressListener to watch traverser --- transport/graphsync/graphsync.go | 82 +++++++++++++++++++++------ transport/graphsync/graphsync_test.go | 50 ++++++++++++++-- 2 files changed, 109 insertions(+), 23 deletions(-) diff --git a/transport/graphsync/graphsync.go b/transport/graphsync/graphsync.go index d4a4db7..6313cdc 100644 --- a/transport/graphsync/graphsync.go +++ b/transport/graphsync/graphsync.go @@ -45,6 +45,10 @@ var outgoingBlkExtensions = []graphsync.ExtensionName{ // Option is an option for setting up the graphsync transport type Option func(*Transport) +type ResponseProgressListener func(progress graphsync.ResponseProgress) + +func noopResponseProgressListener(_ graphsync.ResponseProgress) {} + // SupportedExtensions sets what data transfer extensions are supported func SupportedExtensions(supportedExtensions []graphsync.ExtensionName) Option { return func(t *Transport) { @@ -142,8 +146,13 @@ func (t *Transport) OpenChannel( return err } + listener := ch.responseProgressListener() + if listener == nil { + listener = noopResponseProgressListener + } + // Process incoming data - go t.executeGsRequest(req) + go t.executeGsRequest(req, listener) return nil } @@ -169,9 +178,10 @@ func getDoNotSendFirstBlocksExtension(channel datatransfer.ChannelState) ([]grap // Read from the graphsync response and error channels until they are closed, // and return the last error on the error channel -func (t *Transport) consumeResponses(req *gsReq) error { +func (t *Transport) consumeResponses(req *gsReq, listener ResponseProgressListener) error { var lastError error - for range req.responseChan { + for r := range req.responseChan { + listener(r) } log.Debugf("channel %s: finished consuming graphsync response channel", req.channelID) @@ -185,7 +195,7 @@ func (t *Transport) consumeResponses(req *gsReq) error { // Read from the graphsync response and error channels until they are closed // or there is an error, then call the channel completed callback -func (t *Transport) executeGsRequest(req *gsReq) { +func (t *Transport) executeGsRequest(req *gsReq, listener ResponseProgressListener) { // Make sure to call the onComplete callback before returning defer func() { log.Infow("gs request complete for channel", "chid", req.channelID) @@ -193,11 +203,11 @@ func (t *Transport) executeGsRequest(req *gsReq) { }() // Consume the response and error channels for the graphsync request - lastError := t.consumeResponses(req) + lastError := t.consumeResponses(req, listener) // Request cancelled by client if _, ok := lastError.(graphsync.RequestClientCancelledErr); ok { - terr := xerrors.Errorf("graphsync request cancelled") + terr := fmt.Errorf("graphsync request cancelled") log.Warnf("channel %s: %s", req.channelID, terr) if err := t.events.OnRequestCancelled(req.channelID, terr); err != nil { log.Error(err) @@ -220,7 +230,7 @@ func (t *Transport) executeGsRequest(req *gsReq) { var completeErr error if lastError != nil { - completeErr = xerrors.Errorf("channel %s: graphsync request failed to complete: %w", req.channelID, lastError) + completeErr = fmt.Errorf("channel %s: graphsync request failed to complete: %w", req.channelID, lastError) } // Used by the tests to listen for when a request completes @@ -266,7 +276,7 @@ func (t *Transport) CloseChannel(ctx context.Context, chid datatransfer.ChannelI err = ch.close(ctx) if err != nil { - return xerrors.Errorf("closing channel: %w", err) + return fmt.Errorf("closing channel: %w", err) } return nil } @@ -332,7 +342,7 @@ func (t *Transport) Shutdown(ctx context.Context) error { err := eg.Wait() if err != nil { - return xerrors.Errorf("shutting down graphsync transport: %w", err) + return fmt.Errorf("shutting down graphsync transport: %w", err) } return nil } @@ -362,6 +372,19 @@ func MaxLinks(maxLinks uint64) datatransfer.TransportOption { } } +// WithResponseProgressListener registers a listener for graphsync response +// progress events. Currently only one listener per channel is supported. +func WithResponseProgressListener(listener ResponseProgressListener) datatransfer.TransportOption { + return func(channelID datatransfer.ChannelID, transport datatransfer.Transport) error { + gsTransport, ok := transport.(*Transport) + if !ok { + return datatransfer.ErrUnsupported + } + gsTransport.WithResponseProgressListener(channelID, listener) + return nil + } +} + // UseStore tells the graphsync transport to use the given loader and storer for this channelID func (t *Transport) UseStore(channelID datatransfer.ChannelID, lsys ipld.LinkSystem) error { ch := t.trackDTChannel(channelID) @@ -374,6 +397,14 @@ func (t *Transport) MaxLinks(channelID datatransfer.ChannelID, maxLinks uint64) ch.setMaxLinks(maxLinks) } +// WithResponseProgressListener registers a listener for graphsync response +// progress events for this channel ID. Currently only one listener per +// channel is supported. +func (t *Transport) WithResponseProgressListener(channelID datatransfer.ChannelID, listener ResponseProgressListener) { + ch := t.trackDTChannel(channelID) + ch.setResponseProgressListener(listener) +} + // ChannelGraphsyncRequests describes any graphsync request IDs associated with a given channel type ChannelGraphsyncRequests struct { // Current is the current request ID for the transfer @@ -693,7 +724,7 @@ func (t *Transport) gsCompletedResponseListener(p peer.ID, request graphsync.Req var completeErr error if status != graphsync.RequestCompletedFull { statusStr := gsResponseStatusCodeString(status) - completeErr = xerrors.Errorf("graphsync response to peer %s did not complete: response status code %s", p, statusStr) + completeErr = fmt.Errorf("graphsync response to peer %s did not complete: response status code %s", p, statusStr) } // Used by the tests to listen for when a response completes @@ -903,7 +934,7 @@ func (t *Transport) getDTChannel(chid datatransfer.ChannelID) (*dtChannel, error ch, ok := t.dtChannels[chid] if !ok { - return nil, xerrors.Errorf("channel %s: %w", chid, datatransfer.ErrChannelNotFound) + return nil, fmt.Errorf("channel %s: %w", chid, datatransfer.ErrChannelNotFound) } return ch, nil } @@ -923,9 +954,10 @@ type dtChannel struct { opened chan graphsync.RequestID - optionsLk sync.RWMutex - storeRegistered bool - maxLinksOption uint64 + optionsLk sync.RWMutex + storeRegistered bool + maxLinksOption uint64 + progressListener ResponseProgressListener } // Info needed to monitor an ongoing graphsync request @@ -958,17 +990,17 @@ func (c *dtChannel) open( // Wait for the complete callback to be called err := waitForCompleteHook(ctx, completed) if err != nil { - return nil, xerrors.Errorf("%s: waiting for cancelled graphsync request to complete: %w", chid, err) + return nil, fmt.Errorf("%s: waiting for cancelled graphsync request to complete: %w", chid, err) } // Wait for the cancel request method to complete select { case err = <-errch: case <-ctx.Done(): - err = xerrors.Errorf("timed out waiting for graphsync request to be cancelled") + err = fmt.Errorf("timed out waiting for graphsync request to be cancelled") } if err != nil { - return nil, xerrors.Errorf("%s: restarting graphsync request: %w", chid, err) + return nil, fmt.Errorf("%s: restarting graphsync request: %w", chid, err) } } @@ -1181,6 +1213,20 @@ func (c *dtChannel) setMaxLinks(maxLinks uint64) { c.maxLinksOption = maxLinks } +func (c *dtChannel) setResponseProgressListener(listener ResponseProgressListener) { + c.optionsLk.Lock() + defer c.optionsLk.Unlock() + + c.progressListener = listener +} + +func (c *dtChannel) responseProgressListener() ResponseProgressListener { + c.optionsLk.Lock() + defer c.optionsLk.Unlock() + + return c.progressListener +} + // Use the given loader and storer to get / put blocks for the data-transfer. // Note that each data-transfer channel uses a separate blockstore. func (c *dtChannel) useStore(lsys ipld.LinkSystem) error { @@ -1253,7 +1299,7 @@ func (c *dtChannel) cancel(ctx context.Context) chan error { // Ignore "request not found" errors if err != nil && !xerrors.Is(graphsync.RequestNotFoundErr{}, err) { - errch <- xerrors.Errorf("cancelling graphsync request for channel %s: %w", c.channelID, err) + errch <- fmt.Errorf("cancelling graphsync request for channel %s: %w", c.channelID, err) } else { errch <- nil } diff --git a/transport/graphsync/graphsync_test.go b/transport/graphsync/graphsync_test.go index 6a728c9..6368a04 100644 --- a/transport/graphsync/graphsync_test.go +++ b/transport/graphsync/graphsync_test.go @@ -330,7 +330,6 @@ func TestManager(t *testing.T) { require.NoError(t, gsData.outgoingBlockHookActions.TerminationError) }, }, - "incoming gs request with recognized dt response will record outgoing blocks": { requestConfig: gsRequestConfig{ dtIsResponse: true, @@ -512,7 +511,6 @@ func TestManager(t *testing.T) { require.True(t, events.ChannelCompletedSuccess) }, }, - "recognized incoming request will record unsuccessful request completion": { responseConfig: gsResponseConfig{ status: graphsync.RequestCompletedPartial, @@ -614,7 +612,6 @@ func TestManager(t *testing.T) { gsData.fgs.AssertNoPauseReceived(t) }, }, - "recognized incoming request can begin processing": { action: func(gsData *harness) { gsData.incomingRequestHook() @@ -626,7 +623,6 @@ func TestManager(t *testing.T) { events.TransferInitiatedChannelID) }, }, - "recognized incoming request can be resumed": { action: func(gsData *harness) { gsData.incomingRequestHook() @@ -641,7 +637,6 @@ func TestManager(t *testing.T) { gsData.fgs.AssertResumeReceived(gsData.ctx, t) }, }, - "unrecognized request cannot be resumed": { check: func(t *testing.T, events *fakeEvents, gsData *harness) { err := gsData.transport.ResumeChannel(gsData.ctx, @@ -1051,6 +1046,41 @@ func TestManager(t *testing.T) { require.Equal(t, uint64(101), gsData.incomingRequestHookActions.MaxLinksOption) }, }, + "WithResponseProgressListener can be used to receive progress events": { + action: func(gsData *harness) { + gsData.fgs.LeaveRequestsOpen() + gsData.transport.WithResponseProgressListener( + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, + gsData.progressCollector.OnProgress, + ) + + stor, _ := gsData.outgoing.Selector() + go gsData.outgoingRequestHook() + _ = gsData.transport.OpenChannel( + gsData.ctx, + gsData.other, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, + cidlink.Link{Cid: gsData.outgoing.BaseCid()}, + stor, + nil, + gsData.outgoing) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) + requestReceived.ResponseChan <- graphsync.ResponseProgress{Path: datamodel.ParsePath("yep")} + requestReceived.ResponseChan <- graphsync.ResponseProgress{Path: datamodel.ParsePath("yep/yerp")} + requestReceived.ResponseChan <- graphsync.ResponseProgress{Path: datamodel.ParsePath("yep/yerp/yeppity!")} + close(requestReceived.ResponseChan) + close(requestReceived.ResponseErrChan) + + require.Eventually(t, func() bool { + return events.OnChannelCompletedCalled == true + }, 2*time.Second, 100*time.Millisecond) + require.True(t, events.ChannelCompletedSuccess) + + require.Equal(t, []string{"yep", "yep/yerp", "yep/yerp/yeppity!"}, gsData.progressCollector.paths) + }, + }, } ctx := context.Background() @@ -1090,6 +1120,7 @@ func TestManager(t *testing.T) { incomingRequestHookActions: &testharness.FakeIncomingRequestHookActions{}, requestUpdatedHookActions: &testharness.FakeRequestUpdatedActions{}, incomingResponseHookActions: &testharness.FakeIncomingResponseHookActions{}, + progressCollector: &progressCollector{paths: make([]string, 0)}, } require.NoError(t, transport.SetEventHandler(&data.events)) if data.action != nil { @@ -1235,6 +1266,7 @@ type harness struct { incomingRequestHookActions *testharness.FakeIncomingRequestHookActions requestUpdatedHookActions *testharness.FakeRequestUpdatedActions incomingResponseHookActions *testharness.FakeIncomingResponseHookActions + progressCollector *progressCollector } func (ha *harness) outgoingRequestHook() { @@ -1282,6 +1314,14 @@ func (ha *harness) incomingRequestProcessingListener() { ha.fgs.IncomingRequestProcessingListener(ha.other, ha.request, 0) } +type progressCollector struct { + paths []string +} + +func (pc *progressCollector) OnProgress(progress graphsync.ResponseProgress) { + pc.paths = append(pc.paths, progress.Path.String()) +} + type dtConfig struct { dtExtensionMissing bool dtIsResponse bool