Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: refactor chanUpgradeAck tests to use expected errors #3843

Merged
merged 8 commits into from
Jun 20, 2023
138 changes: 69 additions & 69 deletions modules/core/04-channel/keeper/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,75 +205,6 @@ func (k Keeper) WriteUpgradeTryChannel(ctx sdk.Context, portID, channelID string
return channel, upgrade
}

// WriteUpgradeAckChannel writes a channel which has successfully passed the UpgradeAck handshake step as well as
// setting the upgrade for that channel.
// An event is emitted for the handshake step.
func (k Keeper) WriteUpgradeAckChannel(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved this to below ChanUpgradeAck

ctx sdk.Context,
portID, channelID string,
proposedUpgrade types.Upgrade,
) {
defer telemetry.IncrCounter(1, "ibc", "channel", "upgrade-ack")

channel, found := k.GetChannel(ctx, portID, channelID)
if !found {
panic(fmt.Sprintf("could not find existing channel when updating channel state in successful ChanUpgradeAck step, channelID: %s, portID: %s", channelID, portID))
}

previousState := channel.State
channel.State = types.ACKUPGRADE
channel.FlushStatus = types.FLUSHING

k.SetChannel(ctx, portID, channelID, channel)
k.SetUpgrade(ctx, portID, channelID, proposedUpgrade)

k.Logger(ctx).Info("channel state updated", "port-id", portID, "channel-id", channelID, "previous-state", previousState, "new-state", types.ACKUPGRADE.String())
emitChannelUpgradeAckEvent(ctx, portID, channelID, channel, proposedUpgrade)
}

// ChanUpgradeCancel is called by a module to cancel a channel upgrade that is in progress.
func (k Keeper) ChanUpgradeCancel(ctx sdk.Context, portID, channelID string, errorReceipt types.ErrorReceipt, errorReceiptProof []byte, proofHeight clienttypes.Height) error {
channel, found := k.GetChannel(ctx, portID, channelID)
if !found {
return errorsmod.Wrapf(types.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID)
}

// the channel state must be in INITUPGRADE or TRYUPGRADE
if !collections.Contains(channel.State, []types.State{types.INITUPGRADE, types.TRYUPGRADE}) {
return errorsmod.Wrapf(types.ErrInvalidChannelState, "expected one of [%s, %s], got %s", types.INITUPGRADE, types.TRYUPGRADE, channel.State)
}

// get underlying connection for proof verification
connection, err := k.GetConnection(ctx, channel.ConnectionHops[0])
if err != nil {
return errorsmod.Wrap(err, "failed to retrieve connection using the channel connection hops")
}

if connection.GetState() != int32(connectiontypes.OPEN) {
return errorsmod.Wrapf(
connectiontypes.ErrInvalidConnectionState,
"connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String(),
)
}

if err := k.connectionKeeper.VerifyChannelUpgradeError(ctx, portID, channelID, connection, errorReceipt, errorReceiptProof, proofHeight); err != nil {
return errorsmod.Wrap(err, "failed to verify counterparty error receipt")
}

// If counterparty sequence is less than the current sequence, abort the transaction since this error receipt is from a previous upgrade.
// Otherwise, set our upgrade sequence to the counterparty's error sequence + 1 so that both sides start with a fresh sequence.
currentSequence := channel.UpgradeSequence
counterpartySequence := errorReceipt.Sequence
if counterpartySequence < currentSequence {
return errorsmod.Wrapf(types.ErrInvalidUpgradeSequence, "error receipt sequence (%d) must be greater than or equal to current sequence (%d)", counterpartySequence, currentSequence)
}

channel.UpgradeSequence = errorReceipt.Sequence + 1
k.SetChannel(ctx, portID, channelID, channel)

return nil
}

// WriteUpgradeCancelChannel writes a channel which has canceled the upgrade process.Auxiliary upgrade state is
// also deleted.
func (k Keeper) WriteUpgradeCancelChannel(ctx sdk.Context, portID, channelID string) {
Expand Down Expand Up @@ -360,6 +291,75 @@ func (k Keeper) ChanUpgradeAck(
return nil
}

// WriteUpgradeAckChannel writes a channel which has successfully passed the UpgradeAck handshake step as well as
// setting the upgrade for that channel.
// An event is emitted for the handshake step.
func (k Keeper) WriteUpgradeAckChannel(
ctx sdk.Context,
portID, channelID string,
proposedUpgrade types.Upgrade,
) {
defer telemetry.IncrCounter(1, "ibc", "channel", "upgrade-ack")

channel, found := k.GetChannel(ctx, portID, channelID)
if !found {
panic(fmt.Sprintf("could not find existing channel when updating channel state in successful ChanUpgradeAck step, channelID: %s, portID: %s", channelID, portID))
}

previousState := channel.State
channel.State = types.ACKUPGRADE
channel.FlushStatus = types.FLUSHING

k.SetChannel(ctx, portID, channelID, channel)
k.SetUpgrade(ctx, portID, channelID, proposedUpgrade)

k.Logger(ctx).Info("channel state updated", "port-id", portID, "channel-id", channelID, "previous-state", previousState, "new-state", types.ACKUPGRADE.String())
emitChannelUpgradeAckEvent(ctx, portID, channelID, channel, proposedUpgrade)
}

// ChanUpgradeCancel is called by a module to cancel a channel upgrade that is in progress.
func (k Keeper) ChanUpgradeCancel(ctx sdk.Context, portID, channelID string, errorReceipt types.ErrorReceipt, errorReceiptProof []byte, proofHeight clienttypes.Height) error {
channel, found := k.GetChannel(ctx, portID, channelID)
if !found {
return errorsmod.Wrapf(types.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID)
}

// the channel state must be in INITUPGRADE or TRYUPGRADE
if !collections.Contains(channel.State, []types.State{types.INITUPGRADE, types.TRYUPGRADE}) {
return errorsmod.Wrapf(types.ErrInvalidChannelState, "expected one of [%s, %s], got %s", types.INITUPGRADE, types.TRYUPGRADE, channel.State)
}

// get underlying connection for proof verification
connection, err := k.GetConnection(ctx, channel.ConnectionHops[0])
if err != nil {
return errorsmod.Wrap(err, "failed to retrieve connection using the channel connection hops")
}

if connection.GetState() != int32(connectiontypes.OPEN) {
return errorsmod.Wrapf(
connectiontypes.ErrInvalidConnectionState,
"connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String(),
)
}

if err := k.connectionKeeper.VerifyChannelUpgradeError(ctx, portID, channelID, connection, errorReceipt, errorReceiptProof, proofHeight); err != nil {
return errorsmod.Wrap(err, "failed to verify counterparty error receipt")
}

// If counterparty sequence is less than the current sequence, abort the transaction since this error receipt is from a previous upgrade.
// Otherwise, set our upgrade sequence to the counterparty's error sequence + 1 so that both sides start with a fresh sequence.
currentSequence := channel.UpgradeSequence
counterpartySequence := errorReceipt.Sequence
if counterpartySequence < currentSequence {
return errorsmod.Wrapf(types.ErrInvalidUpgradeSequence, "error receipt sequence (%d) must be greater than or equal to current sequence (%d)", counterpartySequence, currentSequence)
}

channel.UpgradeSequence = errorReceipt.Sequence + 1
k.SetChannel(ctx, portID, channelID, channel)

return nil
}

// ChanUpgradeTimeout times out an outstanding upgrade.
// This should be used by the initialising chain when the counterparty chain has not responded to an upgrade proposal within the specified timeout period.
func (k Keeper) ChanUpgradeTimeout(
Expand Down
58 changes: 30 additions & 28 deletions modules/core/04-channel/keeper/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,12 +336,12 @@ func (suite *KeeperTestSuite) TestChanUpgradeAck() {
testCases := []struct {
name string
malleate func()
expPass bool
expError error
}{
{
"success",
func() {},
true,
nil,
},
{
"success with later upgrade sequence",
Expand All @@ -359,29 +359,29 @@ func (suite *KeeperTestSuite) TestChanUpgradeAck() {
err := path.EndpointA.UpdateClient()
suite.Require().NoError(err)
},
true,
nil,
},
{
"channel not found",
func() {
path.EndpointA.ChannelID = ibctesting.InvalidID
path.EndpointA.ChannelConfig.PortID = ibctesting.InvalidID
},
false,
types.ErrChannelNotFound,
},
{
"channel state is not in INITUPGRADE or TRYUPGRADE state",
func() {
suite.Require().NoError(path.EndpointA.SetChannelState(types.CLOSED))
},
false,
types.ErrInvalidChannelState,
},
{
"counterparty flush status is not in FLUSHING or FLUSHCOMPLETE",
func() {
counterpartyFlushStatus = types.NOTINFLUSH
},
false,
types.ErrInvalidFlushStatus,
},
{
"connection not found",
Expand All @@ -390,7 +390,7 @@ func (suite *KeeperTestSuite) TestChanUpgradeAck() {
channel.ConnectionHops = []string{"connection-100"}
path.EndpointA.SetChannel(channel)
},
false,
connectiontypes.ErrConnectionNotFound,
},
{
"invalid connection state",
Expand All @@ -399,46 +399,47 @@ func (suite *KeeperTestSuite) TestChanUpgradeAck() {
connectionEnd.State = connectiontypes.UNINITIALIZED
path.EndpointA.SetConnection(connectionEnd)
},
false,
connectiontypes.ErrInvalidConnectionState,
},
{
"upgrade not found",
func() {
store := suite.chainA.GetContext().KVStore(suite.chainA.GetSimApp().GetKey(exported.ModuleName))
store.Delete(host.ChannelUpgradeKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID))
},
false,
types.ErrUpgradeNotFound,
},
{
"channel end version mismatch on crossing hellos",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to bottom as this is last in the call stack

"startFlushUpgradeHandshake fails due to proof verification failure, counterparty upgrade connection hops are tampered with",
func() {
channel := path.EndpointA.GetChannel()
channel.State = types.TRYUPGRADE

path.EndpointA.SetChannel(channel)

upgrade := path.EndpointA.GetChannelUpgrade()
upgrade.Fields.Version = "invalid-version"

path.EndpointA.SetChannelUpgrade(upgrade)
counterpartyUpgrade.Fields.ConnectionHops = []string{ibctesting.InvalidID}
},
false,
commitmenttypes.ErrInvalidProof,
},
{
"startFlushUpgradeHandshake fails due to proof verification failure, counterparty upgrade connection hops are tampered with",
"startFlushUpgradeHandshake fails due to mismatch in upgrade ordering",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this testcase from: "startFlushUpgradeHandshake fails due to mismatch in upgrade sequences" to this one which fails with ordering mismatch.

This was the one testcase which was actually failing on proof verification before we got to the checks afterwards.
We can't force mismatch on upgrade sequences in this test as the counterparty channel upgrade sequence is built using our channel upgrade sequence, so mutating one or the other causes proof verification to fail. Decided to test for mismatch in upgrade ordering instead

func() {
counterpartyUpgrade.Fields.ConnectionHops = []string{ibctesting.InvalidID}
upgrade := path.EndpointA.GetChannelUpgrade()
upgrade.Fields.Ordering = types.NONE

path.EndpointA.SetChannelUpgrade(upgrade)
},
false,
types.NewUpgradeError(1, types.ErrIncompatibleCounterpartyUpgrade),
},
{
"startFlushUpgradeHandshake fails due to mismatch in upgrade sequences",
"channel end version mismatch on crossing hellos",
func() {
channel := path.EndpointA.GetChannel()
channel.UpgradeSequence = 5
channel.State = types.TRYUPGRADE

path.EndpointA.SetChannel(channel)

upgrade := path.EndpointA.GetChannelUpgrade()
upgrade.Fields.Version = "invalid-version"

path.EndpointA.SetChannelUpgrade(upgrade)
},
false,
types.NewUpgradeError(1, types.ErrIncompatibleCounterpartyUpgrade),
},
}

Expand Down Expand Up @@ -476,10 +477,11 @@ func (suite *KeeperTestSuite) TestChanUpgradeAck() {
proofChannel, proofUpgrade, proofHeight,
)

if tc.expPass {
expPass := tc.expError == nil
if expPass {
suite.Require().NoError(err)
} else {
suite.Require().Error(err)
suite.assertUpgradeError(err, tc.expError)
}
})
}
Expand Down
Loading