diff --git a/graphsync.go b/graphsync.go index 0c582bfc..06d34c10 100644 --- a/graphsync.go +++ b/graphsync.go @@ -2,6 +2,7 @@ package graphsync import ( "context" + "errors" "github.com/ipfs/go-cid" "github.com/ipld/go-ipld-prime" @@ -83,6 +84,11 @@ const ( RequestFailedContentNotFound = ResponseStatusCode(34) ) +var ( + // ErrExtensionAlreadyRegistered means a user extension can be registered only once + ErrExtensionAlreadyRegistered = errors.New("extension already registered") +) + // ResponseProgress is the fundamental unit of responses making progress in Graphsync. type ResponseProgress struct { Node ipld.Node // a node which matched the graphsync query @@ -115,6 +121,19 @@ type RequestData interface { IsCancel() bool } +// ResponseData describes a received Graphsync response +type ResponseData interface { + // RequestID returns the request ID for this response + RequestID() RequestID + + // Status returns the status for a response + Status() ResponseStatusCode + + // Extension returns the content for an extension on a response, or errors + // if extension is not present + Extension(name ExtensionName) ([]byte, bool) +} + // RequestReceivedHookActions are actions that a request hook can take to change // behavior for the response type RequestReceivedHookActions interface { @@ -130,6 +149,11 @@ type RequestReceivedHookActions interface { // err - error - if not nil, halt request and return RequestRejected with the responseData type OnRequestReceivedHook func(p peer.ID, request RequestData, hookActions RequestReceivedHookActions) +// OnResponseReceivedHook is a hook that runs each time a response is received. +// It receives the peer that sent the response and all data about the response. +// If it returns an error processing is halted and the original request is cancelled. +type OnResponseReceivedHook func(p peer.ID, responseData ResponseData) error + // GraphExchange is a protocol that can exchange IPLD graphs based on a selector type GraphExchange interface { // Request initiates a new GraphSync request to the given peer using the given selector spec. @@ -140,4 +164,7 @@ type GraphExchange interface { // it is considered to have "validated" the request -- and that validation supersedes // the normal validation of requests Graphsync does (i.e. all selectors can be accepted) RegisterRequestReceivedHook(hook OnRequestReceivedHook) error + + // RegisterResponseReceivedHook adds a hook that runs when a response is received + RegisterResponseReceivedHook(OnResponseReceivedHook) error } diff --git a/impl/graphsync.go b/impl/graphsync.go index 44d0db5e..d7c36cd9 100644 --- a/impl/graphsync.go +++ b/impl/graphsync.go @@ -92,7 +92,12 @@ func (gs *GraphSync) Request(ctx context.Context, p peer.ID, root ipld.Link, sel // the normal validation of requests Graphsync does (i.e. all selectors can be accepted) func (gs *GraphSync) RegisterRequestReceivedHook(hook graphsync.OnRequestReceivedHook) error { gs.responseManager.RegisterHook(hook) - // may be a need to return errors here in the future... + return nil +} + +// RegisterResponseReceivedHook adds a hook that runs when a response is received +func (gs *GraphSync) RegisterResponseReceivedHook(hook graphsync.OnResponseReceivedHook) error { + gs.requestManager.RegisterHook(hook) return nil } diff --git a/impl/graphsync_test.go b/impl/graphsync_test.go index e193c722..b1038772 100644 --- a/impl/graphsync_test.go +++ b/impl/graphsync_test.go @@ -2,6 +2,7 @@ package graphsync import ( "context" + "errors" "math" "math/rand" "reflect" @@ -25,173 +26,32 @@ import ( ipld "github.com/ipld/go-ipld-prime" ipldselector "github.com/ipld/go-ipld-prime/traversal/selector" "github.com/ipld/go-ipld-prime/traversal/selector/builder" + "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/peer" mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" mh "github.com/multiformats/go-multihash" ) -type receivedMessage struct { - message gsmsg.GraphSyncMessage - sender peer.ID -} - -// Receiver is an interface for receiving messages from the GraphSyncNetwork. -type receiver struct { - messageReceived chan receivedMessage -} - -func (r *receiver) ReceiveMessage( - ctx context.Context, - sender peer.ID, - incoming gsmsg.GraphSyncMessage) { - - select { - case <-ctx.Done(): - case r.messageReceived <- receivedMessage{incoming, sender}: - } -} - -func (r *receiver) ReceiveError(err error) { -} - -func (r *receiver) Connected(p peer.ID) { -} - -func (r *receiver) Disconnected(p peer.ID) { -} - -type blockChain struct { - genisisNode ipld.Node - genisisLink ipld.Link - middleNodes []ipld.Node - middleLinks []ipld.Link - tipNode ipld.Node - tipLink ipld.Link -} - -func createBlock(nb ipldbridge.NodeBuilder, parents []ipld.Link, size int64) ipld.Node { - return nb.CreateMap(func(mb ipldbridge.MapBuilder, knb ipldbridge.NodeBuilder, vnb ipldbridge.NodeBuilder) { - mb.Insert(knb.CreateString("Parents"), vnb.CreateList(func(lb ipldbridge.ListBuilder, vnb ipldbridge.NodeBuilder) { - for _, parent := range parents { - lb.Append(vnb.CreateLink(parent)) - } - })) - mb.Insert(knb.CreateString("Messages"), vnb.CreateList(func(lb ipldbridge.ListBuilder, vnb ipldbridge.NodeBuilder) { - lb.Append(vnb.CreateBytes(testutil.RandomBytes(size))) - })) - }) -} - -func setupBlockChain( - ctx context.Context, - t *testing.T, - storer ipldbridge.Storer, - bridge ipldbridge.IPLDBridge, - size int64, - blockChainLength int) *blockChain { - linkBuilder := cidlink.LinkBuilder{Prefix: cid.NewPrefixV1(cid.DagCBOR, mh.SHA2_256)} - var genisisNode ipld.Node - err := fluent.Recover(func() { - nb := fluent.WrapNodeBuilder(ipldfree.NodeBuilder()) - genisisNode = createBlock(nb, []ipld.Link{}, size) - }) - if err != nil { - t.Fatal("Error creating genesis block") - } - genesisLink, err := linkBuilder.Build(ctx, ipldbridge.LinkContext{}, genisisNode, storer) - if err != nil { - t.Fatal("Error creating link to genesis block") - } - parent := genesisLink - middleNodes := make([]ipld.Node, 0, blockChainLength-2) - middleLinks := make([]ipld.Link, 0, blockChainLength-2) - for i := 0; i < blockChainLength-2; i++ { - var node ipld.Node - err := fluent.Recover(func() { - nb := fluent.WrapNodeBuilder(ipldfree.NodeBuilder()) - node = createBlock(nb, []ipld.Link{parent}, size) - }) - if err != nil { - t.Fatal("Error creating middle block") - } - middleNodes = append(middleNodes, node) - link, err := linkBuilder.Build(ctx, ipldbridge.LinkContext{}, node, storer) - if err != nil { - t.Fatal("Error creating link to middle block") - } - middleLinks = append(middleLinks, link) - parent = link - } - var tipNode ipld.Node - err = fluent.Recover(func() { - nb := fluent.WrapNodeBuilder(ipldfree.NodeBuilder()) - tipNode = createBlock(nb, []ipld.Link{parent}, size) - }) - if err != nil { - t.Fatal("Error creating tip block") - } - tipLink, err := linkBuilder.Build(ctx, ipldbridge.LinkContext{}, tipNode, storer) - if err != nil { - t.Fatal("Error creating link to tip block") - } - return &blockChain{genisisNode, genesisLink, middleNodes, middleLinks, tipNode, tipLink} -} - func TestMakeRequestToNetwork(t *testing.T) { // create network ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - mn := mocknet.New(ctx) - - // setup network - host1, err := mn.GenPeer() - if err != nil { - t.Fatal("error generating host") - } - host2, err := mn.GenPeer() - if err != nil { - t.Fatal("error generating host") - } - err = mn.LinkAll() - if err != nil { - t.Fatal("error linking hosts") - } - - gsnet1 := gsnet.NewFromLibp2pHost(host1) - - // setup receiving peer to just record message coming in - gsnet2 := gsnet.NewFromLibp2pHost(host2) + td := newGsTestData(ctx, t) r := &receiver{ messageReceived: make(chan receivedMessage), } - gsnet2.SetDelegate(r) - - blockStore := make(map[ipld.Link][]byte) - loader, storer := testbridge.NewMockStore(blockStore) - bridge := ipldbridge.NewIPLDBridge() - graphSync := New(ctx, gsnet1, bridge, loader, storer) + td.gsnet2.SetDelegate(r) + graphSync := td.GraphSyncHost1() blockChainLength := 100 - blockChain := setupBlockChain(ctx, t, storer, bridge, 100, blockChainLength) - - ssb := builder.NewSelectorSpecBuilder(ipldfree.NodeBuilder()) - spec := ssb.ExploreRecursive(ipldselector.RecursionLimitDepth(blockChainLength), - ssb.ExploreFields(func(efsb ipldbridge.ExploreFieldsSpecBuilder) { - efsb.Insert("Parents", ssb.ExploreAll( - ssb.ExploreRecursiveEdge())) - })).Node() + blockChain := setupBlockChain(ctx, t, td.storer1, td.bridge, 100, blockChainLength) - extensionData := testutil.RandomBytes(100) - extensionName := graphsync.ExtensionName("AppleSauce/McGee") - extension := graphsync.ExtensionData{ - Name: extensionName, - Data: extensionData, - } + spec := blockChainSelector(blockChainLength) requestCtx, requestCancel := context.WithCancel(ctx) defer requestCancel() - graphSync.Request(requestCtx, host2.ID(), blockChain.tipLink, spec, extension) + graphSync.Request(requestCtx, td.host2.ID(), blockChain.tipLink, spec, td.extension) var message receivedMessage select { @@ -201,7 +61,7 @@ func TestMakeRequestToNetwork(t *testing.T) { } sender := message.sender - if sender != host1.ID() { + if sender != td.host1.ID() { t.Fatal("received message from wrong node") } @@ -211,20 +71,20 @@ func TestMakeRequestToNetwork(t *testing.T) { t.Fatal("Did not add request to received message") } receivedRequest := receivedRequests[0] - receivedSpec, err := bridge.DecodeNode(receivedRequest.Selector()) + receivedSpec, err := td.bridge.DecodeNode(receivedRequest.Selector()) if err != nil { t.Fatal("unable to decode transmitted selector") } if !reflect.DeepEqual(spec, receivedSpec) { t.Fatal("did not transmit selector spec correctly") } - _, err = bridge.ParseSelector(receivedSpec) + _, err = td.bridge.ParseSelector(receivedSpec) if err != nil { t.Fatal("did not receive parsible selector on other side") } - returnedData, found := receivedRequest.Extension(extensionName) - if !found || !reflect.DeepEqual(extensionData, returnedData) { + returnedData, found := receivedRequest.Extension(td.extensionName) + if !found || !reflect.DeepEqual(td.extensionData, returnedData) { t.Fatal("Failed to encode extension") } } @@ -234,58 +94,23 @@ func TestSendResponseToIncomingRequest(t *testing.T) { ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() - mn := mocknet.New(ctx) - - // setup network - host1, err := mn.GenPeer() - if err != nil { - t.Fatal("error generating host") - } - host2, err := mn.GenPeer() - if err != nil { - t.Fatal("error generating host") - } - err = mn.LinkAll() - if err != nil { - t.Fatal("error linking hosts") - } - - gsnet1 := gsnet.NewFromLibp2pHost(host1) + td := newGsTestData(ctx, t) r := &receiver{ messageReceived: make(chan receivedMessage), } - gsnet1.SetDelegate(r) - - // setup receiving peer to just record message coming in - gsnet2 := gsnet.NewFromLibp2pHost(host2) - - blockStore := make(map[ipld.Link][]byte) - loader, storer := testbridge.NewMockStore(blockStore) - bridge := ipldbridge.NewIPLDBridge() - - extensionData := testutil.RandomBytes(100) - extensionName := graphsync.ExtensionName("AppleSauce/McGee") - extension := graphsync.ExtensionData{ - Name: extensionName, - Data: extensionData, - } - extensionResponseData := testutil.RandomBytes(100) - extensionResponse := graphsync.ExtensionData{ - Name: extensionName, - Data: extensionResponseData, - } + td.gsnet1.SetDelegate(r) var receivedRequestData []byte // initialize graphsync on second node to response to requests - gsnet := New(ctx, gsnet2, bridge, loader, storer) - err = gsnet.RegisterRequestReceivedHook( + gsnet := td.GraphSyncHost2() + err := gsnet.RegisterRequestReceivedHook( func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) { var has bool - receivedRequestData, has = requestData.Extension(extensionName) + receivedRequestData, has = requestData.Extension(td.extensionName) if !has { t.Fatal("did not have expected extension") } - hookActions.SendExtensionData(extensionResponse) + hookActions.SendExtensionData(td.extensionResponse) }, ) if err != nil { @@ -293,24 +118,20 @@ func TestSendResponseToIncomingRequest(t *testing.T) { } blockChainLength := 100 - blockChain := setupBlockChain(ctx, t, storer, bridge, 100, blockChainLength) - ssb := builder.NewSelectorSpecBuilder(ipldfree.NodeBuilder()) - spec := ssb.ExploreRecursive(ipldselector.RecursionLimitDepth(blockChainLength), - ssb.ExploreFields(func(efsb ipldbridge.ExploreFieldsSpecBuilder) { - efsb.Insert("Parents", ssb.ExploreAll( - ssb.ExploreRecursiveEdge())) - })).Node() + blockChain := setupBlockChain(ctx, t, td.storer2, td.bridge, 100, blockChainLength) + + spec := blockChainSelector(blockChainLength) - selectorData, err := bridge.EncodeNode(spec) + selectorData, err := td.bridge.EncodeNode(spec) if err != nil { t.Fatal("could not encode selector spec") } requestID := graphsync.RequestID(rand.Int31()) message := gsmsg.New() - message.AddRequest(gsmsg.NewRequest(requestID, blockChain.tipLink.(cidlink.Link).Cid, selectorData, graphsync.Priority(math.MaxInt32), extension)) + message.AddRequest(gsmsg.NewRequest(requestID, blockChain.tipLink.(cidlink.Link).Cid, selectorData, graphsync.Priority(math.MaxInt32), td.extension)) // send request across network - gsnet1.SendMessage(ctx, host2.ID(), message) + td.gsnet1.SendMessage(ctx, td.host2.ID(), message) // read the values sent back to requestor var received gsmsg.GraphSyncMessage var receivedBlocks []blocks.Block @@ -322,14 +143,14 @@ readAllMessages: t.Fatal("did not receive complete response") case message := <-r.messageReceived: sender := message.sender - if sender != host2.ID() { + if sender != td.host2.ID() { t.Fatal("received message from wrong node") } received = message.message receivedBlocks = append(receivedBlocks, received.Blocks()...) receivedResponses := received.Responses() - receivedExtension, found := receivedResponses[0].Extension(extensionName) + receivedExtension, found := receivedResponses[0].Extension(td.extensionName) if found { receivedExtensions = append(receivedExtensions, receivedExtension) } @@ -349,7 +170,7 @@ readAllMessages: t.Fatal("Send incorrect number of blocks or there were duplicate blocks") } - if !reflect.DeepEqual(extensionData, receivedRequestData) { + if !reflect.DeepEqual(td.extensionData, receivedRequestData) { t.Fatal("did not receive correct request extension data") } @@ -357,7 +178,7 @@ readAllMessages: t.Fatal("should have sent extension responses but didn't") } - if !reflect.DeepEqual(receivedExtensions[0], extensionResponseData) { + if !reflect.DeepEqual(receivedExtensions[0], td.extensionResponseData) { t.Fatal("did not return correct extension data") } } @@ -367,52 +188,50 @@ func TestGraphsyncRoundTrip(t *testing.T) { ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() - mn := mocknet.New(ctx) - - // setup network - host1, err := mn.GenPeer() - if err != nil { - t.Fatal("error generating host") - } - host2, err := mn.GenPeer() - if err != nil { - t.Fatal("error generating host") - } - err = mn.LinkAll() - if err != nil { - t.Fatal("error linking hosts") - } - - gsnet1 := gsnet.NewFromLibp2pHost(host1) - - blockStore1 := make(map[ipld.Link][]byte) - loader1, storer1 := testbridge.NewMockStore(blockStore1) - bridge1 := ipldbridge.NewIPLDBridge() + td := newGsTestData(ctx, t) // initialize graphsync on first node to make requests - requestor := New(ctx, gsnet1, bridge1, loader1, storer1) + requestor := td.GraphSyncHost1() // setup receiving peer to just record message coming in - gsnet2 := gsnet.NewFromLibp2pHost(host2) - - blockStore2 := make(map[ipld.Link][]byte) - loader2, storer2 := testbridge.NewMockStore(blockStore2) - bridge2 := ipldbridge.NewIPLDBridge() - blockChainLength := 100 - blockChain := setupBlockChain(ctx, t, storer2, bridge2, 100, blockChainLength) + blockChain := setupBlockChain(ctx, t, td.storer2, td.bridge, 100, blockChainLength) // initialize graphsync on second node to response to requests - New(ctx, gsnet2, bridge2, loader2, storer2) + responder := td.GraphSyncHost2() - ssb := builder.NewSelectorSpecBuilder(ipldfree.NodeBuilder()) - spec := ssb.ExploreRecursive(ipldselector.RecursionLimitDepth(blockChainLength), - ssb.ExploreFields(func(efsb ipldbridge.ExploreFieldsSpecBuilder) { - efsb.Insert("Parents", ssb.ExploreAll( - ssb.ExploreRecursiveEdge())) - })).Node() + var receivedResponseData []byte + var receivedRequestData []byte - progressChan, errChan := requestor.Request(ctx, host2.ID(), blockChain.tipLink, spec) + err := requestor.RegisterResponseReceivedHook( + func(p peer.ID, responseData graphsync.ResponseData) error { + data, has := responseData.Extension(td.extensionName) + if has { + receivedResponseData = data + } + return nil + }) + if err != nil { + t.Fatal("Error setting up extension") + } + + err = responder.RegisterRequestReceivedHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) { + var has bool + receivedRequestData, has = requestData.Extension(td.extensionName) + if !has { + hookActions.TerminateWithError(errors.New("Missing extension")) + } else { + hookActions.SendExtensionData(td.extensionResponse) + } + }) + + if err != nil { + t.Fatal("Error setting up extension") + } + + spec := blockChainSelector(blockChainLength) + + progressChan, errChan := requestor.Request(ctx, td.host2.ID(), blockChain.tipLink, spec, td.extension) responses := testutil.CollectResponses(ctx, t, progressChan) errs := testutil.CollectErrors(ctx, t, errChan) @@ -423,7 +242,7 @@ func TestGraphsyncRoundTrip(t *testing.T) { if len(errs) != 0 { t.Fatal("errors during traverse") } - if len(blockStore1) != blockChainLength { + if len(td.blockStore1) != blockChainLength { t.Fatal("did not store all blocks") } @@ -442,6 +261,15 @@ func TestGraphsyncRoundTrip(t *testing.T) { expectedPath = expectedPath + "/0" } } + + // verify extension roundtrip + if !reflect.DeepEqual(receivedRequestData, td.extensionData) { + t.Fatal("did not receive correct extension request data") + } + + if !reflect.DeepEqual(receivedResponseData, td.extensionResponseData) { + t.Fatal("did not receive correct extension response data") + } } // TestRoundTripLargeBlocksSlowNetwork test verifies graphsync continues to work @@ -460,61 +288,213 @@ func TestRoundTripLargeBlocksSlowNetwork(t *testing.T) { ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, 20*time.Second) defer cancel() - mn := mocknet.New(ctx) + td := newGsTestData(ctx, t) + td.mn.SetLinkDefaults(mocknet.LinkOptions{Latency: 100 * time.Millisecond, Bandwidth: 3000000}) + + // initialize graphsync on first node to make requests + requestor := td.GraphSyncHost1() + + // setup receiving peer to just record message coming in + blockChainLength := 40 + blockChain := setupBlockChain(ctx, t, td.storer2, td.bridge, 200000, blockChainLength) + + // initialize graphsync on second node to response to requests + td.GraphSyncHost2() + + spec := blockChainSelector(blockChainLength) + progressChan, errChan := requestor.Request(ctx, td.host2.ID(), blockChain.tipLink, spec) + + responses := testutil.CollectResponses(ctx, t, progressChan) + errs := testutil.CollectErrors(ctx, t, errChan) + + if len(responses) != blockChainLength*2 { + t.Fatal("did not traverse all nodes") + } + if len(errs) != 0 { + t.Fatal("errors during traverse") + } +} + +type gsTestData struct { + mn mocknet.Mocknet + ctx context.Context + host1 host.Host + host2 host.Host + gsnet1 gsnet.GraphSyncNetwork + gsnet2 gsnet.GraphSyncNetwork + blockStore1, blockStore2 map[ipld.Link][]byte + loader1, loader2 ipld.Loader + storer1, storer2 ipld.Storer + bridge ipldbridge.IPLDBridge + extensionData []byte + extensionName graphsync.ExtensionName + extension graphsync.ExtensionData + extensionResponseData []byte + extensionResponse graphsync.ExtensionData +} +func newGsTestData(ctx context.Context, t *testing.T) *gsTestData { + td := &gsTestData{ctx: ctx} + td.mn = mocknet.New(ctx) + var err error // setup network - host1, err := mn.GenPeer() + td.host1, err = td.mn.GenPeer() if err != nil { t.Fatal("error generating host") } - host2, err := mn.GenPeer() + td.host2, err = td.mn.GenPeer() if err != nil { t.Fatal("error generating host") } - mn.SetLinkDefaults(mocknet.LinkOptions{Latency: 100 * time.Millisecond, Bandwidth: 3000000}) - err = mn.LinkAll() + err = td.mn.LinkAll() if err != nil { t.Fatal("error linking hosts") } - gsnet1 := gsnet.NewFromLibp2pHost(host1) + td.gsnet1 = gsnet.NewFromLibp2pHost(td.host1) + td.gsnet2 = gsnet.NewFromLibp2pHost(td.host2) + td.blockStore1 = make(map[ipld.Link][]byte) + td.loader1, td.storer1 = testbridge.NewMockStore(td.blockStore1) + td.blockStore2 = make(map[ipld.Link][]byte) + td.loader2, td.storer2 = testbridge.NewMockStore(td.blockStore2) + td.bridge = ipldbridge.NewIPLDBridge() + // setup extension handlers + td.extensionData = testutil.RandomBytes(100) + td.extensionName = graphsync.ExtensionName("AppleSauce/McGee") + td.extension = graphsync.ExtensionData{ + Name: td.extensionName, + Data: td.extensionData, + } + td.extensionResponseData = testutil.RandomBytes(100) + td.extensionResponse = graphsync.ExtensionData{ + Name: td.extensionName, + Data: td.extensionResponseData, + } - blockStore1 := make(map[ipld.Link][]byte) - loader1, storer1 := testbridge.NewMockStore(blockStore1) - bridge1 := ipldbridge.NewIPLDBridge() + return td +} - // initialize graphsync on first node to make requests - requestor := New(ctx, gsnet1, bridge1, loader1, storer1) +func (td *gsTestData) GraphSyncHost1() graphsync.GraphExchange { + return New(td.ctx, td.gsnet1, td.bridge, td.loader1, td.storer1) +} - // setup receiving peer to just record message coming in - gsnet2 := gsnet.NewFromLibp2pHost(host2) +func (td *gsTestData) GraphSyncHost2() graphsync.GraphExchange { - blockStore2 := make(map[ipld.Link][]byte) - loader2, storer2 := testbridge.NewMockStore(blockStore2) - bridge2 := ipldbridge.NewIPLDBridge() + return New(td.ctx, td.gsnet2, td.bridge, td.loader2, td.storer2) +} - blockChainLength := 40 - blockChain := setupBlockChain(ctx, t, storer2, bridge2, 200000, blockChainLength) +type receivedMessage struct { + message gsmsg.GraphSyncMessage + sender peer.ID +} - // initialize graphsync on second node to response to requests - New(ctx, gsnet2, bridge2, loader2, storer2) +// Receiver is an interface for receiving messages from the GraphSyncNetwork. +type receiver struct { + messageReceived chan receivedMessage +} - ssb := builder.NewSelectorSpecBuilder(ipldfree.NodeBuilder()) - spec := ssb.ExploreRecursive(ipldselector.RecursionLimitDepth(blockChainLength), - ssb.ExploreFields(func(efsb ipldbridge.ExploreFieldsSpecBuilder) { - efsb.Insert("Parents", ssb.ExploreAll( - ssb.ExploreRecursiveEdge())) - })).Node() +func (r *receiver) ReceiveMessage( + ctx context.Context, + sender peer.ID, + incoming gsmsg.GraphSyncMessage) { - progressChan, errChan := requestor.Request(ctx, host2.ID(), blockChain.tipLink, spec) + select { + case <-ctx.Done(): + case r.messageReceived <- receivedMessage{incoming, sender}: + } +} - responses := testutil.CollectResponses(ctx, t, progressChan) - errs := testutil.CollectErrors(ctx, t, errChan) +func (r *receiver) ReceiveError(err error) { +} - if len(responses) != blockChainLength*2 { - t.Fatal("did not traverse all nodes") +func (r *receiver) Connected(p peer.ID) { +} + +func (r *receiver) Disconnected(p peer.ID) { +} + +type blockChain struct { + genisisNode ipld.Node + genisisLink ipld.Link + middleNodes []ipld.Node + middleLinks []ipld.Link + tipNode ipld.Node + tipLink ipld.Link +} + +func createBlock(nb ipldbridge.NodeBuilder, parents []ipld.Link, size int64) ipld.Node { + return nb.CreateMap(func(mb ipldbridge.MapBuilder, knb ipldbridge.NodeBuilder, vnb ipldbridge.NodeBuilder) { + mb.Insert(knb.CreateString("Parents"), vnb.CreateList(func(lb ipldbridge.ListBuilder, vnb ipldbridge.NodeBuilder) { + for _, parent := range parents { + lb.Append(vnb.CreateLink(parent)) + } + })) + mb.Insert(knb.CreateString("Messages"), vnb.CreateList(func(lb ipldbridge.ListBuilder, vnb ipldbridge.NodeBuilder) { + lb.Append(vnb.CreateBytes(testutil.RandomBytes(size))) + })) + }) +} + +func setupBlockChain( + ctx context.Context, + t *testing.T, + storer ipldbridge.Storer, + bridge ipldbridge.IPLDBridge, + size int64, + blockChainLength int) *blockChain { + linkBuilder := cidlink.LinkBuilder{Prefix: cid.NewPrefixV1(cid.DagCBOR, mh.SHA2_256)} + var genisisNode ipld.Node + err := fluent.Recover(func() { + nb := fluent.WrapNodeBuilder(ipldfree.NodeBuilder()) + genisisNode = createBlock(nb, []ipld.Link{}, size) + }) + if err != nil { + t.Fatal("Error creating genesis block") } - if len(errs) != 0 { - t.Fatal("errors during traverse") + genesisLink, err := linkBuilder.Build(ctx, ipldbridge.LinkContext{}, genisisNode, storer) + if err != nil { + t.Fatal("Error creating link to genesis block") + } + parent := genesisLink + middleNodes := make([]ipld.Node, 0, blockChainLength-2) + middleLinks := make([]ipld.Link, 0, blockChainLength-2) + for i := 0; i < blockChainLength-2; i++ { + var node ipld.Node + err := fluent.Recover(func() { + nb := fluent.WrapNodeBuilder(ipldfree.NodeBuilder()) + node = createBlock(nb, []ipld.Link{parent}, size) + }) + if err != nil { + t.Fatal("Error creating middle block") + } + middleNodes = append(middleNodes, node) + link, err := linkBuilder.Build(ctx, ipldbridge.LinkContext{}, node, storer) + if err != nil { + t.Fatal("Error creating link to middle block") + } + middleLinks = append(middleLinks, link) + parent = link } + var tipNode ipld.Node + err = fluent.Recover(func() { + nb := fluent.WrapNodeBuilder(ipldfree.NodeBuilder()) + tipNode = createBlock(nb, []ipld.Link{parent}, size) + }) + if err != nil { + t.Fatal("Error creating tip block") + } + tipLink, err := linkBuilder.Build(ctx, ipldbridge.LinkContext{}, tipNode, storer) + if err != nil { + t.Fatal("Error creating link to tip block") + } + return &blockChain{genisisNode, genesisLink, middleNodes, middleLinks, tipNode, tipLink} +} + +func blockChainSelector(blockChainLength int) ipld.Node { + ssb := builder.NewSelectorSpecBuilder(ipldfree.NodeBuilder()) + return ssb.ExploreRecursive(ipldselector.RecursionLimitDepth(blockChainLength), + ssb.ExploreFields(func(efsb ipldbridge.ExploreFieldsSpecBuilder) { + efsb.Insert("Parents", ssb.ExploreAll( + ssb.ExploreRecursiveEdge())) + })).Node() } diff --git a/requestmanager/requestmanager.go b/requestmanager/requestmanager.go index b3dfd2b3..4590265f 100644 --- a/requestmanager/requestmanager.go +++ b/requestmanager/requestmanager.go @@ -32,6 +32,10 @@ type inProgressRequestStatus struct { networkError chan error } +type responseHook struct { + hook graphsync.OnResponseReceivedHook +} + // PeerHandler is an interface that can send requests to peers type PeerHandler interface { SendRequest(p peer.ID, graphSyncRequest gsmsg.GraphSyncRequest) @@ -61,6 +65,7 @@ type RequestManager struct { // dont touch out side of run loop nextRequestID graphsync.RequestID inProgressRequestStatuses map[graphsync.RequestID]*inProgressRequestStatus + responseHooks []responseHook } type requestManagerMessage interface { @@ -197,6 +202,15 @@ func (rm *RequestManager) ProcessResponses(p peer.ID, responses []gsmsg.GraphSyn } } +// RegisterHook registers an extension to processincoming responses +func (rm *RequestManager) RegisterHook( + hook graphsync.OnResponseReceivedHook) { + select { + case rm.messages <- &responseHook{hook}: + case <-rm.ctx.Done(): + } +} + // Startup starts processing for the WantManager. func (rm *RequestManager) Startup() { go rm.run() @@ -266,11 +280,16 @@ func (crm *cancelRequestMessage) handle(rm *RequestManager) { func (prm *processResponseMessage) handle(rm *RequestManager) { filteredResponses := rm.filterResponsesForPeer(prm.responses, prm.p) + filteredResponses = rm.processExtensions(filteredResponses, prm.p) responseMetadata := metadataForResponses(filteredResponses, rm.ipldBridge) rm.asyncLoader.ProcessResponse(responseMetadata, prm.blks) rm.processTerminations(filteredResponses) } +func (rh *responseHook) handle(rm *RequestManager) { + rm.responseHooks = append(rm.responseHooks, *rh) +} + func (rm *RequestManager) filterResponsesForPeer(responses []gsmsg.GraphSyncResponse, p peer.ID) []gsmsg.GraphSyncResponse { responsesForPeer := make([]gsmsg.GraphSyncResponse, 0, len(responses)) for _, response := range responses { @@ -283,6 +302,34 @@ func (rm *RequestManager) filterResponsesForPeer(responses []gsmsg.GraphSyncResp return responsesForPeer } +func (rm *RequestManager) processExtensions(responses []gsmsg.GraphSyncResponse, p peer.ID) []gsmsg.GraphSyncResponse { + remainingResponses := make([]gsmsg.GraphSyncResponse, 0, len(responses)) + for _, response := range responses { + success := rm.processExtensionsForResponse(p, response) + if success { + remainingResponses = append(remainingResponses, response) + } + } + return remainingResponses +} + +func (rm *RequestManager) processExtensionsForResponse(p peer.ID, response gsmsg.GraphSyncResponse) bool { + for _, responseHook := range rm.responseHooks { + err := responseHook.hook(p, response) + if err != nil { + requestStatus := rm.inProgressRequestStatuses[response.RequestID()] + responseError := rm.generateResponseErrorFromStatus(graphsync.RequestFailedUnknown) + select { + case requestStatus.networkError <- responseError: + case <-requestStatus.ctx.Done(): + } + requestStatus.cancelFn() + return false + } + } + return true +} + func (rm *RequestManager) processTerminations(responses []gsmsg.GraphSyncResponse) { for _, response := range responses { if gsmsg.IsTerminalResponseCode(response.Status()) { diff --git a/requestmanager/requestmanager_test.go b/requestmanager/requestmanager_test.go index 48ff6f67..926cd747 100644 --- a/requestmanager/requestmanager_test.go +++ b/requestmanager/requestmanager_test.go @@ -2,6 +2,7 @@ package requestmanager import ( "context" + "errors" "fmt" "reflect" "sync" @@ -631,7 +632,19 @@ func TestEncodingExtensions(t *testing.T) { Name: extensionName2, Data: extensionData2, } - _, _ = requestManager.SendRequest(requestCtx, peers[0], root, selector, extension1, extension2) + + expectedError := make(chan error, 2) + receivedExtensionData := make(chan []byte, 2) + hook := func(p peer.ID, responseData graphsync.ResponseData) error { + data, has := responseData.Extension(extensionName1) + if !has { + t.Fatal("Did not receive extension data in response") + } + receivedExtensionData <- data + return <-expectedError + } + requestManager.RegisterHook(hook) + returnedResponseChan, returnedErrorChan := requestManager.SendRequest(requestCtx, peers[0], root, selector, extension1, extension2) rr := readNNetworkRequests(requestCtx, t, requestRecordChan, 1)[0] @@ -646,4 +659,55 @@ func TestEncodingExtensions(t *testing.T) { t.Fatal("Failed to encode first extension") } + t.Run("responding to extensions", func(t *testing.T) { + expectedData := testutil.RandomBytes(100) + firstResponses := []gsmsg.GraphSyncResponse{ + gsmsg.NewResponse(gsr.ID(), + graphsync.PartialResponse, graphsync.ExtensionData{ + Name: graphsync.ExtensionMetadata, + Data: nil, + }, + graphsync.ExtensionData{ + Name: extensionName1, + Data: expectedData, + }, + ), + } + expectedError <- nil + requestManager.ProcessResponses(peers[0], firstResponses, nil) + select { + case <-requestCtx.Done(): + t.Fatal("Should have checked extension but didn't") + case received := <-receivedExtensionData: + if !reflect.DeepEqual(received, expectedData) { + t.Fatal("Did not receive correct extension data from resposne") + } + } + nextExpectedData := testutil.RandomBytes(100) + + secondResponses := []gsmsg.GraphSyncResponse{ + gsmsg.NewResponse(gsr.ID(), + graphsync.PartialResponse, graphsync.ExtensionData{ + Name: graphsync.ExtensionMetadata, + Data: nil, + }, + graphsync.ExtensionData{ + Name: extensionName1, + Data: nextExpectedData, + }, + ), + } + expectedError <- errors.New("a terrible thing happened") + requestManager.ProcessResponses(peers[0], secondResponses, nil) + select { + case <-requestCtx.Done(): + t.Fatal("Should have checked extension but didn't") + case received := <-receivedExtensionData: + if !reflect.DeepEqual(received, nextExpectedData) { + t.Fatal("Did not receive correct extension data from resposne") + } + } + testutil.VerifySingleTerminalError(requestCtx, t, returnedErrorChan) + testutil.VerifyEmptyResponse(requestCtx, t, returnedResponseChan) + }) }