From 0b556bf2f7dca10f6708ca93fa2fc92a773df3e1 Mon Sep 17 00:00:00 2001 From: Hannah Howard Date: Wed, 15 Jun 2022 20:30:18 -0700 Subject: [PATCH] Feat/refactor transport protocol update part 2 (#338) * feat(network): transport versioning and detection Support multiple transports on the libp2p protocol, via different protocol naming, and using libp2p to do protocol negotiation * Update transport/helpers/network/libp2p_impl.go Co-authored-by: Rod Vagg * Update transport/helpers/network/libp2p_impl.go Co-authored-by: Rod Vagg * fix(network): add versions check for legacy transport Co-authored-by: Rod Vagg --- impl/utils.go | 1 - itest/integration_test.go | 4 +- message.go | 30 ++- message/message1_1prime/message.go | 27 ++- message/message1_1prime/message_test.go | 22 +- message/message1_1prime/schema.ipldsch | 7 + message/message1_1prime/transfer_message.go | 5 +- message/message1_1prime/transfer_request.go | 32 ++- .../message1_1prime/transfer_request_test.go | 10 +- message/message1_1prime/transfer_response.go | 27 ++- .../message1_1prime/transfer_response_test.go | 10 +- testutil/faketransport.go | 5 + transport.go | 7 + transport/graphsync/extension/gsextension.go | 2 +- transport/graphsync/graphsync.go | 9 +- transport/helpers/network/interface.go | 18 +- transport/helpers/network/libp2p_impl.go | 228 +++++++++++------- transport/helpers/network/libp2p_impl_test.go | 8 +- 18 files changed, 305 insertions(+), 147 deletions(-) diff --git a/impl/utils.go b/impl/utils.go index 1a2b135..518b9e0 100644 --- a/impl/utils.go +++ b/impl/utils.go @@ -83,4 +83,3 @@ func (m *manager) cancelMessage(chid datatransfer.ChannelID) datatransfer.Messag } return message.CancelResponse(chid.ID) } - diff --git a/itest/integration_test.go b/itest/integration_test.go index 7be07a7..f13e69e 100644 --- a/itest/integration_test.go +++ b/itest/integration_test.go @@ -1779,7 +1779,7 @@ func TestRespondingToPushGraphsyncRequests(t *testing.T) { r := &receiver{ messageReceived: make(chan receivedMessage), } - dtnet2.SetDelegate("graphsync", r) + dtnet2.SetDelegate(datatransfer.LegacyTransportID, []datatransfer.Version{datatransfer.LegacyTransportVersion}, r) gsr := &fakeGraphSyncReceiver{ receivedMessages: make(chan receivedGraphSyncMessage), @@ -1857,7 +1857,7 @@ func TestResponseHookWhenExtensionNotFound(t *testing.T) { r := &receiver{ messageReceived: make(chan receivedMessage), } - dtnet2.SetDelegate("graphsync", r) + dtnet2.SetDelegate(datatransfer.LegacyTransportID, []datatransfer.Version{datatransfer.LegacyTransportVersion}, r) gsr := &fakeGraphSyncReceiver{ receivedMessages: make(chan receivedGraphSyncMessage), diff --git a/message.go b/message.go index 9eed49f..b44b43e 100644 --- a/message.go +++ b/message.go @@ -11,41 +11,41 @@ import ( "github.com/ipld/go-ipld-prime/datamodel" ) -type MessageVersion struct { +type Version struct { Major uint64 Minor uint64 Patch uint64 } -func (mv MessageVersion) String() string { +func (mv Version) String() string { return fmt.Sprintf("%d.%d.%d", mv.Major, mv.Minor, mv.Patch) } // MessageVersionFromString parses a string into a message version -func MessageVersionFromString(versionString string) (MessageVersion, error) { +func MessageVersionFromString(versionString string) (Version, error) { versions := strings.Split(versionString, ".") if len(versions) != 3 { - return MessageVersion{}, errors.New("not a version string") + return Version{}, errors.New("not a version string") } major, err := strconv.ParseUint(versions[0], 10, 0) if err != nil { - return MessageVersion{}, errors.New("unable to parse major version") + return Version{}, errors.New("unable to parse major version") } minor, err := strconv.ParseUint(versions[1], 10, 0) if err != nil { - return MessageVersion{}, errors.New("unable to parse major version") + return Version{}, errors.New("unable to parse major version") } patch, err := strconv.ParseUint(versions[2], 10, 0) if err != nil { - return MessageVersion{}, errors.New("unable to parse major version") + return Version{}, errors.New("unable to parse major version") } - return MessageVersion{Major: major, Minor: minor, Patch: patch}, nil + return Version{Major: major, Minor: minor, Patch: patch}, nil } var ( // DataTransfer1_2 is the identifier for the current // supported version of data-transfer - DataTransfer1_2 MessageVersion = MessageVersion{1, 2, 0} + DataTransfer1_2 Version = Version{1, 2, 0} ) // Message is a message for the data transfer protocol @@ -60,8 +60,16 @@ type Message interface { TransferID() TransferID ToNet(w io.Writer) error ToIPLD() datamodel.Node - MessageForVersion(targetProtocol MessageVersion) (newMsg Message, err error) - WrappedForTransport(transportID TransportID) Message + MessageForVersion(targetProtocol Version) (newMsg Message, err error) + Version() Version + WrappedForTransport(transportID TransportID, transportVersion Version) TransportedMessage +} + +// TransportedMessage is a message that can also report how it was transported +type TransportedMessage interface { + Message + TransportID() TransportID + TransportVersion() Version } // Request is a response message for the data transfer protocol diff --git a/message/message1_1prime/message.go b/message/message1_1prime/message.go index 7690dbe..641686b 100644 --- a/message/message1_1prime/message.go +++ b/message/message1_1prime/message.go @@ -212,15 +212,34 @@ func fromMessage(tresp *TransferMessage1_1) (datatransfer.Message, error) { return tresp.Response, nil } +func fromWrappedMessage(wtresp *WrappedTransferMessage1_1) (datatransfer.TransportedMessage, error) { + tresp := wtresp.Message + if (tresp.IsRequest && tresp.Request == nil) || (!tresp.IsRequest && tresp.Response == nil) { + return nil, xerrors.Errorf("invalid/malformed message") + } + + if tresp.IsRequest { + return &WrappedTransferRequest1_1{ + tresp.Request, + wtresp.TransportVersion, + wtresp.TransportID, + }, nil + } + return &WrappedTransferResponse1_1{ + tresp.Response, + wtresp.TransportID, + wtresp.TransportVersion, + }, nil +} + // FromNetWrraped can read a network stream to deserialize a message + transport ID -func FromNetWrapped(r io.Reader) (datatransfer.TransportID, datatransfer.Message, error) { +func FromNetWrapped(r io.Reader) (datatransfer.TransportedMessage, error) { tm, err := bindnodeRegistry.TypeFromReader(r, &WrappedTransferMessage1_1{}, dagcbor.Decode) if err != nil { - return "", nil, err + return nil, err } wtresp := tm.(*WrappedTransferMessage1_1) - msg, err := fromMessage(&wtresp.Message) - return datatransfer.TransportID(wtresp.TransportID), msg, err + return fromWrappedMessage(wtresp) } // FromNet can read a network stream to deserialize a GraphSyncMessage diff --git a/message/message1_1prime/message_test.go b/message/message1_1prime/message_test.go index bc1cd84..33050de 100644 --- a/message/message1_1prime/message_test.go +++ b/message/message1_1prime/message_test.go @@ -510,6 +510,7 @@ func TestToNetFromNetEquivalency(t *testing.T) { }) t.Run("round-trip with wrapping", func(t *testing.T) { transportID := datatransfer.TransportID("applesauce") + transportVersion := datatransfer.Version{Major: 1, Minor: 5, Patch: 0} baseCid := testutil.GenerateCids(1)[0] selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() isPull := false @@ -519,15 +520,16 @@ func TestToNetFromNetEquivalency(t *testing.T) { voucherResult := testutil.NewTestTypedVoucher() request, err := message1_1.NewRequest(id, false, isPull, &voucher, baseCid, selector) require.NoError(t, err) - wrequest := request.WrappedForTransport(transportID) + wrequest := request.WrappedForTransport(transportID, transportVersion) buf := new(bytes.Buffer) err = wrequest.ToNet(buf) require.NoError(t, err) require.Greater(t, buf.Len(), 0) - receivedTransportID, deserialized, err := message1_1.FromNetWrapped(buf) + deserialized, err := message1_1.FromNetWrapped(buf) require.NoError(t, err) - require.Equal(t, transportID, receivedTransportID) + require.Equal(t, transportID, deserialized.TransportID()) + require.Equal(t, transportVersion, deserialized.TransportVersion()) deserializedRequest, ok := deserialized.(datatransfer.Request) require.True(t, ok) @@ -541,12 +543,13 @@ func TestToNetFromNetEquivalency(t *testing.T) { response, err := message1_1.NewResponse(id, accepted, false, &voucherResult) require.NoError(t, err) - wresponse := response.WrappedForTransport(transportID) + wresponse := response.WrappedForTransport(transportID, transportVersion) err = wresponse.ToNet(buf) require.NoError(t, err) - receivedTransportID, deserialized, err = message1_1.FromNetWrapped(buf) + deserialized, err = message1_1.FromNetWrapped(buf) require.NoError(t, err) - require.Equal(t, transportID, receivedTransportID) + require.Equal(t, transportID, deserialized.TransportID()) + require.Equal(t, transportVersion, deserialized.TransportVersion()) deserializedResponse, ok := deserialized.(datatransfer.Response) require.True(t, ok) @@ -559,12 +562,13 @@ func TestToNetFromNetEquivalency(t *testing.T) { testutil.AssertEqualTestVoucherResult(t, response, deserializedResponse) request = message1_1.CancelRequest(id) - wrequest = request.WrappedForTransport(transportID) + wrequest = request.WrappedForTransport(transportID, transportVersion) err = wrequest.ToNet(buf) require.NoError(t, err) - receivedTransportID, deserialized, err = message1_1.FromNetWrapped(buf) + deserialized, err = message1_1.FromNetWrapped(buf) require.NoError(t, err) - require.Equal(t, transportID, receivedTransportID) + require.Equal(t, transportID, deserialized.TransportID()) + require.Equal(t, transportVersion, deserialized.TransportVersion()) deserializedRequest, ok = deserialized.(datatransfer.Request) require.True(t, ok) diff --git a/message/message1_1prime/schema.ipldsch b/message/message1_1prime/schema.ipldsch index 06ed207..3268366 100644 --- a/message/message1_1prime/schema.ipldsch +++ b/message/message1_1prime/schema.ipldsch @@ -37,7 +37,14 @@ type TransferMessage1_1 struct { Response nullable TransferResponse } +type Version struct { + Major Int + Minor Int + Patch Int +} representation tuple + type WrappedTransferMessage1_1 struct { TransportID TransportID (rename "ID") + TransportVersion Version (rename "TV") Message TransferMessage1_1 (rename "Msg") } \ No newline at end of file diff --git a/message/message1_1prime/transfer_message.go b/message/message1_1prime/transfer_message.go index 4740fd8..0f28e52 100644 --- a/message/message1_1prime/transfer_message.go +++ b/message/message1_1prime/transfer_message.go @@ -59,8 +59,9 @@ func init() { } type WrappedTransferMessage1_1 struct { - TransportID string - Message TransferMessage1_1 + TransportID string + TransportVersion datatransfer.Version + Message TransferMessage1_1 } func (wtm *WrappedTransferMessage1_1) BindnodeSchema() string { diff --git a/message/message1_1prime/transfer_request.go b/message/message1_1prime/transfer_request.go index 175e599..007d651 100644 --- a/message/message1_1prime/transfer_request.go +++ b/message/message1_1prime/transfer_request.go @@ -29,7 +29,7 @@ type TransferRequest1_1 struct { RestartChannel datatransfer.ChannelID } -func (trq *TransferRequest1_1) MessageForVersion(version datatransfer.MessageVersion) (datatransfer.Message, error) { +func (trq *TransferRequest1_1) MessageForVersion(version datatransfer.Version) (datatransfer.Message, error) { switch version { case datatransfer.DataTransfer1_2: return trq, nil @@ -38,8 +38,16 @@ func (trq *TransferRequest1_1) MessageForVersion(version datatransfer.MessageVer } } -func (trq *TransferRequest1_1) WrappedForTransport(transportID datatransfer.TransportID) datatransfer.Message { - return &WrappedTransferRequest1_1{trq, string(transportID)} +func (trq *TransferRequest1_1) Version() datatransfer.Version { + return datatransfer.DataTransfer1_2 +} + +func (trq *TransferRequest1_1) WrappedForTransport(transportID datatransfer.TransportID, transportVersion datatransfer.Version) datatransfer.TransportedMessage { + return &WrappedTransferRequest1_1{ + TransferRequest1_1: trq, + transportID: string(transportID), + transportVersion: transportVersion, + } } // IsRequest always returns true in this case because this is a transfer request @@ -164,15 +172,25 @@ func (trq *TransferRequest1_1) ToNet(w io.Writer) error { // transport id type WrappedTransferRequest1_1 struct { *TransferRequest1_1 - TransportID string + transportVersion datatransfer.Version + transportID string +} + +func (trq *WrappedTransferRequest1_1) TransportID() datatransfer.TransportID { + return datatransfer.TransportID(trq.transportID) +} + +func (trq *WrappedTransferRequest1_1) TransportVersion() datatransfer.Version { + return trq.transportVersion } -func (trsp *WrappedTransferRequest1_1) toIPLD() schema.TypedNode { +func (trq *WrappedTransferRequest1_1) toIPLD() schema.TypedNode { msg := WrappedTransferMessage1_1{ - TransportID: trsp.TransportID, + TransportID: trq.transportID, + TransportVersion: trq.transportVersion, Message: TransferMessage1_1{ IsRequest: true, - Request: trsp.TransferRequest1_1, + Request: trq.TransferRequest1_1, Response: nil, }, } diff --git a/message/message1_1prime/transfer_request_test.go b/message/message1_1prime/transfer_request_test.go index 88a9894..410917b 100644 --- a/message/message1_1prime/transfer_request_test.go +++ b/message/message1_1prime/transfer_request_test.go @@ -40,14 +40,12 @@ func TestRequestMessageForVersion(t *testing.T) { require.Equal(t, selector, n) require.Equal(t, testutil.TestVoucherType, req.VoucherType()) - wrappedOut12 := out12.WrappedForTransport(datatransfer.LegacyTransportID) - require.Equal(t, &message1_1.WrappedTransferRequest1_1{ - TransferRequest1_1: request.(*message1_1.TransferRequest1_1), - TransportID: string(datatransfer.LegacyTransportID), - }, wrappedOut12) + wrappedOut12 := out12.WrappedForTransport(datatransfer.LegacyTransportID, datatransfer.LegacyTransportVersion) + require.Equal(t, datatransfer.LegacyTransportID, wrappedOut12.TransportID()) + require.Equal(t, datatransfer.LegacyTransportVersion, wrappedOut12.TransportVersion()) // random protocol should fail - _, err = request.MessageForVersion(datatransfer.MessageVersion{ + _, err = request.MessageForVersion(datatransfer.Version{ Major: rand.Uint64(), Minor: rand.Uint64(), Patch: rand.Uint64(), diff --git a/message/message1_1prime/transfer_response.go b/message/message1_1prime/transfer_response.go index 77cf36a..3e3c41f 100644 --- a/message/message1_1prime/transfer_response.go +++ b/message/message1_1prime/transfer_response.go @@ -87,7 +87,7 @@ func (trsp *TransferResponse1_1) EmptyVoucherResult() bool { return trsp.VoucherTypeIdentifier == datatransfer.EmptyTypeIdentifier } -func (trsp *TransferResponse1_1) MessageForVersion(version datatransfer.MessageVersion) (datatransfer.Message, error) { +func (trsp *TransferResponse1_1) MessageForVersion(version datatransfer.Version) (datatransfer.Message, error) { switch version { case datatransfer.DataTransfer1_2: return trsp, nil @@ -96,8 +96,16 @@ func (trsp *TransferResponse1_1) MessageForVersion(version datatransfer.MessageV } } -func (trsp *TransferResponse1_1) WrappedForTransport(transportID datatransfer.TransportID) datatransfer.Message { - return &WrappedTransferResponse1_1{trsp, string(transportID)} +func (trsp *TransferResponse1_1) Version() datatransfer.Version { + return datatransfer.DataTransfer1_2 +} + +func (trsp *TransferResponse1_1) WrappedForTransport(transportID datatransfer.TransportID, transportVersion datatransfer.Version) datatransfer.TransportedMessage { + return &WrappedTransferResponse1_1{ + TransferResponse1_1: trsp, + transportID: string(transportID), + transportVersion: transportVersion, + } } func (trsp *TransferResponse1_1) toIPLD() schema.TypedNode { msg := TransferMessage1_1{ @@ -121,12 +129,21 @@ func (trsp *TransferResponse1_1) ToNet(w io.Writer) error { // transport id type WrappedTransferResponse1_1 struct { *TransferResponse1_1 - TransportID string + transportID string + transportVersion datatransfer.Version +} + +func (trsp *WrappedTransferResponse1_1) TransportID() datatransfer.TransportID { + return datatransfer.TransportID(trsp.transportID) +} +func (trsp *WrappedTransferResponse1_1) TransportVersion() datatransfer.Version { + return trsp.transportVersion } func (trsp *WrappedTransferResponse1_1) toIPLD() schema.TypedNode { msg := WrappedTransferMessage1_1{ - TransportID: trsp.TransportID, + TransportID: trsp.transportID, + TransportVersion: trsp.transportVersion, Message: TransferMessage1_1{ IsRequest: false, Request: nil, diff --git a/message/message1_1prime/transfer_response_test.go b/message/message1_1prime/transfer_response_test.go index f29773b..71fcf70 100644 --- a/message/message1_1prime/transfer_response_test.go +++ b/message/message1_1prime/transfer_response_test.go @@ -28,14 +28,12 @@ func TestResponseMessageForVersion(t *testing.T) { require.Equal(t, testutil.TestVoucherType, resp.VoucherResultType()) require.True(t, resp.IsValidationResult()) - wrappedOut := out.WrappedForTransport(datatransfer.LegacyTransportID) - require.Equal(t, &message1_1.WrappedTransferResponse1_1{ - TransferResponse1_1: response.(*message1_1.TransferResponse1_1), - TransportID: string(datatransfer.LegacyTransportID), - }, wrappedOut) + wrappedOut := out.WrappedForTransport(datatransfer.LegacyTransportID, datatransfer.LegacyTransportVersion) + require.Equal(t, datatransfer.LegacyTransportID, wrappedOut.TransportID()) + require.Equal(t, datatransfer.LegacyTransportVersion, wrappedOut.TransportVersion()) // random protocol should fail - _, err = response.MessageForVersion(datatransfer.MessageVersion{ + _, err = response.MessageForVersion(datatransfer.Version{ Major: rand.Uint64(), Minor: rand.Uint64(), Patch: rand.Uint64(), diff --git a/testutil/faketransport.go b/testutil/faketransport.go index f5843ee..7e0d20c 100644 --- a/testutil/faketransport.go +++ b/testutil/faketransport.go @@ -59,6 +59,11 @@ func (ft *FakeTransport) ID() datatransfer.TransportID { return "fake" } +// Versions indicates what versions of this transport are supported +func (ft *FakeTransport) Versions() []datatransfer.Version { + return []datatransfer.Version{{Major: 1, Minor: 1, Patch: 0}} +} + // Capabilities tells datatransfer what kinds of capabilities this transport supports func (ft *FakeTransport) Capabilities() datatransfer.TransportCapabilities { return datatransfer.TransportCapabilities{ diff --git a/transport.go b/transport.go index ecaed68..546c257 100644 --- a/transport.go +++ b/transport.go @@ -13,6 +13,10 @@ type TransportID string // i.e. graphsync const LegacyTransportID TransportID = "graphsync" +// LegacyTransportVersion is the only transport version for the fil/data-transfer protocol -- +// i.e. graphsync 1.0.0 +var LegacyTransportVersion Version = Version{1, 0, 0} + // EventsHandler are semantic data transfer events that happen as a result of transport events type EventsHandler interface { // ChannelState queries for the current channel state @@ -104,6 +108,9 @@ type Transport interface { // ID is a unique identifier for this transport ID() TransportID + // Versions indicates what versions of this transport are supported + Versions() []Version + // Capabilities tells datatransfer what kinds of capabilities this transport supports Capabilities() TransportCapabilities // OpenChannel opens a channel on a given transport to move data back and forth. diff --git a/transport/graphsync/extension/gsextension.go b/transport/graphsync/extension/gsextension.go index f161ee1..28aeb66 100644 --- a/transport/graphsync/extension/gsextension.go +++ b/transport/graphsync/extension/gsextension.go @@ -20,7 +20,7 @@ const ( ) // ProtocolMap maps graphsync extensions to their libp2p protocols -var ProtocolMap = map[graphsync.ExtensionName]datatransfer.MessageVersion{ +var ProtocolMap = map[graphsync.ExtensionName]datatransfer.Version{ ExtensionIncomingRequest1_1: datatransfer.DataTransfer1_2, ExtensionOutgoingBlock1_1: datatransfer.DataTransfer1_2, ExtensionDataTransfer1_1: datatransfer.DataTransfer1_2, diff --git a/transport/graphsync/graphsync.go b/transport/graphsync/graphsync.go index 06dcd32..07fe307 100644 --- a/transport/graphsync/graphsync.go +++ b/transport/graphsync/graphsync.go @@ -23,6 +23,7 @@ import ( var log = logging.Logger("dt_graphsync") var transportID datatransfer.TransportID = "graphsync" +var supportedVersions = []datatransfer.Version{{Major: 1, Minor: 0, Patch: 0}} // When restarting a data transfer, we cancel the existing graphsync request // before opening a new one. @@ -109,6 +110,10 @@ func (t *Transport) ID() datatransfer.TransportID { return transportID } +func (t *Transport) Versions() []datatransfer.Version { + return supportedVersions +} + func (t *Transport) Capabilities() datatransfer.TransportCapabilities { return datatransfer.TransportCapabilities{ Pausable: true, @@ -142,7 +147,7 @@ func (t *Transport) RestartChannel( req datatransfer.Request) error { log.Debugf("%s: re-establishing connection to %s", channelState.ChannelID(), channelState.OtherPeer()) start := time.Now() - err := t.dtNet.ConnectWithRetry(ctx, channelState.OtherPeer()) + err := t.dtNet.ConnectWithRetry(ctx, channelState.OtherPeer(), transportID) if err != nil { return xerrors.Errorf("%s: failed to reconnect to peer %s after %s: %w", channelState.ChannelID(), channelState.OtherPeer(), time.Since(start), err) @@ -292,7 +297,7 @@ func (t *Transport) SetEventHandler(events datatransfer.EventsHandler) error { t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterNetworkErrorListener(t.gsNetworkSendErrorListener)) t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterReceiverNetworkErrorListener(t.gsNetworkReceiveErrorListener)) - t.dtNet.SetDelegate(transportID, &receiver{t}) + t.dtNet.SetDelegate(transportID, supportedVersions, &receiver{t}) return nil } diff --git a/transport/helpers/network/interface.go b/transport/helpers/network/interface.go index 2ed9030..022a1a3 100644 --- a/transport/helpers/network/interface.go +++ b/transport/helpers/network/interface.go @@ -20,11 +20,19 @@ const ( ProtocolDataTransfer1_2 protocol.ID = "/datatransfer/1.2.0" ) +// ProtocolDescription describes how you are connected to a given +// peer on a given transport, if at all +type ProtocolDescription struct { + IsLegacy bool + MessageVersion datatransfer.Version + TransportVersion datatransfer.Version +} + // MessageVersion extracts the message version from the full protocol -func MessageVersion(protocol protocol.ID) (datatransfer.MessageVersion, error) { +func MessageVersion(protocol protocol.ID) (datatransfer.Version, error) { protocolParts := strings.Split(string(protocol), "/") if len(protocolParts) == 0 { - return datatransfer.MessageVersion{}, errors.New("no protocol to parse") + return datatransfer.Version{}, errors.New("no protocol to parse") } return datatransfer.MessageVersionFromString(protocolParts[len(protocolParts)-1]) } @@ -43,7 +51,7 @@ type DataTransferNetwork interface { // SetDelegate registers the Reciver to handle messages received from the // network. - SetDelegate(datatransfer.TransportID, Receiver) + SetDelegate(datatransfer.TransportID, []datatransfer.Version, Receiver) // ConnectTo establishes a connection to the given peer ConnectTo(context.Context, peer.ID) error @@ -51,14 +59,14 @@ type DataTransferNetwork interface { // ConnectWithRetry establishes a connection to the given peer, retrying if // necessary, and opens a stream on the data-transfer protocol to verify // the peer will accept messages on the protocol - ConnectWithRetry(ctx context.Context, p peer.ID) error + ConnectWithRetry(ctx context.Context, p peer.ID, transportID datatransfer.TransportID) error // ID returns the peer id of this libp2p host ID() peer.ID // Protocol returns the protocol version of the peer, connecting to // the peer if necessary - Protocol(context.Context, peer.ID) (protocol.ID, error) + Protocol(context.Context, peer.ID, datatransfer.TransportID) (ProtocolDescription, error) } // Receiver is an interface for receiving messages from the GraphSyncNetwork. diff --git a/transport/helpers/network/libp2p_impl.go b/transport/helpers/network/libp2p_impl.go index 412e934..d4fe0ed 100644 --- a/transport/helpers/network/libp2p_impl.go +++ b/transport/helpers/network/libp2p_impl.go @@ -2,8 +2,10 @@ package network import ( "context" + "errors" "fmt" "io" + "strings" "time" logging "github.com/ipfs/go-log/v2" @@ -16,7 +18,6 @@ import ( "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/trace" - "golang.org/x/xerrors" datatransfer "github.com/filecoin-project/go-data-transfer/v2" "github.com/filecoin-project/go-data-transfer/v2/message" @@ -47,31 +48,35 @@ var defaultDataTransferProtocols = []protocol.ID{ ProtocolFilDataTransfer1_2, } +func isLegacyProtocol(protocol protocol.ID) bool { + return protocol == ProtocolFilDataTransfer1_2 +} + // Option is an option for configuring the libp2p storage market network type Option func(*libp2pDataTransferNetwork) // DataTransferProtocols OVERWRITES the default libp2p protocols we use for data transfer with the given protocols. func DataTransferProtocols(protocols []protocol.ID) Option { - return func(impl *libp2pDataTransferNetwork) { - impl.setDataTransferProtocols(protocols) + return func(dtnet *libp2pDataTransferNetwork) { + dtnet.setDataTransferProtocols(protocols) } } // SendMessageParameters changes the default parameters around sending messages func SendMessageParameters(openStreamTimeout time.Duration, sendMessageTimeout time.Duration) Option { - return func(impl *libp2pDataTransferNetwork) { - impl.sendMessageTimeout = sendMessageTimeout - impl.openStreamTimeout = openStreamTimeout + return func(dtnet *libp2pDataTransferNetwork) { + dtnet.sendMessageTimeout = sendMessageTimeout + dtnet.openStreamTimeout = openStreamTimeout } } // RetryParameters changes the default parameters around connection reopening func RetryParameters(minDuration time.Duration, maxDuration time.Duration, attempts float64, backoffFactor float64) Option { - return func(impl *libp2pDataTransferNetwork) { - impl.maxStreamOpenAttempts = attempts - impl.minAttemptDuration = minDuration - impl.maxAttemptDuration = maxDuration - impl.backoffFactor = backoffFactor + return func(dtnet *libp2pDataTransferNetwork) { + dtnet.maxStreamOpenAttempts = attempts + dtnet.minAttemptDuration = minDuration + dtnet.maxAttemptDuration = maxDuration + dtnet.backoffFactor = backoffFactor } } @@ -86,7 +91,8 @@ func NewFromLibp2pHost(host host.Host, options ...Option) DataTransferNetwork { minAttemptDuration: defaultMinAttemptDuration, maxAttemptDuration: defaultMaxAttemptDuration, backoffFactor: defaultBackoffFactor, - receivers: make(map[datatransfer.TransportID]Receiver), + receivers: make(map[protocol.ID]receiverData), + transportProtocols: make(map[datatransfer.TransportID]transportProtocols), } dataTransferNetwork.setDataTransferProtocols(defaultDataTransferProtocols) @@ -97,44 +103,54 @@ func NewFromLibp2pHost(host host.Host, options ...Option) DataTransferNetwork { return &dataTransferNetwork } +type transportProtocols struct { + protocols []protocol.ID + protocolStrings []string +} + +type receiverData struct { + ProtocolDescription + transportID datatransfer.TransportID + receiver Receiver +} + // libp2pDataTransferNetwork transforms the libp2p host interface, which sends and receives // NetMessage objects, into the data transfer network interface. type libp2pDataTransferNetwork struct { host host.Host // inbound messages from the network are forwarded to the receiver - receivers map[datatransfer.TransportID]Receiver - + receivers map[protocol.ID]receiverData + transportProtocols map[datatransfer.TransportID]transportProtocols openStreamTimeout time.Duration sendMessageTimeout time.Duration maxStreamOpenAttempts float64 minAttemptDuration time.Duration maxAttemptDuration time.Duration dtProtocols []protocol.ID - dtProtocolStrings []string backoffFactor float64 } -func (impl *libp2pDataTransferNetwork) openStream(ctx context.Context, id peer.ID, protocols ...protocol.ID) (network.Stream, error) { +func (dtnet *libp2pDataTransferNetwork) openStream(ctx context.Context, id peer.ID, protocols ...protocol.ID) (network.Stream, error) { b := &backoff.Backoff{ - Min: impl.minAttemptDuration, - Max: impl.maxAttemptDuration, - Factor: impl.backoffFactor, + Min: dtnet.minAttemptDuration, + Max: dtnet.maxAttemptDuration, + Factor: dtnet.backoffFactor, Jitter: true, } start := time.Now() for { - tctx, cancel := context.WithTimeout(ctx, impl.openStreamTimeout) + tctx, cancel := context.WithTimeout(ctx, dtnet.openStreamTimeout) defer cancel() // will use the first among the given protocols that the remote peer supports at := time.Now() - s, err := impl.host.NewStream(tctx, id, protocols...) + s, err := dtnet.host.NewStream(tctx, id, protocols...) if err == nil { nAttempts := b.Attempt() + 1 if b.Attempt() > 0 { log.Debugf("opened stream to %s on attempt %g of %g after %s", - id, nAttempts, impl.maxStreamOpenAttempts, time.Since(start)) + id, nAttempts, dtnet.maxStreamOpenAttempts, time.Since(start)) } return s, err @@ -142,13 +158,13 @@ func (impl *libp2pDataTransferNetwork) openStream(ctx context.Context, id peer.I // b.Attempt() starts from zero nAttempts := b.Attempt() + 1 - if nAttempts >= impl.maxStreamOpenAttempts { - return nil, xerrors.Errorf("exhausted %g attempts but failed to open stream to %s, err: %w", impl.maxStreamOpenAttempts, id, err) + if nAttempts >= dtnet.maxStreamOpenAttempts { + return nil, fmt.Errorf("exhausted %g attempts but failed to open stream to %s, err: %w", dtnet.maxStreamOpenAttempts, id, err) } d := b.Duration() log.Warnf("failed to open stream to %s on attempt %g of %g after %s, waiting %s to try again, err: %s", - id, nAttempts, impl.maxStreamOpenAttempts, time.Since(at), d, err) + id, nAttempts, dtnet.maxStreamOpenAttempts, time.Since(at), d, err) select { case <-ctx.Done(): @@ -176,30 +192,36 @@ func (dtnet *libp2pDataTransferNetwork) SendMessage( )) defer span.End() - s, err := dtnet.openStream(ctx, p, dtnet.dtProtocols...) + + transportProtocols, ok := dtnet.transportProtocols[transportID] + if !ok { + return datatransfer.ErrUnsupported + } + s, err := dtnet.openStream(ctx, p, transportProtocols.protocols...) if err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) return err } - messageVersion, err := MessageVersion(s.Protocol()) - if err != nil { - err = xerrors.Errorf("failed to determine message version for protocol: %w", err) + receiverData, ok := dtnet.receivers[s.Protocol()] + if !ok { + // this shouldn't happen, but let's be careful just in case to avoid a panic + err := errors.New("no receiver set") span.RecordError(err) span.SetStatus(codes.Error, err.Error()) return err } - outgoing, err = outgoing.MessageForVersion(messageVersion) + outgoing, err = outgoing.MessageForVersion(receiverData.MessageVersion) if err != nil { - err = xerrors.Errorf("failed to convert message for protocol: %w", err) + err = fmt.Errorf("failed to convert message for protocol: %w", err) span.RecordError(err) span.SetStatus(codes.Error, err.Error()) return err } - if err = dtnet.msgToStream(ctx, s, transportID, outgoing); err != nil { + if err = dtnet.msgToStream(ctx, s, outgoing, receiverData); err != nil { if err2 := s.Reset(); err2 != nil { log.Error(err) span.RecordError(err2) @@ -214,9 +236,55 @@ func (dtnet *libp2pDataTransferNetwork) SendMessage( return s.Close() } -func (dtnet *libp2pDataTransferNetwork) SetDelegate(transportID datatransfer.TransportID, r Receiver) { - dtnet.receivers[transportID] = r - for _, p := range dtnet.dtProtocols { +func (dtnet *libp2pDataTransferNetwork) SetDelegate(transportID datatransfer.TransportID, versions []datatransfer.Version, r Receiver) { + transportProtocols := transportProtocols{} + for _, dtProtocol := range dtnet.dtProtocols { + messageVersion, _ := MessageVersion(dtProtocol) + if isLegacyProtocol(dtProtocol) { + if transportID == datatransfer.LegacyTransportID { + supportsLegacyVersion := false + for _, version := range versions { + if version == datatransfer.LegacyTransportVersion { + supportsLegacyVersion = true + break + } + } + if !supportsLegacyVersion { + continue + } + dtnet.receivers[dtProtocol] = receiverData{ + ProtocolDescription: ProtocolDescription{ + IsLegacy: true, + TransportVersion: datatransfer.LegacyTransportVersion, + MessageVersion: messageVersion, + }, + transportID: transportID, + receiver: r, + } + transportProtocols.protocols = append(transportProtocols.protocols, dtProtocol) + transportProtocols.protocolStrings = append(transportProtocols.protocolStrings, string(dtProtocol)) + } + } else { + for _, version := range versions { + joinedProtocol := strings.Join([]string{string(dtProtocol), string(transportID), version.String()}, "/") + dtnet.receivers[protocol.ID(joinedProtocol)] = receiverData{ + ProtocolDescription: ProtocolDescription{ + IsLegacy: false, + TransportVersion: version, + MessageVersion: messageVersion, + }, + transportID: transportID, + receiver: r, + } + transportProtocols.protocols = append(transportProtocols.protocols, protocol.ID(joinedProtocol)) + transportProtocols.protocolStrings = append(transportProtocols.protocolStrings, joinedProtocol) + } + } + } + + dtnet.transportProtocols[transportID] = transportProtocols + + for _, p := range transportProtocols.protocols { dtnet.host.SetStreamHandler(p, dtnet.handleNewStream) } } @@ -228,10 +296,14 @@ func (dtnet *libp2pDataTransferNetwork) ConnectTo(ctx context.Context, p peer.ID // ConnectWithRetry establishes a connection to the given peer, retrying if // necessary, and opens a stream on the data-transfer protocol to verify // the peer will accept messages on the protocol -func (dtnet *libp2pDataTransferNetwork) ConnectWithRetry(ctx context.Context, p peer.ID) error { +func (dtnet *libp2pDataTransferNetwork) ConnectWithRetry(ctx context.Context, p peer.ID, transportID datatransfer.TransportID) error { + transportProtocols, ok := dtnet.transportProtocols[transportID] + if !ok { + return datatransfer.ErrUnsupported + } // Open a stream over the data-transfer protocol, to make sure that the // peer is listening on the protocol - s, err := dtnet.openStream(ctx, p, dtnet.dtProtocols...) + s, err := dtnet.openStream(ctx, p, transportProtocols.protocols...) if err != nil { return err } @@ -250,21 +322,20 @@ func (dtnet *libp2pDataTransferNetwork) handleNewStream(s network.Stream) { return } + receiverData, ok := dtnet.receivers[s.Protocol()] + if !ok { + s.Reset() // nolint: errcheck,gosec + return + } p := s.Conn().RemotePeer() + // if we have no transport handler, reset the stream for { - var transportID datatransfer.TransportID var received datatransfer.Message var err error - switch s.Protocol() { - case ProtocolFilDataTransfer1_2: - if dtnet.receivers[datatransfer.LegacyTransportID] == nil { - s.Reset() // nolint: errcheck,gosec - return - } - transportID = datatransfer.LegacyTransportID + if receiverData.IsLegacy { received, err = message.FromNet(s) - case ProtocolDataTransfer1_2: - transportID, received, err = message.FromNetWrapped(s) + } else { + received, err = message.FromNetWrapped(s) } if err != nil { @@ -275,12 +346,6 @@ func (dtnet *libp2pDataTransferNetwork) handleNewStream(s network.Stream) { return } - // if we have no transport handler, reset the stream - if dtnet.receivers[transportID] == nil { - s.Reset() // nolint: errcheck,gosec - return - } - ctx := context.Background() log.Debugf("net handleNewStream from %s", p) @@ -288,15 +353,15 @@ func (dtnet *libp2pDataTransferNetwork) handleNewStream(s network.Stream) { receivedRequest, ok := received.(datatransfer.Request) if ok { if receivedRequest.IsRestartExistingChannelRequest() { - dtnet.receivers[transportID].ReceiveRestartExistingChannelRequest(ctx, p, receivedRequest) + receiverData.receiver.ReceiveRestartExistingChannelRequest(ctx, p, receivedRequest) } else { - dtnet.receivers[transportID].ReceiveRequest(ctx, p, receivedRequest) + receiverData.receiver.ReceiveRequest(ctx, p, receivedRequest) } } } else { receivedResponse, ok := received.(datatransfer.Response) if ok { - dtnet.receivers[transportID].ReceiveResponse(ctx, p, receivedResponse) + receiverData.receiver.ReceiveResponse(ctx, p, receivedResponse) } } } @@ -314,7 +379,7 @@ func (dtnet *libp2pDataTransferNetwork) Unprotect(id peer.ID, tag string) bool { return dtnet.host.ConnManager().Unprotect(id, tag) } -func (dtnet *libp2pDataTransferNetwork) msgToStream(ctx context.Context, s network.Stream, transportID datatransfer.TransportID, msg datatransfer.Message) error { +func (dtnet *libp2pDataTransferNetwork) msgToStream(ctx context.Context, s network.Stream, msg datatransfer.Message, receiverData receiverData) error { if msg.IsRequest() { log.Debugf("Outgoing request message for transfer ID: %d", msg.TransferID()) } @@ -332,15 +397,8 @@ func (dtnet *libp2pDataTransferNetwork) msgToStream(ctx context.Context, s netwo } }() - switch s.Protocol() { - case ProtocolFilDataTransfer1_2: - if transportID != datatransfer.LegacyTransportID { - return fmt.Errorf("cannot send messages for transports other than graphsync on legacy protocol") - } - case ProtocolDataTransfer1_2: - msg = msg.WrappedForTransport(transportID) - default: - return fmt.Errorf("unrecognized protocol on remote: %s", s.Protocol()) + if !receiverData.IsLegacy { + msg = msg.WrappedForTransport(receiverData.transportID, receiverData.TransportVersion) } if err := msg.ToNet(s); err != nil { @@ -350,35 +408,41 @@ func (dtnet *libp2pDataTransferNetwork) msgToStream(ctx context.Context, s netwo return nil } -func (impl *libp2pDataTransferNetwork) Protocol(ctx context.Context, id peer.ID) (protocol.ID, error) { +func (dtnet *libp2pDataTransferNetwork) Protocol(ctx context.Context, id peer.ID, transportID datatransfer.TransportID) (ProtocolDescription, error) { + transportProtocols, ok := dtnet.transportProtocols[transportID] + if !ok { + return ProtocolDescription{}, datatransfer.ErrUnsupported + } + // Check the cache for the peer's protocol version - firstProto, err := impl.host.Peerstore().FirstSupportedProtocol(id, impl.dtProtocolStrings...) + firstProto, err := dtnet.host.Peerstore().FirstSupportedProtocol(id, transportProtocols.protocolStrings...) if err != nil { - return "", err + return ProtocolDescription{}, err } if firstProto != "" { - return protocol.ID(firstProto), nil + receiverData, ok := dtnet.receivers[protocol.ID(firstProto)] + if !ok { + return ProtocolDescription{}, err + } + return receiverData.ProtocolDescription, nil } // The peer's protocol version is not in the cache, so connect to the peer. // Note that when the stream is opened, the peer's protocol will be added // to the cache. - s, err := impl.openStream(ctx, id, impl.dtProtocols...) + s, err := dtnet.openStream(ctx, id, dtnet.dtProtocols...) if err != nil { - return "", err + return ProtocolDescription{}, err } _ = s.Close() - - return s.Protocol(), nil + receiverData, ok := dtnet.receivers[s.Protocol()] + if !ok { + return ProtocolDescription{}, err + } + return receiverData.ProtocolDescription, nil } -func (impl *libp2pDataTransferNetwork) setDataTransferProtocols(protocols []protocol.ID) { - impl.dtProtocols = append([]protocol.ID{}, protocols...) - - // Keep a string version of the protocols for performance reasons - impl.dtProtocolStrings = make([]string, 0, len(impl.dtProtocols)) - for _, proto := range impl.dtProtocols { - impl.dtProtocolStrings = append(impl.dtProtocolStrings, string(proto)) - } +func (dtnet *libp2pDataTransferNetwork) setDataTransferProtocols(protocols []protocol.ID) { + dtnet.dtProtocols = append([]protocol.ID{}, protocols...) } diff --git a/transport/helpers/network/libp2p_impl_test.go b/transport/helpers/network/libp2p_impl_test.go index ceffc8a..76b59ef 100644 --- a/transport/helpers/network/libp2p_impl_test.go +++ b/transport/helpers/network/libp2p_impl_test.go @@ -91,8 +91,8 @@ func TestMessageSendAndReceive(t *testing.T) { messageReceived: make(chan struct{}), connectedPeers: make(chan peer.ID, 2), } - dtnet1.SetDelegate("graphsync", r) - dtnet2.SetDelegate("graphsync", r) + dtnet1.SetDelegate("graphsync", []datatransfer.Version{datatransfer.LegacyTransportVersion}, r) + dtnet2.SetDelegate("graphsync", []datatransfer.Version{datatransfer.LegacyTransportVersion}, r) err = dtnet1.ConnectTo(ctx, host2.ID()) require.NoError(t, err) @@ -263,8 +263,8 @@ func TestSendMessageRetry(t *testing.T) { messageReceived: make(chan struct{}), connectedPeers: make(chan peer.ID, 2), } - dtnet1.SetDelegate("graphsync", r) - dtnet2.SetDelegate("graphsync", r) + dtnet1.SetDelegate("graphsync", []datatransfer.Version{datatransfer.LegacyTransportVersion}, r) + dtnet2.SetDelegate("graphsync", []datatransfer.Version{datatransfer.LegacyTransportVersion}, r) err = dtnet1.ConnectTo(ctx, host2.ID()) require.NoError(t, err)