diff --git a/modules/core/04-channel/keeper/upgrade.go b/modules/core/04-channel/keeper/upgrade.go index a09c41cf8af..51c2747aadc 100644 --- a/modules/core/04-channel/keeper/upgrade.go +++ b/modules/core/04-channel/keeper/upgrade.go @@ -366,11 +366,103 @@ func (k Keeper) ChanUpgradeTimeout( ctx sdk.Context, portID, channelID string, counterpartyChannel types.Channel, - prevErrorReceipt types.ErrorReceipt, + prevErrorReceipt *types.ErrorReceipt, proofCounterpartyChannel, proofErrorReceipt []byte, proofHeight exported.Height, ) error { + channel, found := k.GetChannel(ctx, portID, channelID) + if !found { + return errorsmod.Wrapf(types.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID) + } + + if channel.State != types.INITUPGRADE { + return errorsmod.Wrapf(types.ErrInvalidChannelState, "channel state is not INITUPGRADE (got %s)", channel.State) + } + + upgrade, found := k.GetUpgrade(ctx, portID, channelID) + if !found { + return errorsmod.Wrapf(types.ErrUpgradeNotFound, "port ID (%s) channel ID (%s)", portID, channelID) + } + + connection, found := k.connectionKeeper.GetConnection(ctx, channel.ConnectionHops[0]) + if !found { + return errorsmod.Wrap( + connectiontypes.ErrConnectionNotFound, + channel.ConnectionHops[0], + ) + } + + if connection.GetState() != int32(connectiontypes.OPEN) { + return errorsmod.Wrapf( + connectiontypes.ErrInvalidConnectionState, + "connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String(), + ) + } + + // proof must be from a height after timeout has elapsed. Either timeoutHeight or timeoutTimestamp must be defined. + // if timeoutHeight is defined and proof is from before timeout height, abort transaction + proofTimestamp, err := k.connectionKeeper.GetTimestampAtHeight(ctx, connection, proofHeight) + if err != nil { + return err + } + + timeout := upgrade.Timeout + proofHeightIsInvalid := timeout.Height.IsZero() || proofHeight.LT(timeout.Height) + proofTimestampIsInvalid := timeout.Timestamp == 0 || proofTimestamp < timeout.Timestamp + if proofHeightIsInvalid && proofTimestampIsInvalid { + return errorsmod.Wrap(types.ErrInvalidUpgradeTimeout, "timeout has not yet passed on counterparty chain") + } + + // counterparty channel must be proved to still be in OPEN state or INITUPGRADE state (crossing hellos) + if !collections.Contains(counterpartyChannel.State, []types.State{types.OPEN, types.INITUPGRADE}) { + return errorsmod.Wrapf(types.ErrInvalidChannelState, "expected one of [%s, %s], got %s", types.OPEN, types.INITUPGRADE, counterpartyChannel.State) + } + + // verify the counterparty channel state + if err := k.connectionKeeper.VerifyChannelState( + ctx, + connection, + proofHeight, proofCounterpartyChannel, + channel.Counterparty.PortId, + channel.Counterparty.ChannelId, + counterpartyChannel, + ); err != nil { + return errorsmod.Wrap(err, "failed to verify counterparty channel state") + } + + // Error receipt passed in is either nil or it is a stale error receipt from a previous upgrade + if prevErrorReceipt == nil { + if err := k.connectionKeeper.VerifyChannelUpgradeErrorAbsence( + ctx, + channel.Counterparty.PortId, channel.Counterparty.ChannelId, + connection, + proofErrorReceipt, + proofHeight, + ); err != nil { + return errorsmod.Wrap(err, "failed to verify absence of counterparty channel upgrade error receipt") + } + + return nil + } + // timeout for this sequence can only succeed if the error receipt written into the error path on the counterparty + // was for a previous sequence by the timeout deadline. + upgradeSequence := channel.UpgradeSequence + if upgradeSequence < prevErrorReceipt.Sequence { + return errorsmod.Wrapf(types.ErrInvalidUpgradeSequence, "previous counterparty error receipt sequence is greater than our current upgrade sequence: %d > %d", prevErrorReceipt.Sequence, upgradeSequence) + } + + if err := k.connectionKeeper.VerifyChannelUpgradeError( + ctx, + channel.Counterparty.PortId, channel.Counterparty.ChannelId, + connection, + *prevErrorReceipt, + proofErrorReceipt, + proofHeight, + ); err != nil { + return errorsmod.Wrap(err, "failed to verify counterparty channel upgrade error receipt") + } + return nil } diff --git a/modules/core/04-channel/keeper/upgrade_test.go b/modules/core/04-channel/keeper/upgrade_test.go index 44d32871742..628329e4f85 100644 --- a/modules/core/04-channel/keeper/upgrade_test.go +++ b/modules/core/04-channel/keeper/upgrade_test.go @@ -485,6 +485,197 @@ func (suite *KeeperTestSuite) TestChanUpgradeAck() { } } +func (suite *KeeperTestSuite) TestChanUpgradeTimeout() { + var ( + path *ibctesting.Path + errReceipt *types.ErrorReceipt + proofHeight exported.Height + proofCounterpartyChannel []byte + proofErrorReceipt []byte + ) + + testCases := []struct { + name string + malleate func() + expError error + }{ + { + "success: proof height has passed", + func() {}, + nil, + }, + { + "success: proof timestamp has passed", + func() { + upgrade := path.EndpointA.GetProposedUpgrade() + upgrade.Timeout.Height = defaultTimeoutHeight + upgrade.Timeout.Timestamp = 5 + suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgrade) + + suite.Require().NoError(path.EndpointA.UpdateClient()) + + proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() + upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + proofErrorReceipt, _ = suite.chainB.QueryProof(upgradeErrorReceiptKey) + }, + nil, + }, + { + "success: non-nil error receipt", + func() { + errReceipt = &types.ErrorReceipt{ + Sequence: 1, + Message: types.ErrInvalidUpgrade.Error(), + } + + suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, *errReceipt) + + suite.Require().NoError(path.EndpointB.UpdateClient()) + suite.Require().NoError(path.EndpointA.UpdateClient()) + + proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() + upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + proofErrorReceipt, _ = suite.chainB.QueryProof(upgradeErrorReceiptKey) + }, + nil, + }, + { + "channel not found", + func() { + path.EndpointA.ChannelID = ibctesting.InvalidID + }, + types.ErrChannelNotFound, + }, + { + "channel state is not in INITUPGRADE state", + func() { + suite.Require().NoError(path.EndpointA.SetChannelState(types.ACKUPGRADE)) + }, + types.ErrInvalidChannelState, + }, + { + "current upgrade not found", + func() { + suite.chainA.DeleteKey(host.ChannelUpgradeKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)) + }, + types.ErrUpgradeNotFound, + }, + { + "connection not found", + func() { + channel := path.EndpointA.GetChannel() + channel.ConnectionHops[0] = ibctesting.InvalidID + path.EndpointA.SetChannel(channel) + }, + connectiontypes.ErrConnectionNotFound, + }, + { + "connection not open", + func() { + connectionEnd := path.EndpointA.GetConnection() + connectionEnd.State = connectiontypes.UNINITIALIZED + path.EndpointA.SetConnection(connectionEnd) + }, + connectiontypes.ErrInvalidConnectionState, + }, + { + "unable to retrieve timestamp at proof height", + func() { + proofHeight = suite.chainA.GetTimeoutHeight() + }, + clienttypes.ErrConsensusStateNotFound, + }, + { + "timeout has not passed", + func() { + upgrade := path.EndpointA.GetProposedUpgrade() + upgrade.Timeout.Height = suite.chainA.GetTimeoutHeight() + suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgrade) + + suite.Require().NoError(path.EndpointA.UpdateClient()) + + proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() + upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + proofErrorReceipt, _ = suite.chainB.QueryProof(upgradeErrorReceiptKey) + }, + types.ErrInvalidUpgradeTimeout, + }, + { + "counterparty channel state is not OPEN or INITUPGRADE (crossing hellos)", + func() { + channel := path.EndpointB.GetChannel() + channel.State = types.TRYUPGRADE + path.EndpointB.SetChannel(channel) + + suite.Require().NoError(path.EndpointB.UpdateClient()) + suite.Require().NoError(path.EndpointA.UpdateClient()) + + proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() + upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + proofErrorReceipt, _ = suite.chainB.QueryProof(upgradeErrorReceiptKey) + }, + types.ErrInvalidChannelState, + }, + { + "non-nil error receipt: error receipt seq greater than current upgrade seq", + func() { + errReceipt = &types.ErrorReceipt{ + Sequence: 3, + Message: types.ErrInvalidUpgrade.Error(), + } + }, + types.ErrInvalidUpgradeSequence, + }, + } + + for _, tc := range testCases { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + expPass := tc.expError == nil + + path = ibctesting.NewPath(suite.chainA, suite.chainB) + suite.coordinator.Setup(path) + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + + errReceipt = nil + + // set timeout height to 1 to ensure timeout + path.EndpointA.ChannelConfig.ProposedUpgrade.Timeout.Height = clienttypes.NewHeight(1, 1) + suite.Require().NoError(path.EndpointA.ChanUpgradeInit()) + + // ensure clients are up to date to receive valid proofs + suite.Require().NoError(path.EndpointB.UpdateClient()) + suite.Require().NoError(path.EndpointA.UpdateClient()) + + proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() + upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + proofErrorReceipt, _ = suite.chainB.QueryProof(upgradeErrorReceiptKey) + + tc.malleate() + + err := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.ChanUpgradeTimeout( + suite.chainA.GetContext(), + path.EndpointA.ChannelConfig.PortID, + path.EndpointA.ChannelID, + path.EndpointB.GetChannel(), + errReceipt, + proofCounterpartyChannel, + proofErrorReceipt, + proofHeight, + ) + + if expPass { + suite.Require().NoError(err) + } else { + suite.assertUpgradeError(err, tc.expError) + } + }) + } +} + // TestStartFlushUpgradeHandshake tests the startFlushUpgradeHandshake. // UpgradeInit will be run on chainA and startFlushUpgradeHandshake // will be called on chainB diff --git a/modules/core/04-channel/types/errors.go b/modules/core/04-channel/types/errors.go index cb668b3c824..50f3e9f8a0b 100644 --- a/modules/core/04-channel/types/errors.go +++ b/modules/core/04-channel/types/errors.go @@ -49,4 +49,5 @@ var ( ErrInvalidFlushStatus = errorsmod.Register(SubModuleName, 33, "invalid flush status") ErrUpgradeRestoreFailed = errorsmod.Register(SubModuleName, 34, "restore failed") ErrUpgradeTimeout = errorsmod.Register(SubModuleName, 35, "upgrade timed-out") + ErrInvalidUpgradeTimeout = errorsmod.Register(SubModuleName, 36, "upgrade timeout is invalid") )