diff --git a/modules/core/04-channel/keeper/upgrade.go b/modules/core/04-channel/keeper/upgrade.go index fdc1d872c69..a92ef651627 100644 --- a/modules/core/04-channel/keeper/upgrade.go +++ b/modules/core/04-channel/keeper/upgrade.go @@ -230,6 +230,49 @@ func (k Keeper) WriteUpgradeAckChannel( emitChannelUpgradeAckEvent(ctx, portID, channelID, channel, proposedUpgrade) } +// ChanUpgradeCancel is called by a module to cancel a channel upgrade that is in progress. +func (k Keeper) ChanUpgradeCancel(ctx sdk.Context, portID, channelID string, errorReceipt types.ErrorReceipt, errorReceiptProof []byte, proofHeight clienttypes.Height) error { + channel, found := k.GetChannel(ctx, portID, channelID) + if !found { + return errorsmod.Wrapf(types.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID) + } + + // the channel state must be in INITUPGRADE or TRYUPGRADE + if !collections.Contains(channel.State, []types.State{types.INITUPGRADE, types.TRYUPGRADE}) { + return errorsmod.Wrapf(types.ErrInvalidChannelState, "expected one of [%s, %s], got %s", types.INITUPGRADE, types.TRYUPGRADE, channel.State) + } + + // get underlying connection for proof verification + connection, err := k.GetConnection(ctx, channel.ConnectionHops[0]) + if err != nil { + return errorsmod.Wrap(err, "failed to retrieve connection using the channel connection hops") + } + + if connection.GetState() != int32(connectiontypes.OPEN) { + return errorsmod.Wrapf( + connectiontypes.ErrInvalidConnectionState, + "connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String(), + ) + } + + if err := k.connectionKeeper.VerifyChannelUpgradeError(ctx, portID, channelID, connection, errorReceipt, errorReceiptProof, proofHeight); err != nil { + return errorsmod.Wrap(err, "failed to verify counterparty error receipt") + } + + // If counterparty sequence is less than the current sequence, abort the transaction since this error receipt is from a previous upgrade. + // Otherwise, set our upgrade sequence to the counterparty's error sequence + 1 so that both sides start with a fresh sequence. + currentSequence := channel.UpgradeSequence + counterpartySequence := errorReceipt.Sequence + if counterpartySequence < currentSequence { + return errorsmod.Wrap(types.ErrInvalidUpgradeSequence, "error sequence must be less than current sequence") + } + + channel.UpgradeSequence = errorReceipt.Sequence + 1 + k.SetChannel(ctx, portID, channelID, channel) + + return nil +} + // WriteUpgradeCancelChannel writes a channel which has canceled the upgrade process.Auxiliary upgrade state is // also deleted. func (k Keeper) WriteUpgradeCancelChannel(ctx sdk.Context, portID, channelID string) { @@ -245,8 +288,11 @@ func (k Keeper) WriteUpgradeCancelChannel(ctx sdk.Context, portID, channelID str panic(fmt.Sprintf("could not find existing channel when updating channel state, channelID: %s, portID: %s", channelID, portID)) } + previousState := channel.State + k.restoreChannel(ctx, portID, channelID, channel) + k.Logger(ctx).Info("channel state updated", "port-id", portID, "channel-id", channelID, "previous-state", previousState, "new-state", types.OPEN.String()) emitChannelUpgradeCancelEvent(ctx, portID, channelID, channel, upgrade) } diff --git a/modules/core/04-channel/keeper/upgrade_test.go b/modules/core/04-channel/keeper/upgrade_test.go index 269b4f32c9c..44d32871742 100644 --- a/modules/core/04-channel/keeper/upgrade_test.go +++ b/modules/core/04-channel/keeper/upgrade_test.go @@ -4,6 +4,7 @@ import ( "fmt" errorsmod "cosmossdk.io/errors" + sdk "github.com/cosmos/cosmos-sdk/types" clienttypes "github.com/cosmos/ibc-go/v7/modules/core/02-client/types" connectiontypes "github.com/cosmos/ibc-go/v7/modules/core/03-connection/types" @@ -876,3 +877,116 @@ func (suite *KeeperTestSuite) TestAbortHandshake() { }) } } + +func (suite *KeeperTestSuite) TestChanUpgradeCancel() { + var ( + path *ibctesting.Path + errorReceipt types.ErrorReceipt + errorReceiptProof []byte + proofHeight clienttypes.Height + ) + tests := []struct { + name string + malleate func() + expError error + }{ + { + name: "success", + malleate: func() {}, + expError: nil, + }, + { + name: "invalid channel state", + malleate: func() { + channel := path.EndpointA.GetChannel() + channel.State = types.INIT + path.EndpointA.SetChannel(channel) + }, + expError: types.ErrInvalidChannelState, + }, + { + name: "channel not found", + malleate: func() { + path.EndpointA.Chain.DeleteKey(host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)) + }, + expError: types.ErrChannelNotFound, + }, + { + name: "connection not found", + malleate: func() { + channel := path.EndpointA.GetChannel() + channel.ConnectionHops = []string{"connection-100"} + path.EndpointA.SetChannel(channel) + }, + expError: connectiontypes.ErrConnectionNotFound, + }, + { + name: "counter partyupgrade sequence less than current sequence", + malleate: func() { + var ok bool + errorReceipt, ok = suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.GetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + suite.Require().True(ok) + + // the channel sequence will be 1 + errorReceipt.Sequence = 0 + + suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, errorReceipt) + + suite.coordinator.CommitBlock(suite.chainB) + suite.Require().NoError(path.EndpointA.UpdateClient()) + + upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + errorReceiptProof, proofHeight = suite.chainB.QueryProof(upgradeErrorReceiptKey) + }, + expError: types.ErrInvalidUpgradeSequence, + }, + } + + for _, tc := range tests { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + + path = ibctesting.NewPath(suite.chainA, suite.chainB) + suite.coordinator.Setup(path) + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + + suite.Require().NoError(path.EndpointA.ChanUpgradeInit()) + + suite.Require().NoError(path.EndpointB.UpdateClient()) + + // cause the upgrade to fail on chain b so an error receipt is written. + suite.chainB.GetSimApp().IBCMockModule.IBCApp.OnChanUpgradeTry = func( + ctx sdk.Context, portID, channelID string, order types.Order, connectionHops []string, counterpartyVersion string, + ) (string, error) { + return "", fmt.Errorf("mock app callback failed") + } + + suite.Require().NoError(path.EndpointB.ChanUpgradeTry()) + + suite.Require().NoError(path.EndpointA.UpdateClient()) + + upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + errorReceiptProof, proofHeight = suite.chainB.QueryProof(upgradeErrorReceiptKey) + + var ok bool + errorReceipt, ok = suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.GetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + suite.Require().True(ok) + + tc.malleate() + + err := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.ChanUpgradeCancel(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, errorReceipt, errorReceiptProof, proofHeight) + + expPass := tc.expError == nil + if expPass { + suite.Require().NoError(err) + channel := path.EndpointA.GetChannel() + suite.Require().Equal(errorReceipt.Sequence+1, channel.UpgradeSequence, "upgrade sequence should be incremented") + } else { + suite.Require().ErrorIs(err, tc.expError) + } + }) + } +}