Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make deal state channel id nilable #490

Merged
merged 3 commits into from
Feb 24, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions retrievalmarket/impl/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (

"github.com/filecoin-project/go-address"
datatransfer "github.com/filecoin-project/go-data-transfer"
versioning "github.com/filecoin-project/go-ds-versioning/pkg"
versionedfsm "github.com/filecoin-project/go-ds-versioning/pkg/fsm"
"github.com/filecoin-project/go-multistore"
"github.com/filecoin-project/go-state-types/abi"
Expand Down Expand Up @@ -99,7 +98,7 @@ func NewClient(
StateEntryFuncs: clientstates.ClientStateEntryFuncs,
FinalityStates: clientstates.ClientFinalityStates,
Notifier: c.notifySubscribers,
}, retrievalMigrations, versioning.VersionKey("1"))
}, retrievalMigrations, "2")
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion retrievalmarket/impl/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ func TestMigrations(t *testing.T) {
},
},
StoreID: storeIDs[i],
ChannelID: channelIDs[i],
ChannelID: &channelIDs[i],
LastPaymentRequested: lastPaymentRequesteds[i],
AllBlocksReceived: allBlocksReceiveds[i],
TotalFunds: totalFundss[i],
Expand Down
2 changes: 1 addition & 1 deletion retrievalmarket/impl/clientstates/client_fsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ var ClientEvents = fsm.Events{
From(rm.DealStatusNew).To(rm.DealStatusWaitForAcceptance).
From(rm.DealStatusRetryLegacy).To(rm.DealStatusWaitForAcceptanceLegacy).
Action(func(deal *rm.ClientDealState, channelID datatransfer.ChannelID) error {
deal.ChannelID = channelID
deal.ChannelID = &channelID
deal.Message = ""
return nil
}),
Expand Down
13 changes: 8 additions & 5 deletions retrievalmarket/impl/clientstates/client_states.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func SendFunds(ctx fsm.Context, environment ClientDealEnvironment, deal rm.Clien
}

// send payment voucher (or fail)
err = environment.SendDataTransferVoucher(ctx.Context(), deal.ChannelID, &rm.DealPayment{
err = environment.SendDataTransferVoucher(ctx.Context(), *deal.ChannelID, &rm.DealPayment{
ID: deal.DealProposal.ID,
PaymentChannel: deal.PaymentInfo.PayCh,
PaymentVoucher: voucher,
Expand Down Expand Up @@ -164,10 +164,13 @@ func CheckFunds(ctx fsm.Context, environment ClientDealEnvironment, deal rm.Clie

// CancelDeal clears a deal that went wrong for an unknown reason
func CancelDeal(ctx fsm.Context, environment ClientDealEnvironment, deal rm.ClientDealState) error {
// Read next response (or fail)
err := environment.CloseDataTransfer(ctx.Context(), deal.ChannelID)
if err != nil {
return ctx.Trigger(rm.ClientEventDataTransferError, err)
// If the data transfer has started, cancel it
if deal.ChannelID != nil {
// Read next response (or fail)
err := environment.CloseDataTransfer(ctx.Context(), *deal.ChannelID)
if err != nil {
return ctx.Trigger(rm.ClientEventDataTransferError, err)
}
}

return ctx.Trigger(rm.ClientEventCancelComplete)
Expand Down
11 changes: 11 additions & 0 deletions retrievalmarket/impl/clientstates/client_states_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,12 @@ func TestSendFunds(t *testing.T) {
node := testnodes.NewTestRetrievalClientNode(nodeParams)
environment := &fakeEnvironment{node, nil, sendDataTransferVoucherError, nil}
fsmCtx := fsmtest.NewTestContext(ctx, eventMachine)
dealState.ChannelID = &datatransfer.ChannelID{
Initiator: "initiator",
Responder: dealState.Sender,
ID: 1,
}
dealState.ChannelID.Responder = dealState.Sender
dirkmc marked this conversation as resolved.
Show resolved Hide resolved
err := clientstates.SendFunds(fsmCtx, environment, *dealState)
require.NoError(t, err)
fsmCtx.ReplayEvents(t, dealState)
Expand Down Expand Up @@ -526,6 +532,11 @@ func TestCancelDeal(t *testing.T) {
node := testnodes.NewTestRetrievalClientNode(testnodes.TestRetrievalClientNodeParams{})
environment := &fakeEnvironment{node, nil, nil, closeError}
fsmCtx := fsmtest.NewTestContext(ctx, eventMachine)
dealState.ChannelID = &datatransfer.ChannelID{
Initiator: "initiator",
Responder: dealState.Sender,
ID: 1,
}
err := clientstates.CancelDeal(fsmCtx, environment, *dealState)
require.NoError(t, err)
fsmCtx.ReplayEvents(t, dealState)
Expand Down
2 changes: 1 addition & 1 deletion retrievalmarket/impl/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ func TestProviderMigrations(t *testing.T) {
},
},
StoreID: storeIDs[i],
ChannelID: channelIDs[i],
ChannelID: &channelIDs[i],
PieceInfo: &piecestore.PieceInfo{
PieceCID: *pieceCIDs[i],
Deals: []piecestore.DealInfo{
Expand Down
2 changes: 1 addition & 1 deletion retrievalmarket/impl/providerstates/provider_fsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ var ProviderEvents = fsm.Events{
From(rm.DealStatusFundsNeededUnseal).ToNoChange().
From(rm.DealStatusNew).To(rm.DealStatusUnsealing).
Action(func(deal *rm.ProviderDealState, channelID datatransfer.ChannelID) error {
deal.ChannelID = channelID
deal.ChannelID = &channelID
return nil
}),

Expand Down
16 changes: 10 additions & 6 deletions retrievalmarket/impl/providerstates/provider_states.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ func UnpauseDeal(ctx fsm.Context, environment ProviderDealEnvironment, deal rm.P
if err != nil {
return ctx.Trigger(rm.ProviderEventDataTransferError, err)
}
err = environment.ResumeDataTransfer(ctx.Context(), deal.ChannelID)
if err != nil {
return ctx.Trigger(rm.ProviderEventDataTransferError, err)
if deal.ChannelID != nil {
err = environment.ResumeDataTransfer(ctx.Context(), *deal.ChannelID)
if err != nil {
return ctx.Trigger(rm.ProviderEventDataTransferError, err)
}
}
return nil
}
Expand All @@ -87,9 +89,11 @@ func CancelDeal(ctx fsm.Context, environment ProviderDealEnvironment, deal rm.Pr
if err != nil {
return ctx.Trigger(rm.ProviderEventMultiStoreError, err)
}
err = environment.CloseDataTransfer(ctx.Context(), deal.ChannelID)
if err != nil && !errors.Is(err, statemachine.ErrTerminated) {
return ctx.Trigger(rm.ProviderEventDataTransferError, err)
if deal.ChannelID != nil {
err = environment.CloseDataTransfer(ctx.Context(), *deal.ChannelID)
if err != nil && !errors.Is(err, statemachine.ErrTerminated) {
return ctx.Trigger(rm.ProviderEventDataTransferError, err)
}
}
return ctx.Trigger(rm.ProviderEventCancelComplete)
}
Expand Down
11 changes: 11 additions & 0 deletions retrievalmarket/impl/providerstates/provider_states_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/stretchr/testify/require"

datatransfer "github.com/filecoin-project/go-data-transfer"
"github.com/filecoin-project/go-state-types/abi"
"github.com/filecoin-project/go-state-types/big"
"github.com/filecoin-project/go-statemachine/fsm"
Expand Down Expand Up @@ -112,6 +113,11 @@ func TestUnpauseDeal(t *testing.T) {
environment := rmtesting.NewTestProviderDealEnvironment(node)
setupEnv(environment)
fsmCtx := fsmtest.NewTestContext(ctx, eventMachine)
dealState.ChannelID = &datatransfer.ChannelID{
Initiator: "initiator",
Responder: dealState.Receiver,
ID: 1,
}
err := providerstates.UnpauseDeal(fsmCtx, environment, *dealState)
require.NoError(t, err)
node.VerifyExpectations(t)
Expand Down Expand Up @@ -155,6 +161,11 @@ func TestCancelDeal(t *testing.T) {
environment := rmtesting.NewTestProviderDealEnvironment(node)
setupEnv(environment)
fsmCtx := fsmtest.NewTestContext(ctx, eventMachine)
dealState.ChannelID = &datatransfer.ChannelID{
Initiator: "initiator",
Responder: dealState.Receiver,
ID: 1,
}
err := providerstates.CancelDeal(fsmCtx, environment, *dealState)
require.NoError(t, err)
node.VerifyExpectations(t)
Expand Down
22 changes: 19 additions & 3 deletions retrievalmarket/impl/requestvalidation/revalidator.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"errors"
"sync"

logging "github.com/ipfs/go-log/v2"

datatransfer "github.com/filecoin-project/go-data-transfer"
"github.com/filecoin-project/go-state-types/abi"
"github.com/filecoin-project/go-state-types/big"
Expand All @@ -14,6 +16,8 @@ import (
"github.com/filecoin-project/go-fil-markets/retrievalmarket/migrations"
)

var log = logging.Logger("retrieval-revalidator")

// RevalidatorEnvironment are the dependencies needed to
// build the logic of revalidation -- essentially, access to the node at statemachines
type RevalidatorEnvironment interface {
Expand Down Expand Up @@ -52,9 +56,15 @@ func NewProviderRevalidator(env RevalidatorEnvironment) *ProviderRevalidator {
// a given channel ID with a retrieval deal, so that checks run for data sent
// on the channel
func (pr *ProviderRevalidator) TrackChannel(deal rm.ProviderDealState) {
// Sanity check
if deal.ChannelID == nil {
log.Errorf("cannot track deal %s: channel ID is nil", deal.ID)
return
}

pr.trackedChannelsLk.Lock()
defer pr.trackedChannelsLk.Unlock()
pr.trackedChannels[deal.ChannelID] = &channelData{
pr.trackedChannels[*deal.ChannelID] = &channelData{
dealID: deal.Identifier(),
}
pr.writeDealState(deal)
Expand All @@ -63,9 +73,15 @@ func (pr *ProviderRevalidator) TrackChannel(deal rm.ProviderDealState) {
// UntrackChannel indicates a retrieval deal is finish and no longer is tracked
// by this provider
func (pr *ProviderRevalidator) UntrackChannel(deal rm.ProviderDealState) {
// Sanity check
if deal.ChannelID == nil {
log.Errorf("cannot untrack deal %s: channel ID is nil", deal.ID)
return
}

pr.trackedChannelsLk.Lock()
defer pr.trackedChannelsLk.Unlock()
delete(pr.trackedChannels, deal.ChannelID)
delete(pr.trackedChannels, *deal.ChannelID)
}

func (pr *ProviderRevalidator) loadDealState(channel *channelData) error {
Expand All @@ -82,7 +98,7 @@ func (pr *ProviderRevalidator) loadDealState(channel *channelData) error {
}

func (pr *ProviderRevalidator) writeDealState(deal rm.ProviderDealState) {
channel := pr.trackedChannels[deal.ChannelID]
channel := pr.trackedChannels[*deal.ChannelID]
channel.totalSent = deal.TotalSent
if !deal.PricePerByte.IsZero() {
channel.totalPaidFor = big.Div(big.Max(big.Sub(deal.FundsReceived, deal.UnsealPrice), big.Zero()), deal.PricePerByte).Uint64()
Expand Down
35 changes: 18 additions & 17 deletions retrievalmarket/impl/requestvalidation/revalidator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func TestOnPullDataSent(t *testing.T) {
},
"record block": {
deal: deal,
channelID: deal.ChannelID,
channelID: *deal.ChannelID,
expectedID: deal.Identifier(),
expectedEvent: rm.ProviderEventBlockSent,
expectedArgs: []interface{}{deal.TotalSent + uint64(500)},
Expand All @@ -64,7 +64,7 @@ func TestOnPullDataSent(t *testing.T) {
},
"record block zero price per byte": {
deal: dealZeroPricePerByte,
channelID: dealZeroPricePerByte.ChannelID,
channelID: *dealZeroPricePerByte.ChannelID,
expectedID: dealZeroPricePerByte.Identifier(),
expectedEvent: rm.ProviderEventBlockSent,
expectedArgs: []interface{}{dealZeroPricePerByte.TotalSent + uint64(500)},
Expand All @@ -73,7 +73,7 @@ func TestOnPullDataSent(t *testing.T) {
},
"request payment": {
deal: deal,
channelID: deal.ChannelID,
channelID: *deal.ChannelID,
expectedID: deal.Identifier(),
expectedEvent: rm.ProviderEventPaymentRequested,
expectedArgs: []interface{}{deal.TotalSent + defaultCurrentInterval},
Expand All @@ -88,7 +88,7 @@ func TestOnPullDataSent(t *testing.T) {
},
"request payment, legacy": {
deal: legacyDeal,
channelID: legacyDeal.ChannelID,
channelID: *legacyDeal.ChannelID,
expectedID: legacyDeal.Identifier(),
expectedEvent: rm.ProviderEventPaymentRequested,
expectedArgs: []interface{}{legacyDeal.TotalSent + defaultCurrentInterval},
Expand Down Expand Up @@ -140,7 +140,7 @@ func TestOnComplete(t *testing.T) {
dealZeroPricePerByte.PricePerByte = big.Zero()
legacyDeal := deal
legacyDeal.LegacyProtocol = true
channelID := deal.ChannelID
channelID := *deal.ChannelID
testCases := map[string]struct {
expectedEvents []eventSent
deal rm.ProviderDealState
Expand Down Expand Up @@ -296,6 +296,7 @@ func TestRevalidate(t *testing.T) {

deal := *makeDealState(rm.DealStatusFundsNeeded)
deal.TotalSent = defaultTotalSent + defaultCurrentInterval
channelID := *deal.ChannelID
smallerPayment := abi.NewTokenAmount(400000)
payment := &retrievalmarket.DealPayment{
ID: deal.ID,
Expand Down Expand Up @@ -329,7 +330,7 @@ func TestRevalidate(t *testing.T) {
},
"not a payment voucher": {
deal: deal,
channelID: deal.ChannelID,
channelID: channelID,
noSend: true,
expectedError: errors.New("wrong voucher type"),
},
Expand All @@ -338,7 +339,7 @@ func TestRevalidate(t *testing.T) {
tn.ChainHeadError = errors.New("something went wrong")
},
deal: deal,
channelID: deal.ChannelID,
channelID: channelID,
voucher: payment,
expectedError: errors.New("something went wrong"),
expectedID: deal.Identifier(),
Expand All @@ -355,7 +356,7 @@ func TestRevalidate(t *testing.T) {
tn.ChainHeadError = errors.New("something went wrong")
},
deal: deal,
channelID: deal.ChannelID,
channelID: channelID,
voucher: legacyPayment,
expectedError: errors.New("something went wrong"),
expectedID: deal.Identifier(),
Expand All @@ -372,7 +373,7 @@ func TestRevalidate(t *testing.T) {
_ = tn.ExpectVoucher(payCh, voucher, nil, defaultPaymentPerInterval, abi.NewTokenAmount(0), errors.New("your money's no good here"))
},
deal: deal,
channelID: deal.ChannelID,
channelID: channelID,
voucher: payment,
expectedError: errors.New("your money's no good here"),
expectedID: deal.Identifier(),
Expand All @@ -389,7 +390,7 @@ func TestRevalidate(t *testing.T) {
_ = tn.ExpectVoucher(payCh, voucher, nil, defaultPaymentPerInterval, abi.NewTokenAmount(0), errors.New("your money's no good here"))
},
deal: deal,
channelID: deal.ChannelID,
channelID: channelID,
voucher: legacyPayment,
expectedError: errors.New("your money's no good here"),
expectedID: deal.Identifier(),
Expand All @@ -406,7 +407,7 @@ func TestRevalidate(t *testing.T) {
_ = tn.ExpectVoucher(payCh, voucher, nil, defaultPaymentPerInterval, smallerPayment, nil)
},
deal: deal,
channelID: deal.ChannelID,
channelID: channelID,
voucher: payment,
expectedError: datatransfer.ErrPause,
expectedID: deal.Identifier(),
Expand All @@ -423,7 +424,7 @@ func TestRevalidate(t *testing.T) {
_ = tn.ExpectVoucher(payCh, voucher, nil, defaultPaymentPerInterval, smallerPayment, nil)
},
deal: deal,
channelID: deal.ChannelID,
channelID: channelID,
voucher: legacyPayment,
expectedError: datatransfer.ErrPause,
expectedID: deal.Identifier(),
Expand All @@ -440,7 +441,7 @@ func TestRevalidate(t *testing.T) {
_ = tn.ExpectVoucher(payCh, voucher, nil, defaultPaymentPerInterval, defaultPaymentPerInterval, nil)
},
deal: deal,
channelID: deal.ChannelID,
channelID: channelID,
voucher: payment,
expectedID: deal.Identifier(),
expectedEvent: rm.ProviderEventPaymentReceived,
Expand All @@ -452,7 +453,7 @@ func TestRevalidate(t *testing.T) {
_ = tn.ExpectVoucher(payCh, voucher, nil, defaultPaymentPerInterval, defaultPaymentPerInterval, nil)
},
deal: lastPaymentDeal,
channelID: deal.ChannelID,
channelID: channelID,
voucher: payment,
expectedID: deal.Identifier(),
expectedEvent: rm.ProviderEventPaymentReceived,
Expand All @@ -467,7 +468,7 @@ func TestRevalidate(t *testing.T) {
_ = tn.ExpectVoucher(payCh, voucher, nil, defaultPaymentPerInterval, defaultPaymentPerInterval, nil)
},
deal: lastPaymentDeal,
channelID: deal.ChannelID,
channelID: channelID,
voucher: legacyPayment,
expectedID: deal.Identifier(),
expectedEvent: rm.ProviderEventPaymentReceived,
Expand All @@ -482,7 +483,7 @@ func TestRevalidate(t *testing.T) {
_ = tn.ExpectVoucher(payCh, voucher, nil, defaultPaymentPerInterval, big.Zero(), nil)
},
deal: deal,
channelID: deal.ChannelID,
channelID: channelID,
voucher: payment,
expectedID: deal.Identifier(),
expectedEvent: rm.ProviderEventPaymentReceived,
Expand Down Expand Up @@ -565,7 +566,7 @@ func makeDealState(status retrievalmarket.DealStatus) *retrievalmarket.ProviderD
TotalSent: defaultTotalSent,
CurrentInterval: defaultCurrentInterval,
FundsReceived: defaultFundsReceived,
ChannelID: channelID,
ChannelID: &channelID,
Receiver: channelID.Initiator,
DealProposal: retrievalmarket.DealProposal{
ID: dealID,
Expand Down
Loading