Skip to content

Commit

Permalink
Fix inconsistent string lookup functions (#5437)
Browse files Browse the repository at this point in the history
* fix inconsistent string lookup functions

* test client type and ordering

* channel and connection state tests

* address golangcibot comments

* fix test
  • Loading branch information
fedekunze authored Jan 2, 2020
1 parent 73a7e5d commit e06d6a9
Show file tree
Hide file tree
Showing 16 changed files with 276 additions and 100 deletions.
14 changes: 7 additions & 7 deletions x/ibc/02-client/exported/exported.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,23 +82,23 @@ func (ct *ClientType) UnmarshalJSON(data []byte) error {
return err
}

bz2, err := ClientTypeFromString(s)
if err != nil {
return err
clientType := ClientTypeFromString(s)
if clientType == 0 {
return fmt.Errorf("invalid client type '%s'", s)
}

*ct = bz2
*ct = clientType
return nil
}

// ClientTypeFromString returns a byte that corresponds to the registered client
// type. It returns 0 if the type is not found/registered.
func ClientTypeFromString(clientType string) (ClientType, error) {
func ClientTypeFromString(clientType string) ClientType {
switch clientType {
case ClientTypeTendermint:
return Tendermint, nil
return Tendermint

default:
return 0, fmt.Errorf("'%s' is not a valid client type", clientType)
return 0
}
}
49 changes: 49 additions & 0 deletions x/ibc/02-client/exported/exported_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package exported

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestClientTypeString(t *testing.T) {
cases := []struct {
msg string
name string
clientType ClientType
}{
{"tendermint client", ClientTypeTendermint, Tendermint},
{"empty type", "", 0},
}

for _, tt := range cases {
tt := tt
require.Equal(t, tt.clientType, ClientTypeFromString(tt.name), tt.msg)
require.Equal(t, tt.name, tt.clientType.String(), tt.msg)
}
}

func TestClientTypeMarshalJSON(t *testing.T) {
cases := []struct {
msg string
name string
clientType ClientType
expectPass bool
}{
{"tendermint client should have passed", ClientTypeTendermint, Tendermint, true},
{"empty type should have failed", "", 0, false},
}

for _, tt := range cases {
tt := tt
bz, err := tt.clientType.MarshalJSON()
require.NoError(t, err)
var ct ClientType
if tt.expectPass {
require.NoError(t, ct.UnmarshalJSON(bz), tt.msg)
require.Equal(t, tt.name, ct.String(), tt.msg)
} else {
require.Error(t, ct.UnmarshalJSON(bz), tt.msg)
}
}
}
10 changes: 6 additions & 4 deletions x/ibc/02-client/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ import (

// HandleMsgCreateClient defines the sdk.Handler for MsgCreateClient
func HandleMsgCreateClient(ctx sdk.Context, k Keeper, msg MsgCreateClient) sdk.Result {
clientType, err := exported.ClientTypeFromString(msg.ClientType)
if err != nil {
return sdk.ResultFromError(ErrInvalidClientType(DefaultCodespace, err.Error()))
clientType := exported.ClientTypeFromString(msg.ClientType)
if clientType == 0 {
return sdk.ResultFromError(
ErrInvalidClientType(DefaultCodespace, fmt.Sprintf("invalid client type '%s'", msg.ClientType)),
)
}

_, err = k.CreateClient(ctx, msg.ClientID, clientType, msg.ConsensusState)
_, err := k.CreateClient(ctx, msg.ClientID, clientType, msg.ConsensusState)
if err != nil {
return sdk.ResultFromError(err)
}
Expand Down
8 changes: 6 additions & 2 deletions x/ibc/02-client/types/msgs.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package types

import (
"fmt"

sdk "github.com/cosmos/cosmos-sdk/types"
evidenceexported "github.com/cosmos/cosmos-sdk/x/evidence/exported"
"github.com/cosmos/cosmos-sdk/x/ibc/02-client/exported"
Expand Down Expand Up @@ -51,8 +53,10 @@ func (msg MsgCreateClient) ValidateBasic() sdk.Error {
if err := host.DefaultClientIdentifierValidator(msg.ClientID); err != nil {
return sdk.ConvertError(err)
}
if _, err := exported.ClientTypeFromString(msg.ClientType); err != nil {
return sdk.ConvertError(errors.ErrInvalidClientType(errors.DefaultCodespace, err.Error()))
if clientType := exported.ClientTypeFromString(msg.ClientType); clientType == 0 {
return sdk.ConvertError(
errors.ErrInvalidClientType(errors.DefaultCodespace, fmt.Sprintf("invalid client type '%s'", msg.ClientType)),
)
}
if msg.ConsensusState == nil {
return sdk.ConvertError(errors.ErrInvalidConsensus(errors.DefaultCodespace))
Expand Down
4 changes: 2 additions & 2 deletions x/ibc/03-connection/alias.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ import (
)

const (
NONE = types.NONE
UNINITIALIZED = types.UNINITIALIZED
INIT = types.INIT
TRYOPEN = types.TRYOPEN
OPEN = types.OPEN
StateNone = types.StateNone
StateUninitialized = types.StateUninitialized
StateInit = types.StateInit
StateTryOpen = types.StateTryOpen
StateOpen = types.StateOpen
Expand Down
2 changes: 1 addition & 1 deletion x/ibc/03-connection/keeper/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (k Keeper) ConnOpenTry(
version := types.PickVersion(counterpartyVersions, types.GetCompatibleVersions())

// connection defines chain B's ConnectionEnd
connection := types.NewConnectionEnd(types.NONE, clientID, counterparty, []string{version})
connection := types.NewConnectionEnd(types.UNINITIALIZED, clientID, counterparty, []string{version})
expConnBz, err := k.cdc.MarshalBinaryLengthPrefixed(expectedConn)
if err != nil {
return err
Expand Down
2 changes: 1 addition & 1 deletion x/ibc/03-connection/keeper/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func (suite *KeeperTestSuite) TestConnOpenAck() {
}

invalidConnectionState := func() error {
suite.createConnection(testConnectionID2, testConnectionID1, testClientID2, testClientID1, connection.NONE)
suite.createConnection(testConnectionID2, testConnectionID1, testClientID2, testClientID1, connection.UNINITIALIZED)
//suite.updateClient(testClientID2)

proofTry, proofHeight := suite.queryProof(connectionKey)
Expand Down
2 changes: 1 addition & 1 deletion x/ibc/03-connection/keeper/keeper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func (suite KeeperTestSuite) TestGetAllConnections() {
}

conn3 := types.ConnectionEnd{
State: types.NONE,
State: types.UNINITIALIZED,
ClientID: testClientID3,
Counterparty: counterparty2,
Versions: types.GetCompatibleVersions(),
Expand Down
34 changes: 12 additions & 22 deletions x/ibc/03-connection/types/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package types

import (
"encoding/json"
"fmt"

sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
commitment "github.com/cosmos/cosmos-sdk/x/ibc/23-commitment"
Expand Down Expand Up @@ -72,49 +71,45 @@ type State byte

// available connection states
const (
NONE State = iota // default State
UNINITIALIZED State = iota // default State
INIT
TRYOPEN
OPEN
)

// string representation of the connection states
const (
StateNone string = "NONE"
StateInit string = "INIT"
StateTryOpen string = "TRYOPEN"
StateOpen string = "OPEN"
StateUninitialized string = "UNINITIALIZED"
StateInit string = "INIT"
StateTryOpen string = "TRYOPEN"
StateOpen string = "OPEN"
)

// String implements the Stringer interface
func (cs State) String() string {
switch cs {
case NONE:
return StateNone
case INIT:
return StateInit
case TRYOPEN:
return StateTryOpen
case OPEN:
return StateOpen
default:
return ""
return StateUninitialized
}
}

// StateFromString parses a string into a connection state
func StateFromString(state string) (State, error) {
func StateFromString(state string) State {
switch state {
case StateNone:
return NONE, nil
case StateInit:
return INIT, nil
return INIT
case StateTryOpen:
return TRYOPEN, nil
return TRYOPEN
case StateOpen:
return OPEN, nil
return OPEN
default:
return NONE, fmt.Errorf("'%s' is not a valid connection state", state)
return UNINITIALIZED
}
}

Expand All @@ -131,11 +126,6 @@ func (cs *State) UnmarshalJSON(data []byte) error {
return err
}

bz2, err := StateFromString(s)
if err != nil {
return err
}

*cs = bz2
*cs = StateFromString(s)
return nil
}
46 changes: 46 additions & 0 deletions x/ibc/03-connection/types/connection_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package types

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestConnectionStateString(t *testing.T) {
cases := []struct {
name string
state State
}{
{StateUninitialized, UNINITIALIZED},
{StateInit, INIT},
{StateTryOpen, TRYOPEN},
{StateOpen, OPEN},
}

for _, tt := range cases {
tt := tt
require.Equal(t, tt.state, StateFromString(tt.name))
require.Equal(t, tt.name, tt.state.String())
}
}

func TestConnectionlStateMarshalJSON(t *testing.T) {
cases := []struct {
name string
state State
}{
{StateUninitialized, UNINITIALIZED},
{StateInit, INIT},
{StateTryOpen, TRYOPEN},
{StateOpen, OPEN},
}

for _, tt := range cases {
tt := tt
bz, err := tt.state.MarshalJSON()
require.NoError(t, err)
var state State
require.NoError(t, state.UnmarshalJSON(bz))
require.Equal(t, tt.name, state.String())
}
}
9 changes: 5 additions & 4 deletions x/ibc/04-channel/alias.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,21 @@ import (
)

const (
NONE = types.NONE
UNINITIALIZED = types.UNINITIALIZED
UNORDERED = types.UNORDERED
ORDERED = types.ORDERED
OrderNone = types.OrderNone
OrderUnordered = types.OrderUnordered
OrderOrdered = types.OrderOrdered
CLOSED = types.CLOSED
INIT = types.INIT
OPENTRY = types.OPENTRY
TRYOPEN = types.TRYOPEN
OPEN = types.OPEN
StateClosed = types.StateClosed
StateUninitialized = types.StateUninitialized
StateInit = types.StateInit
StateOpenTry = types.StateOpenTry
StateTryOpen = types.StateTryOpen
StateOpen = types.StateOpen
StateClosed = types.StateClosed
DefaultCodespace = types.DefaultCodespace
CodeChannelExists = types.CodeChannelExists
CodeChannelNotFound = types.CodeChannelNotFound
Expand Down
8 changes: 4 additions & 4 deletions x/ibc/04-channel/keeper/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (k Keeper) ChanOpenInit(
return connection.ErrConnectionNotFound(k.codespace, connectionHops[0])
}

if connectionEnd.State == connection.NONE {
if connectionEnd.State == connection.UNINITIALIZED {
return connection.ErrInvalidConnectionState(
k.codespace,
fmt.Sprintf("connection state cannot be NONE"),
Expand Down Expand Up @@ -112,7 +112,7 @@ func (k Keeper) ChanOpenTry(

// NOTE: this step has been switched with the one below to reverse the connection
// hops
channel := types.NewChannel(types.OPENTRY, order, counterparty, connectionHops, version)
channel := types.NewChannel(types.TRYOPEN, order, counterparty, connectionHops, version)

counterpartyHops, found := k.CounterpartyHops(ctx, channel)
if !found {
Expand Down Expand Up @@ -200,7 +200,7 @@ func (k Keeper) ChanOpenAck(
// counterparty of the counterparty channel end (i.e self)
counterparty := types.NewCounterparty(portID, channelID)
expectedChannel := types.NewChannel(
types.OPENTRY, channel.Ordering, counterparty,
types.TRYOPEN, channel.Ordering, counterparty,
counterpartyHops, channel.Version,
)

Expand Down Expand Up @@ -238,7 +238,7 @@ func (k Keeper) ChanOpenConfirm(
return types.ErrChannelNotFound(k.codespace, portID, channelID)
}

if channel.State != types.OPENTRY {
if channel.State != types.TRYOPEN {
return types.ErrInvalidChannelState(
k.codespace,
fmt.Sprintf("channel state is not OPENTRY (got %s)", channel.State.String()),
Expand Down
Loading

0 comments on commit e06d6a9

Please sign in to comment.