diff --git a/app/app.go b/app/app.go index 5d9f63b53f..3f4a3e2ad1 100644 --- a/app/app.go +++ b/app/app.go @@ -371,8 +371,6 @@ func NewWasmApp(logger log.Logger, db dbm.DB, traceStore io.Writer, loadLatest b wasmDir, wasmConfig, supportedFeatures, - nil, - nil, wasmOpts..., ) diff --git a/x/wasm/alias.go b/x/wasm/alias.go index 6ae30298f6..608c1ec3fd 100644 --- a/x/wasm/alias.go +++ b/x/wasm/alias.go @@ -127,7 +127,7 @@ type ( Config = types.WasmConfig ContractInfoWithAddress = types.ContractInfoWithAddress CodeInfoResponse = types.CodeInfoResponse - MessageHandler = keeper.DefaultMessageHandler + MessageHandler = keeper.SDKMessageHandler BankEncoder = keeper.BankEncoder CustomEncoder = keeper.CustomEncoder StakingEncoder = keeper.StakingEncoder diff --git a/x/wasm/internal/keeper/bench_test.go b/x/wasm/internal/keeper/bench_test.go index 42fe9cf6a5..44e17930f1 100644 --- a/x/wasm/internal/keeper/bench_test.go +++ b/x/wasm/internal/keeper/bench_test.go @@ -40,7 +40,7 @@ func BenchmarkExecution(b *testing.B) { for name, spec := range specs { b.Run(name, func(b *testing.B) { wasmConfig := types.WasmConfig{MemoryCacheSize: 0} - ctx, keepers := createTestInput(b, false, SupportedFeatures, nil, nil, wasmConfig, spec.db()) + ctx, keepers := createTestInput(b, false, SupportedFeatures, wasmConfig, spec.db()) example := InstantiateHackatomExampleContract(b, ctx, keepers) if spec.pinned { require.NoError(b, keepers.WasmKeeper.PinCode(ctx, example.CodeID)) diff --git a/x/wasm/internal/keeper/genesis_test.go b/x/wasm/internal/keeper/genesis_test.go index c3de008a16..f7ff5e2044 100644 --- a/x/wasm/internal/keeper/genesis_test.go +++ b/x/wasm/internal/keeper/genesis_test.go @@ -643,7 +643,7 @@ func setupKeeper(t *testing.T) (*Keeper, sdk.Context, []sdk.StoreKey) { wasmConfig := wasmTypes.DefaultWasmConfig() pk := paramskeeper.NewKeeper(encodingConfig.Marshaler, encodingConfig.Amino, keyParams, tkeyParams) - srcKeeper := NewKeeper(encodingConfig.Marshaler, keyWasm, pk.Subspace(wasmTypes.DefaultParamspace), authkeeper.AccountKeeper{}, nil, stakingkeeper.Keeper{}, distributionkeeper.Keeper{}, nil, nil, nil, nil, nil, tempDir, wasmConfig, SupportedFeatures, nil, nil) + srcKeeper := NewKeeper(encodingConfig.Marshaler, keyWasm, pk.Subspace(wasmTypes.DefaultParamspace), authkeeper.AccountKeeper{}, nil, stakingkeeper.Keeper{}, distributionkeeper.Keeper{}, nil, nil, nil, nil, nil, tempDir, wasmConfig, SupportedFeatures) return &srcKeeper, ctx, []sdk.StoreKey{keyWasm, keyParams} } diff --git a/x/wasm/internal/keeper/handler_plugin.go b/x/wasm/internal/keeper/handler_plugin.go index e263e2fa4b..61c9436869 100644 --- a/x/wasm/internal/keeper/handler_plugin.go +++ b/x/wasm/internal/keeper/handler_plugin.go @@ -1,326 +1,48 @@ package keeper import ( - "encoding/json" + "errors" "fmt" "github.com/CosmWasm/wasmd/x/wasm/internal/types" wasmvmtypes "github.com/CosmWasm/wasmvm/types" codectypes "github.com/cosmos/cosmos-sdk/codec/types" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" - banktypes "github.com/cosmos/cosmos-sdk/x/bank/types" - distributiontypes "github.com/cosmos/cosmos-sdk/x/distribution/types" - ibctransfertypes "github.com/cosmos/cosmos-sdk/x/ibc/applications/transfer/types" - ibcclienttypes "github.com/cosmos/cosmos-sdk/x/ibc/core/02-client/types" channeltypes "github.com/cosmos/cosmos-sdk/x/ibc/core/04-channel/types" host "github.com/cosmos/cosmos-sdk/x/ibc/core/24-host" - stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" ) -type DefaultMessageHandler struct { - router sdk.Router - encoders MessageEncoders +// msgEncoder is an extension point to customize encodings +type msgEncoder interface { + // Encode converts wasmvm message to n cosmos message types + Encode(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg) ([]sdk.Msg, error) } -func NewDefaultMessageHandler(router sdk.Router, channelKeeper types.ChannelKeeper, capabilityKeeper types.CapabilityKeeper, unpacker codectypes.AnyUnpacker, customEncoders *MessageEncoders) DefaultMessageHandler { - encoders := DefaultEncoders(channelKeeper, capabilityKeeper, unpacker).Merge(customEncoders) - return DefaultMessageHandler{ - router: router, - encoders: encoders, - } -} - -type BankEncoder func(sender sdk.AccAddress, msg *wasmvmtypes.BankMsg) ([]sdk.Msg, error) -type CustomEncoder func(sender sdk.AccAddress, msg json.RawMessage) ([]sdk.Msg, error) -type StakingEncoder func(sender sdk.AccAddress, msg *wasmvmtypes.StakingMsg) ([]sdk.Msg, error) -type StargateEncoder func(sender sdk.AccAddress, msg *wasmvmtypes.StargateMsg) ([]sdk.Msg, error) -type WasmEncoder func(sender sdk.AccAddress, msg *wasmvmtypes.WasmMsg) ([]sdk.Msg, error) -type IBCEncoder func(ctx sdk.Context, sender sdk.AccAddress, contractIBCPortID string, msg *wasmvmtypes.IBCMsg) ([]sdk.Msg, error) - -type MessageEncoders struct { - Bank BankEncoder - Custom CustomEncoder - IBC IBCEncoder - Staking StakingEncoder - Stargate StargateEncoder - Wasm WasmEncoder -} - -func DefaultEncoders(channelKeeper types.ChannelKeeper, capabilityKeeper types.CapabilityKeeper, unpacker codectypes.AnyUnpacker) MessageEncoders { - return MessageEncoders{ - Bank: EncodeBankMsg, - Custom: NoCustomMsg, - IBC: EncodeIBCMsg(channelKeeper, capabilityKeeper), - Staking: EncodeStakingMsg, - Stargate: EncodeStargateMsg(unpacker), - Wasm: EncodeWasmMsg, - } -} - -func (e MessageEncoders) Merge(o *MessageEncoders) MessageEncoders { - if o == nil { - return e - } - if o.Bank != nil { - e.Bank = o.Bank - } - if o.Custom != nil { - e.Custom = o.Custom - } - if o.IBC != nil { - e.IBC = o.IBC - } - if o.Staking != nil { - e.Staking = o.Staking - } - if o.Stargate != nil { - e.Stargate = o.Stargate - } - if o.Wasm != nil { - e.Wasm = o.Wasm - } - return e -} - -func (e MessageEncoders) Encode(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg) ([]sdk.Msg, error) { - switch { - case msg.Bank != nil: - return e.Bank(contractAddr, msg.Bank) - case msg.Custom != nil: - return e.Custom(contractAddr, msg.Custom) - case msg.IBC != nil: - return e.IBC(ctx, contractAddr, contractIBCPortID, msg.IBC) - case msg.Staking != nil: - return e.Staking(contractAddr, msg.Staking) - case msg.Stargate != nil: - return e.Stargate(contractAddr, msg.Stargate) - case msg.Wasm != nil: - return e.Wasm(contractAddr, msg.Wasm) - } - return nil, sdkerrors.Wrap(types.ErrInvalidMsg, "Unknown variant of Wasm") -} - -func EncodeBankMsg(sender sdk.AccAddress, msg *wasmvmtypes.BankMsg) ([]sdk.Msg, error) { - if msg.Send == nil { - return nil, sdkerrors.Wrap(types.ErrInvalidMsg, "Unknown variant of Bank") - } - if len(msg.Send.Amount) == 0 { - return nil, nil - } - toSend, err := convertWasmCoinsToSdkCoins(msg.Send.Amount) - if err != nil { - return nil, err - } - sdkMsg := banktypes.MsgSend{ - FromAddress: sender.String(), - ToAddress: msg.Send.ToAddress, - Amount: toSend, - } - return []sdk.Msg{&sdkMsg}, nil -} - -func NoCustomMsg(sender sdk.AccAddress, msg json.RawMessage) ([]sdk.Msg, error) { - return nil, sdkerrors.Wrap(types.ErrInvalidMsg, "Custom variant not supported") -} - -func EncodeStakingMsg(sender sdk.AccAddress, msg *wasmvmtypes.StakingMsg) ([]sdk.Msg, error) { - switch { - case msg.Delegate != nil: - coin, err := convertWasmCoinToSdkCoin(msg.Delegate.Amount) - if err != nil { - return nil, err - } - sdkMsg := stakingtypes.MsgDelegate{ - DelegatorAddress: sender.String(), - ValidatorAddress: msg.Delegate.Validator, - Amount: coin, - } - return []sdk.Msg{&sdkMsg}, nil - - case msg.Redelegate != nil: - coin, err := convertWasmCoinToSdkCoin(msg.Redelegate.Amount) - if err != nil { - return nil, err - } - sdkMsg := stakingtypes.MsgBeginRedelegate{ - DelegatorAddress: sender.String(), - ValidatorSrcAddress: msg.Redelegate.SrcValidator, - ValidatorDstAddress: msg.Redelegate.DstValidator, - Amount: coin, - } - return []sdk.Msg{&sdkMsg}, nil - case msg.Undelegate != nil: - coin, err := convertWasmCoinToSdkCoin(msg.Undelegate.Amount) - if err != nil { - return nil, err - } - sdkMsg := stakingtypes.MsgUndelegate{ - DelegatorAddress: sender.String(), - ValidatorAddress: msg.Undelegate.Validator, - Amount: coin, - } - return []sdk.Msg{&sdkMsg}, nil - case msg.Withdraw != nil: - senderAddr := sender.String() - rcpt := senderAddr - if len(msg.Withdraw.Recipient) != 0 { - rcpt = msg.Withdraw.Recipient - } - setMsg := distributiontypes.MsgSetWithdrawAddress{ - DelegatorAddress: senderAddr, - WithdrawAddress: rcpt, - } - withdrawMsg := distributiontypes.MsgWithdrawDelegatorReward{ - DelegatorAddress: senderAddr, - ValidatorAddress: msg.Withdraw.Validator, - } - return []sdk.Msg{&setMsg, &withdrawMsg}, nil - default: - return nil, sdkerrors.Wrap(types.ErrInvalidMsg, "Unknown variant of Staking") - } -} - -func EncodeStargateMsg(unpacker codectypes.AnyUnpacker) StargateEncoder { - return func(sender sdk.AccAddress, msg *wasmvmtypes.StargateMsg) ([]sdk.Msg, error) { - any := codectypes.Any{ - TypeUrl: msg.TypeURL, - Value: msg.Value, - } - var sdkMsg sdk.Msg - if err := unpacker.UnpackAny(&any, &sdkMsg); err != nil { - return nil, sdkerrors.Wrap(types.ErrInvalidMsg, fmt.Sprintf("Cannot unpack proto message with type URL: %s", msg.TypeURL)) - } - if err := codectypes.UnpackInterfaces(sdkMsg, unpacker); err != nil { - return nil, sdkerrors.Wrap(types.ErrInvalidMsg, fmt.Sprintf("UnpackInterfaces inside msg: %s", err)) - } - return []sdk.Msg{sdkMsg}, nil - } -} - -func EncodeWasmMsg(sender sdk.AccAddress, msg *wasmvmtypes.WasmMsg) ([]sdk.Msg, error) { - switch { - case msg.Execute != nil: - coins, err := convertWasmCoinsToSdkCoins(msg.Execute.Send) - if err != nil { - return nil, err - } - - sdkMsg := types.MsgExecuteContract{ - Sender: sender.String(), - Contract: msg.Execute.ContractAddr, - Msg: msg.Execute.Msg, - Funds: coins, - } - return []sdk.Msg{&sdkMsg}, nil - case msg.Instantiate != nil: - coins, err := convertWasmCoinsToSdkCoins(msg.Instantiate.Send) - if err != nil { - return nil, err - } - - sdkMsg := types.MsgInstantiateContract{ - Sender: sender.String(), - CodeID: msg.Instantiate.CodeID, - Label: msg.Instantiate.Label, - InitMsg: msg.Instantiate.Msg, - Funds: coins, - } - return []sdk.Msg{&sdkMsg}, nil - case msg.Migrate != nil: - sdkMsg := types.MsgMigrateContract{ - Sender: sender.String(), - Contract: msg.Migrate.ContractAddr, - CodeID: msg.Migrate.NewCodeID, - MigrateMsg: msg.Migrate.Msg, - } - return []sdk.Msg{&sdkMsg}, nil - default: - return nil, sdkerrors.Wrap(types.ErrInvalidMsg, "Unknown variant of Wasm") - } -} - -func EncodeIBCMsg(channelKeeper types.ChannelKeeper, capabilityKeeper types.CapabilityKeeper) IBCEncoder { - return func(ctx sdk.Context, sender sdk.AccAddress, contractIBCPortID string, msg *wasmvmtypes.IBCMsg) ([]sdk.Msg, error) { - switch { - case msg.SendPacket != nil: - if contractIBCPortID == "" { - return nil, sdkerrors.Wrapf(types.ErrUnsupportedForContract, "ibc not supported") - } - contractIBCChannelID := msg.SendPacket.ChannelID - if contractIBCChannelID == "" { - return nil, sdkerrors.Wrapf(types.ErrEmpty, "ibc channel") - } - - sequence, found := channelKeeper.GetNextSequenceSend(ctx, contractIBCPortID, contractIBCChannelID) - if !found { - return nil, sdkerrors.Wrapf( - channeltypes.ErrSequenceSendNotFound, - "source port: %s, source channel: %s", contractIBCPortID, contractIBCChannelID, - ) - } - - channelInfo, ok := channelKeeper.GetChannel(ctx, contractIBCPortID, contractIBCChannelID) - if !ok { - return nil, sdkerrors.Wrap(channeltypes.ErrInvalidChannel, "not found") - } - channelCap, ok := capabilityKeeper.GetCapability(ctx, host.ChannelCapabilityPath(contractIBCPortID, contractIBCChannelID)) - if !ok { - return nil, sdkerrors.Wrap(channeltypes.ErrChannelCapabilityNotFound, "module does not own channel capability") - } - packet := channeltypes.NewPacket( - msg.SendPacket.Data, - sequence, - contractIBCPortID, - contractIBCChannelID, - channelInfo.Counterparty.PortId, - channelInfo.Counterparty.ChannelId, - convertWasmIBCTimeoutHeightToCosmosHeight(msg.SendPacket.TimeoutBlock), - convertWasmIBCTimeoutTimestampToCosmosTimestamp(msg.SendPacket.TimeoutTimestamp), - ) - return nil, channelKeeper.SendPacket(ctx, channelCap, packet) - case msg.CloseChannel != nil: - return []sdk.Msg{&channeltypes.MsgChannelCloseInit{ - PortId: PortIDForContract(sender), - ChannelId: msg.CloseChannel.ChannelID, - Signer: sender.String(), - }}, nil - case msg.Transfer != nil: - amount, err := convertWasmCoinToSdkCoin(msg.Transfer.Amount) - if err != nil { - return nil, sdkerrors.Wrap(err, "amount") - } - portID := ibctransfertypes.ModuleName //todo: port can be customized in genesis. make this more flexible - msg := &ibctransfertypes.MsgTransfer{ - SourcePort: portID, - SourceChannel: msg.Transfer.ChannelID, - Token: amount, - Sender: sender.String(), - Receiver: msg.Transfer.ToAddress, - TimeoutHeight: convertWasmIBCTimeoutHeightToCosmosHeight(msg.Transfer.TimeoutBlock), - TimeoutTimestamp: convertWasmIBCTimeoutTimestampToCosmosTimestamp(msg.Transfer.TimeoutTimestamp), - } - return []sdk.Msg{msg}, nil - default: - return nil, sdkerrors.Wrap(types.ErrInvalidMsg, "Unknown variant of IBC") - } - } +// SDKMessageHandler can handles messages that can be encoded into sdk.Message types and routed. +type SDKMessageHandler struct { + router sdk.Router + encoders msgEncoder } -func convertWasmIBCTimeoutHeightToCosmosHeight(ibcTimeoutBlock *wasmvmtypes.IBCTimeoutBlock) ibcclienttypes.Height { - if ibcTimeoutBlock == nil { - return ibcclienttypes.NewHeight(0, 0) +func NewDefaultMessageHandler(router sdk.Router, channelKeeper types.ChannelKeeper, capabilityKeeper types.CapabilityKeeper, unpacker codectypes.AnyUnpacker, customEncoders ...*MessageEncoders) messenger { + encoders := DefaultEncoders(unpacker) + for _, e := range customEncoders { + encoders = encoders.Merge(e) } - return ibcclienttypes.NewHeight(ibcTimeoutBlock.Revision, ibcTimeoutBlock.Height) + return NewMessageHandlerChain( + NewSDKMessageHandler(router, encoders), + NewIBCRawPacketHandler(channelKeeper, capabilityKeeper), + ) } -func convertWasmIBCTimeoutTimestampToCosmosTimestamp(timestamp *uint64) uint64 { - if timestamp == nil { - return 0 +func NewSDKMessageHandler(router sdk.Router, encoders msgEncoder) SDKMessageHandler { + return SDKMessageHandler{ + router: router, + encoders: encoders, } - return *timestamp } -func (h DefaultMessageHandler) DispatchMsg(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg) (events []sdk.Event, data [][]byte, err error) { +func (h SDKMessageHandler) DispatchMsg(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg) (events []sdk.Event, data [][]byte, err error) { sdkMsgs, err := h.encoders.Encode(ctx, contractAddr, contractIBCPortID, msg) if err != nil { return nil, nil, err @@ -342,7 +64,7 @@ func (h DefaultMessageHandler) DispatchMsg(ctx sdk.Context, contractAddr sdk.Acc return } -func (h DefaultMessageHandler) handleSdkMessage(ctx sdk.Context, contractAddr sdk.Address, msg sdk.Msg) (*sdk.Result, error) { +func (h SDKMessageHandler) handleSdkMessage(ctx sdk.Context, contractAddr sdk.Address, msg sdk.Msg) (*sdk.Result, error) { if err := msg.ValidateBasic(); err != nil { return nil, err } @@ -365,25 +87,84 @@ func (h DefaultMessageHandler) handleSdkMessage(ctx sdk.Context, contractAddr sd return res, nil } -func convertWasmCoinsToSdkCoins(coins []wasmvmtypes.Coin) (sdk.Coins, error) { - var toSend sdk.Coins - for _, coin := range coins { - c, err := convertWasmCoinToSdkCoin(coin) - if err != nil { - return nil, err +// MessageHandlerChain defines a chain of handlers that are called one by one until it can be handled. +type MessageHandlerChain struct { + handlers []messenger +} + +func NewMessageHandlerChain(first messenger, others ...messenger) *MessageHandlerChain { + r := &MessageHandlerChain{handlers: append([]messenger{first}, others...)} + for i := range r.handlers { + if r.handlers[i] == nil { + panic(fmt.Sprintf("handler must not be nil at position : %d", i)) + } + } + return r +} + +// DispatchMsg dispatch message to handlers. +func (m MessageHandlerChain) DispatchMsg(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg) ([]sdk.Event, [][]byte, error) { + for _, h := range m.handlers { + events, data, err := h.DispatchMsg(ctx, contractAddr, contractIBCPortID, msg) + switch { + case err == nil: + return events, data, err + case errors.Is(err, types.ErrUnknownMsg): + continue + default: + return events, data, err } - toSend = append(toSend, c) } - return toSend, nil + return nil, nil, sdkerrors.Wrap(types.ErrUnknownMsg, "no handler found") +} + +// IBCRawPacketHandler handels IBC.SendPacket messages which are published to an IBC channel. +type IBCRawPacketHandler struct { + channelKeeper types.ChannelKeeper + capabilityKeeper types.CapabilityKeeper } -func convertWasmCoinToSdkCoin(coin wasmvmtypes.Coin) (sdk.Coin, error) { - amount, ok := sdk.NewIntFromString(coin.Amount) +func NewIBCRawPacketHandler(chk types.ChannelKeeper, cak types.CapabilityKeeper) *IBCRawPacketHandler { + return &IBCRawPacketHandler{channelKeeper: chk, capabilityKeeper: cak} +} + +// DispatchMsg publishes a raw IBC packet onto the channel. +func (h IBCRawPacketHandler) DispatchMsg(ctx sdk.Context, _ sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg) (events []sdk.Event, data [][]byte, err error) { + if msg.IBC == nil || msg.IBC.SendPacket == nil { + return nil, nil, types.ErrUnknownMsg + } + if contractIBCPortID == "" { + return nil, nil, sdkerrors.Wrapf(types.ErrUnsupportedForContract, "ibc not supported") + } + contractIBCChannelID := msg.IBC.SendPacket.ChannelID + if contractIBCChannelID == "" { + return nil, nil, sdkerrors.Wrapf(types.ErrEmpty, "ibc channel") + } + + sequence, found := h.channelKeeper.GetNextSequenceSend(ctx, contractIBCPortID, contractIBCChannelID) + if !found { + return nil, nil, sdkerrors.Wrapf(channeltypes.ErrSequenceSendNotFound, + "source port: %s, source channel: %s", contractIBCPortID, contractIBCChannelID, + ) + } + + channelInfo, ok := h.channelKeeper.GetChannel(ctx, contractIBCPortID, contractIBCChannelID) if !ok { - return sdk.Coin{}, sdkerrors.Wrap(sdkerrors.ErrInvalidCoins, coin.Amount+coin.Denom) + return nil, nil, sdkerrors.Wrap(channeltypes.ErrInvalidChannel, "not found") } - return sdk.Coin{ - Denom: coin.Denom, - Amount: amount, - }, nil + channelCap, ok := h.capabilityKeeper.GetCapability(ctx, host.ChannelCapabilityPath(contractIBCPortID, contractIBCChannelID)) + if !ok { + return nil, nil, sdkerrors.Wrap(channeltypes.ErrChannelCapabilityNotFound, "module does not own channel capability") + } + packet := channeltypes.NewPacket( + msg.IBC.SendPacket.Data, + sequence, + contractIBCPortID, + contractIBCChannelID, + channelInfo.Counterparty.PortId, + channelInfo.Counterparty.ChannelId, + convertWasmIBCTimeoutHeightToCosmosHeight(msg.IBC.SendPacket.TimeoutBlock), + convertWasmIBCTimeoutTimestampToCosmosTimestamp(msg.IBC.SendPacket.TimeoutTimestamp), + ) + return nil, nil, h.channelKeeper.SendPacket(ctx, channelCap, packet) } diff --git a/x/wasm/internal/keeper/handler_plugin_encoders.go b/x/wasm/internal/keeper/handler_plugin_encoders.go new file mode 100644 index 0000000000..20b19414da --- /dev/null +++ b/x/wasm/internal/keeper/handler_plugin_encoders.go @@ -0,0 +1,292 @@ +package keeper + +import ( + "encoding/json" + "fmt" + "github.com/CosmWasm/wasmd/x/wasm/internal/types" + wasmvmtypes "github.com/CosmWasm/wasmvm/types" + codectypes "github.com/cosmos/cosmos-sdk/codec/types" + sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" + banktypes "github.com/cosmos/cosmos-sdk/x/bank/types" + distributiontypes "github.com/cosmos/cosmos-sdk/x/distribution/types" + ibctransfertypes "github.com/cosmos/cosmos-sdk/x/ibc/applications/transfer/types" + ibcclienttypes "github.com/cosmos/cosmos-sdk/x/ibc/core/02-client/types" + channeltypes "github.com/cosmos/cosmos-sdk/x/ibc/core/04-channel/types" + stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" +) + +type BankEncoder func(sender sdk.AccAddress, msg *wasmvmtypes.BankMsg) ([]sdk.Msg, error) +type CustomEncoder func(sender sdk.AccAddress, msg json.RawMessage) ([]sdk.Msg, error) +type StakingEncoder func(sender sdk.AccAddress, msg *wasmvmtypes.StakingMsg) ([]sdk.Msg, error) +type StargateEncoder func(sender sdk.AccAddress, msg *wasmvmtypes.StargateMsg) ([]sdk.Msg, error) +type WasmEncoder func(sender sdk.AccAddress, msg *wasmvmtypes.WasmMsg) ([]sdk.Msg, error) +type IBCEncoder func(ctx sdk.Context, sender sdk.AccAddress, contractIBCPortID string, msg *wasmvmtypes.IBCMsg) ([]sdk.Msg, error) + +type MessageEncoders struct { + Bank func(sender sdk.AccAddress, msg *wasmvmtypes.BankMsg) ([]sdk.Msg, error) + Custom func(sender sdk.AccAddress, msg json.RawMessage) ([]sdk.Msg, error) + IBC func(ctx sdk.Context, sender sdk.AccAddress, contractIBCPortID string, msg *wasmvmtypes.IBCMsg) ([]sdk.Msg, error) + Staking func(sender sdk.AccAddress, msg *wasmvmtypes.StakingMsg) ([]sdk.Msg, error) + Stargate func(sender sdk.AccAddress, msg *wasmvmtypes.StargateMsg) ([]sdk.Msg, error) + Wasm func(sender sdk.AccAddress, msg *wasmvmtypes.WasmMsg) ([]sdk.Msg, error) +} + +func DefaultEncoders(unpacker codectypes.AnyUnpacker) MessageEncoders { + return MessageEncoders{ + Bank: EncodeBankMsg, + Custom: NoCustomMsg, + IBC: EncodeIBCMsg, + Staking: EncodeStakingMsg, + Stargate: EncodeStargateMsg(unpacker), + Wasm: EncodeWasmMsg, + } +} + +func (e MessageEncoders) Merge(o *MessageEncoders) MessageEncoders { + if o == nil { + return e + } + if o.Bank != nil { + e.Bank = o.Bank + } + if o.Custom != nil { + e.Custom = o.Custom + } + if o.IBC != nil { + e.IBC = o.IBC + } + if o.Staking != nil { + e.Staking = o.Staking + } + if o.Stargate != nil { + e.Stargate = o.Stargate + } + if o.Wasm != nil { + e.Wasm = o.Wasm + } + return e +} + +func (e MessageEncoders) Encode(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg) ([]sdk.Msg, error) { + switch { + case msg.Bank != nil: + return e.Bank(contractAddr, msg.Bank) + case msg.Custom != nil: + return e.Custom(contractAddr, msg.Custom) + case msg.IBC != nil: + return e.IBC(ctx, contractAddr, contractIBCPortID, msg.IBC) + case msg.Staking != nil: + return e.Staking(contractAddr, msg.Staking) + case msg.Stargate != nil: + return e.Stargate(contractAddr, msg.Stargate) + case msg.Wasm != nil: + return e.Wasm(contractAddr, msg.Wasm) + } + return nil, sdkerrors.Wrap(types.ErrUnknownMsg, "unknown variant of Wasm") +} + +func EncodeBankMsg(sender sdk.AccAddress, msg *wasmvmtypes.BankMsg) ([]sdk.Msg, error) { + if msg.Send == nil { + return nil, sdkerrors.Wrap(types.ErrUnknownMsg, "unknown variant of Bank") + } + if len(msg.Send.Amount) == 0 { + return nil, nil + } + toSend, err := convertWasmCoinsToSdkCoins(msg.Send.Amount) + if err != nil { + return nil, err + } + sdkMsg := banktypes.MsgSend{ + FromAddress: sender.String(), + ToAddress: msg.Send.ToAddress, + Amount: toSend, + } + return []sdk.Msg{&sdkMsg}, nil +} + +func NoCustomMsg(sender sdk.AccAddress, msg json.RawMessage) ([]sdk.Msg, error) { + return nil, sdkerrors.Wrap(types.ErrUnknownMsg, "custom variant not supported") +} + +func EncodeStakingMsg(sender sdk.AccAddress, msg *wasmvmtypes.StakingMsg) ([]sdk.Msg, error) { + switch { + case msg.Delegate != nil: + coin, err := convertWasmCoinToSdkCoin(msg.Delegate.Amount) + if err != nil { + return nil, err + } + sdkMsg := stakingtypes.MsgDelegate{ + DelegatorAddress: sender.String(), + ValidatorAddress: msg.Delegate.Validator, + Amount: coin, + } + return []sdk.Msg{&sdkMsg}, nil + + case msg.Redelegate != nil: + coin, err := convertWasmCoinToSdkCoin(msg.Redelegate.Amount) + if err != nil { + return nil, err + } + sdkMsg := stakingtypes.MsgBeginRedelegate{ + DelegatorAddress: sender.String(), + ValidatorSrcAddress: msg.Redelegate.SrcValidator, + ValidatorDstAddress: msg.Redelegate.DstValidator, + Amount: coin, + } + return []sdk.Msg{&sdkMsg}, nil + case msg.Undelegate != nil: + coin, err := convertWasmCoinToSdkCoin(msg.Undelegate.Amount) + if err != nil { + return nil, err + } + sdkMsg := stakingtypes.MsgUndelegate{ + DelegatorAddress: sender.String(), + ValidatorAddress: msg.Undelegate.Validator, + Amount: coin, + } + return []sdk.Msg{&sdkMsg}, nil + case msg.Withdraw != nil: + senderAddr := sender.String() + rcpt := senderAddr + if len(msg.Withdraw.Recipient) != 0 { + rcpt = msg.Withdraw.Recipient + } + setMsg := distributiontypes.MsgSetWithdrawAddress{ + DelegatorAddress: senderAddr, + WithdrawAddress: rcpt, + } + withdrawMsg := distributiontypes.MsgWithdrawDelegatorReward{ + DelegatorAddress: senderAddr, + ValidatorAddress: msg.Withdraw.Validator, + } + return []sdk.Msg{&setMsg, &withdrawMsg}, nil + default: + return nil, sdkerrors.Wrap(types.ErrUnknownMsg, "unknown variant of Staking") + } +} + +func EncodeStargateMsg(unpacker codectypes.AnyUnpacker) StargateEncoder { + return func(sender sdk.AccAddress, msg *wasmvmtypes.StargateMsg) ([]sdk.Msg, error) { + any := codectypes.Any{ + TypeUrl: msg.TypeURL, + Value: msg.Value, + } + var sdkMsg sdk.Msg + if err := unpacker.UnpackAny(&any, &sdkMsg); err != nil { + return nil, sdkerrors.Wrap(types.ErrInvalidMsg, fmt.Sprintf("Cannot unpack proto message with type URL: %s", msg.TypeURL)) + } + if err := codectypes.UnpackInterfaces(sdkMsg, unpacker); err != nil { + return nil, sdkerrors.Wrap(types.ErrInvalidMsg, fmt.Sprintf("UnpackInterfaces inside msg: %s", err)) + } + return []sdk.Msg{sdkMsg}, nil + } +} + +func EncodeWasmMsg(sender sdk.AccAddress, msg *wasmvmtypes.WasmMsg) ([]sdk.Msg, error) { + switch { + case msg.Execute != nil: + coins, err := convertWasmCoinsToSdkCoins(msg.Execute.Send) + if err != nil { + return nil, err + } + + sdkMsg := types.MsgExecuteContract{ + Sender: sender.String(), + Contract: msg.Execute.ContractAddr, + Msg: msg.Execute.Msg, + Funds: coins, + } + return []sdk.Msg{&sdkMsg}, nil + case msg.Instantiate != nil: + coins, err := convertWasmCoinsToSdkCoins(msg.Instantiate.Send) + if err != nil { + return nil, err + } + + sdkMsg := types.MsgInstantiateContract{ + Sender: sender.String(), + CodeID: msg.Instantiate.CodeID, + Label: msg.Instantiate.Label, + InitMsg: msg.Instantiate.Msg, + Funds: coins, + } + return []sdk.Msg{&sdkMsg}, nil + case msg.Migrate != nil: + sdkMsg := types.MsgMigrateContract{ + Sender: sender.String(), + Contract: msg.Migrate.ContractAddr, + CodeID: msg.Migrate.NewCodeID, + MigrateMsg: msg.Migrate.Msg, + } + return []sdk.Msg{&sdkMsg}, nil + default: + return nil, sdkerrors.Wrap(types.ErrUnknownMsg, "unknown variant of Wasm") + } +} + +func EncodeIBCMsg(ctx sdk.Context, sender sdk.AccAddress, contractIBCPortID string, msg *wasmvmtypes.IBCMsg) ([]sdk.Msg, error) { + switch { + case msg.CloseChannel != nil: + return []sdk.Msg{&channeltypes.MsgChannelCloseInit{ + PortId: PortIDForContract(sender), + ChannelId: msg.CloseChannel.ChannelID, + Signer: sender.String(), + }}, nil + case msg.Transfer != nil: + amount, err := convertWasmCoinToSdkCoin(msg.Transfer.Amount) + if err != nil { + return nil, sdkerrors.Wrap(err, "amount") + } + portID := ibctransfertypes.ModuleName //todo: port can be customized in genesis. make this more flexible + msg := &ibctransfertypes.MsgTransfer{ + SourcePort: portID, + SourceChannel: msg.Transfer.ChannelID, + Token: amount, + Sender: sender.String(), + Receiver: msg.Transfer.ToAddress, + TimeoutHeight: convertWasmIBCTimeoutHeightToCosmosHeight(msg.Transfer.TimeoutBlock), + TimeoutTimestamp: convertWasmIBCTimeoutTimestampToCosmosTimestamp(msg.Transfer.TimeoutTimestamp), + } + return []sdk.Msg{msg}, nil + default: + return nil, sdkerrors.Wrap(types.ErrUnknownMsg, "Unknown variant of IBC") + } +} + +func convertWasmIBCTimeoutHeightToCosmosHeight(ibcTimeoutBlock *wasmvmtypes.IBCTimeoutBlock) ibcclienttypes.Height { + if ibcTimeoutBlock == nil { + return ibcclienttypes.NewHeight(0, 0) + } + return ibcclienttypes.NewHeight(ibcTimeoutBlock.Revision, ibcTimeoutBlock.Height) +} + +func convertWasmIBCTimeoutTimestampToCosmosTimestamp(timestamp *uint64) uint64 { + if timestamp == nil { + return 0 + } + return *timestamp +} + +func convertWasmCoinsToSdkCoins(coins []wasmvmtypes.Coin) (sdk.Coins, error) { + var toSend sdk.Coins + for _, coin := range coins { + c, err := convertWasmCoinToSdkCoin(coin) + if err != nil { + return nil, err + } + toSend = append(toSend, c) + } + return toSend, nil +} + +func convertWasmCoinToSdkCoin(coin wasmvmtypes.Coin) (sdk.Coin, error) { + amount, ok := sdk.NewIntFromString(coin.Amount) + if !ok { + return sdk.Coin{}, sdkerrors.Wrap(sdkerrors.ErrInvalidCoins, coin.Amount+coin.Denom) + } + return sdk.Coin{ + Denom: coin.Denom, + Amount: amount, + }, nil +} diff --git a/x/wasm/internal/keeper/handler_plugin_encoders_test.go b/x/wasm/internal/keeper/handler_plugin_encoders_test.go new file mode 100644 index 0000000000..0151a4bf33 --- /dev/null +++ b/x/wasm/internal/keeper/handler_plugin_encoders_test.go @@ -0,0 +1,445 @@ +package keeper + +import ( + "encoding/json" + codectypes "github.com/cosmos/cosmos-sdk/codec/types" + govtypes "github.com/cosmos/cosmos-sdk/x/gov/types" + ibctransfertypes "github.com/cosmos/cosmos-sdk/x/ibc/applications/transfer/types" + clienttypes "github.com/cosmos/cosmos-sdk/x/ibc/core/02-client/types" + channeltypes "github.com/cosmos/cosmos-sdk/x/ibc/core/04-channel/types" + "github.com/golang/protobuf/proto" + "testing" + + "github.com/CosmWasm/wasmd/x/wasm/internal/types" + wasmvmtypes "github.com/CosmWasm/wasmvm/types" + sdk "github.com/cosmos/cosmos-sdk/types" + banktypes "github.com/cosmos/cosmos-sdk/x/bank/types" + distributiontypes "github.com/cosmos/cosmos-sdk/x/distribution/types" + stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEncoding(t *testing.T) { + addr1 := RandomAccountAddress(t) + addr2 := RandomAccountAddress(t) + invalidAddr := "xrnd1d02kd90n38qvr3qb9qof83fn2d2" + valAddr := make(sdk.ValAddress, sdk.AddrLen) + valAddr[0] = 12 + valAddr2 := make(sdk.ValAddress, sdk.AddrLen) + valAddr2[1] = 123 + var timeoutVal uint64 = 100 + + jsonMsg := json.RawMessage(`{"foo": 123}`) + + bankMsg := &banktypes.MsgSend{ + FromAddress: addr2.String(), + ToAddress: addr1.String(), + Amount: sdk.Coins{ + sdk.NewInt64Coin("uatom", 12345), + sdk.NewInt64Coin("utgd", 54321), + }, + } + bankMsgBin, err := proto.Marshal(bankMsg) + require.NoError(t, err) + + content, err := codectypes.NewAnyWithValue(types.StoreCodeProposalFixture()) + require.NoError(t, err) + + proposalMsg := &govtypes.MsgSubmitProposal{ + Proposer: addr1.String(), + InitialDeposit: sdk.NewCoins(sdk.NewInt64Coin("uatom", 12345)), + Content: content, + } + proposalMsgBin, err := proto.Marshal(proposalMsg) + require.NoError(t, err) + + cases := map[string]struct { + sender sdk.AccAddress + srcMsg wasmvmtypes.CosmosMsg + srcIBCPort string + // set if valid + output []sdk.Msg + // set if invalid + isError bool + }{ + "simple send": { + sender: addr1, + srcMsg: wasmvmtypes.CosmosMsg{ + Bank: &wasmvmtypes.BankMsg{ + Send: &wasmvmtypes.SendMsg{ + ToAddress: addr2.String(), + Amount: []wasmvmtypes.Coin{ + { + Denom: "uatom", + Amount: "12345", + }, + { + Denom: "usdt", + Amount: "54321", + }, + }, + }, + }, + }, + output: []sdk.Msg{ + &banktypes.MsgSend{ + FromAddress: addr1.String(), + ToAddress: addr2.String(), + Amount: sdk.Coins{ + sdk.NewInt64Coin("uatom", 12345), + sdk.NewInt64Coin("usdt", 54321), + }, + }, + }, + }, + "invalid send amount": { + sender: addr1, + srcMsg: wasmvmtypes.CosmosMsg{ + Bank: &wasmvmtypes.BankMsg{ + Send: &wasmvmtypes.SendMsg{ + ToAddress: addr2.String(), + Amount: []wasmvmtypes.Coin{ + { + Denom: "uatom", + Amount: "123.456", + }, + }, + }, + }, + }, + isError: true, + }, + "invalid address": { + sender: addr1, + srcMsg: wasmvmtypes.CosmosMsg{ + Bank: &wasmvmtypes.BankMsg{ + Send: &wasmvmtypes.SendMsg{ + ToAddress: invalidAddr, + Amount: []wasmvmtypes.Coin{ + { + Denom: "uatom", + Amount: "7890", + }, + }, + }, + }, + }, + isError: false, // addresses are checked in the handler + output: []sdk.Msg{ + &banktypes.MsgSend{ + FromAddress: addr1.String(), + ToAddress: invalidAddr, + Amount: sdk.Coins{ + sdk.NewInt64Coin("uatom", 7890), + }, + }, + }, + }, + "wasm execute": { + sender: addr1, + srcMsg: wasmvmtypes.CosmosMsg{ + Wasm: &wasmvmtypes.WasmMsg{ + Execute: &wasmvmtypes.ExecuteMsg{ + ContractAddr: addr2.String(), + Msg: jsonMsg, + Send: []wasmvmtypes.Coin{ + wasmvmtypes.NewCoin(12, "eth"), + }, + }, + }, + }, + output: []sdk.Msg{ + &types.MsgExecuteContract{ + Sender: addr1.String(), + Contract: addr2.String(), + Msg: jsonMsg, + Funds: sdk.NewCoins(sdk.NewInt64Coin("eth", 12)), + }, + }, + }, + "wasm instantiate": { + sender: addr1, + srcMsg: wasmvmtypes.CosmosMsg{ + Wasm: &wasmvmtypes.WasmMsg{ + Instantiate: &wasmvmtypes.InstantiateMsg{ + CodeID: 7, + Msg: jsonMsg, + Send: []wasmvmtypes.Coin{ + wasmvmtypes.NewCoin(123, "eth"), + }, + Label: "myLabel", + }, + }, + }, + output: []sdk.Msg{ + &types.MsgInstantiateContract{ + Sender: addr1.String(), + CodeID: 7, + Label: "myLabel", + InitMsg: jsonMsg, + Funds: sdk.NewCoins(sdk.NewInt64Coin("eth", 123)), + }, + }, + }, + "wasm migrate": { + sender: addr2, + srcMsg: wasmvmtypes.CosmosMsg{ + Wasm: &wasmvmtypes.WasmMsg{ + Migrate: &wasmvmtypes.MigrateMsg{ + ContractAddr: addr1.String(), + NewCodeID: 12, + Msg: jsonMsg, + }, + }, + }, + output: []sdk.Msg{ + &types.MsgMigrateContract{ + Sender: addr2.String(), + Contract: addr1.String(), + CodeID: 12, + MigrateMsg: jsonMsg, + }, + }, + }, + "staking delegate": { + sender: addr1, + srcMsg: wasmvmtypes.CosmosMsg{ + Staking: &wasmvmtypes.StakingMsg{ + Delegate: &wasmvmtypes.DelegateMsg{ + Validator: valAddr.String(), + Amount: wasmvmtypes.NewCoin(777, "stake"), + }, + }, + }, + output: []sdk.Msg{ + &stakingtypes.MsgDelegate{ + DelegatorAddress: addr1.String(), + ValidatorAddress: valAddr.String(), + Amount: sdk.NewInt64Coin("stake", 777), + }, + }, + }, + "staking delegate to non-validator": { + sender: addr1, + srcMsg: wasmvmtypes.CosmosMsg{ + Staking: &wasmvmtypes.StakingMsg{ + Delegate: &wasmvmtypes.DelegateMsg{ + Validator: addr2.String(), + Amount: wasmvmtypes.NewCoin(777, "stake"), + }, + }, + }, + isError: false, // fails in the handler + output: []sdk.Msg{ + &stakingtypes.MsgDelegate{ + DelegatorAddress: addr1.String(), + ValidatorAddress: addr2.String(), + Amount: sdk.NewInt64Coin("stake", 777), + }, + }, + }, + "staking undelegate": { + sender: addr1, + srcMsg: wasmvmtypes.CosmosMsg{ + Staking: &wasmvmtypes.StakingMsg{ + Undelegate: &wasmvmtypes.UndelegateMsg{ + Validator: valAddr.String(), + Amount: wasmvmtypes.NewCoin(555, "stake"), + }, + }, + }, + output: []sdk.Msg{ + &stakingtypes.MsgUndelegate{ + DelegatorAddress: addr1.String(), + ValidatorAddress: valAddr.String(), + Amount: sdk.NewInt64Coin("stake", 555), + }, + }, + }, + "staking redelegate": { + sender: addr1, + srcMsg: wasmvmtypes.CosmosMsg{ + Staking: &wasmvmtypes.StakingMsg{ + Redelegate: &wasmvmtypes.RedelegateMsg{ + SrcValidator: valAddr.String(), + DstValidator: valAddr2.String(), + Amount: wasmvmtypes.NewCoin(222, "stake"), + }, + }, + }, + output: []sdk.Msg{ + &stakingtypes.MsgBeginRedelegate{ + DelegatorAddress: addr1.String(), + ValidatorSrcAddress: valAddr.String(), + ValidatorDstAddress: valAddr2.String(), + Amount: sdk.NewInt64Coin("stake", 222), + }, + }, + }, + "staking withdraw (implicit recipient)": { + sender: addr1, + srcMsg: wasmvmtypes.CosmosMsg{ + Staking: &wasmvmtypes.StakingMsg{ + Withdraw: &wasmvmtypes.WithdrawMsg{ + Validator: valAddr2.String(), + }, + }, + }, + output: []sdk.Msg{ + &distributiontypes.MsgSetWithdrawAddress{ + DelegatorAddress: addr1.String(), + WithdrawAddress: addr1.String(), + }, + &distributiontypes.MsgWithdrawDelegatorReward{ + DelegatorAddress: addr1.String(), + ValidatorAddress: valAddr2.String(), + }, + }, + }, + "staking withdraw (explicit recipient)": { + sender: addr1, + srcMsg: wasmvmtypes.CosmosMsg{ + Staking: &wasmvmtypes.StakingMsg{ + Withdraw: &wasmvmtypes.WithdrawMsg{ + Validator: valAddr2.String(), + Recipient: addr2.String(), + }, + }, + }, + output: []sdk.Msg{ + &distributiontypes.MsgSetWithdrawAddress{ + DelegatorAddress: addr1.String(), + WithdrawAddress: addr2.String(), + }, + &distributiontypes.MsgWithdrawDelegatorReward{ + DelegatorAddress: addr1.String(), + ValidatorAddress: valAddr2.String(), + }, + }, + }, + "stargate encoded bank msg": { + sender: addr2, + srcMsg: wasmvmtypes.CosmosMsg{ + Stargate: &wasmvmtypes.StargateMsg{ + TypeURL: "/cosmos.bank.v1beta1.MsgSend", + Value: bankMsgBin, + }, + }, + output: []sdk.Msg{bankMsg}, + }, + "stargate encoded msg with any type": { + sender: addr2, + srcMsg: wasmvmtypes.CosmosMsg{ + Stargate: &wasmvmtypes.StargateMsg{ + TypeURL: "/cosmos.gov.v1beta1.MsgSubmitProposal", + Value: proposalMsgBin, + }, + }, + output: []sdk.Msg{proposalMsg}, + }, + "stargate encoded invalid typeUrl": { + sender: addr2, + srcMsg: wasmvmtypes.CosmosMsg{ + Stargate: &wasmvmtypes.StargateMsg{ + TypeURL: "/cosmos.bank.v2.MsgSend", + Value: bankMsgBin, + }, + }, + isError: true, + }, + "IBC transfer with block timeout": { + sender: addr1, + srcIBCPort: "myIBCPort", + srcMsg: wasmvmtypes.CosmosMsg{ + IBC: &wasmvmtypes.IBCMsg{ + Transfer: &wasmvmtypes.TransferMsg{ + ChannelID: "myChanID", + ToAddress: addr2.String(), + Amount: wasmvmtypes.Coin{ + Denom: "ALX", + Amount: "1", + }, + TimeoutBlock: &wasmvmtypes.IBCTimeoutBlock{Revision: 1, Height: 2}, + }, + }, + }, + output: []sdk.Msg{ + &ibctransfertypes.MsgTransfer{ + SourcePort: "transfer", + SourceChannel: "myChanID", + Token: sdk.Coin{ + Denom: "ALX", + Amount: sdk.NewInt(1), + }, + Sender: addr1.String(), + Receiver: addr2.String(), + TimeoutHeight: clienttypes.Height{RevisionNumber: 1, RevisionHeight: 2}, + }, + }, + }, + "IBC transfer with time timeout": { + sender: addr1, + srcIBCPort: "myIBCPort", + srcMsg: wasmvmtypes.CosmosMsg{ + IBC: &wasmvmtypes.IBCMsg{ + Transfer: &wasmvmtypes.TransferMsg{ + ChannelID: "myChanID", + ToAddress: addr2.String(), + Amount: wasmvmtypes.Coin{ + Denom: "ALX", + Amount: "1", + }, + TimeoutTimestamp: &timeoutVal, + }, + }, + }, + output: []sdk.Msg{ + &ibctransfertypes.MsgTransfer{ + SourcePort: "transfer", + SourceChannel: "myChanID", + Token: sdk.Coin{ + Denom: "ALX", + Amount: sdk.NewInt(1), + }, + Sender: addr1.String(), + Receiver: addr2.String(), + TimeoutTimestamp: 100, + }, + }, + }, + "IBC close channel": { + sender: addr1, + srcIBCPort: "myIBCPort", + srcMsg: wasmvmtypes.CosmosMsg{ + IBC: &wasmvmtypes.IBCMsg{ + CloseChannel: &wasmvmtypes.CloseChannelMsg{ + ChannelID: "channel-1", + }, + }, + }, + output: []sdk.Msg{ + &channeltypes.MsgChannelCloseInit{ + PortId: "wasm." + addr1.String(), + ChannelId: "channel-1", + Signer: addr1.String(), + }, + }, + }, + } + encodingConfig := MakeEncodingConfig(t) + encoder := DefaultEncoders(encodingConfig.Marshaler) + for name, tc := range cases { + tc := tc + t.Run(name, func(t *testing.T) { + var ctx sdk.Context + res, err := encoder.Encode(ctx, tc.sender, tc.srcIBCPort, tc.srcMsg) + if tc.isError { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tc.output, res) + } + }) + } +} diff --git a/x/wasm/internal/keeper/handler_plugin_test.go b/x/wasm/internal/keeper/handler_plugin_test.go index aed42ffcaa..f28be41ade 100644 --- a/x/wasm/internal/keeper/handler_plugin_test.go +++ b/x/wasm/internal/keeper/handler_plugin_test.go @@ -3,456 +3,247 @@ package keeper import ( "encoding/json" "github.com/CosmWasm/wasmd/x/wasm/internal/keeper/wasmtesting" - codectypes "github.com/cosmos/cosmos-sdk/codec/types" + "github.com/CosmWasm/wasmd/x/wasm/internal/types" + wasmvmtypes "github.com/CosmWasm/wasmvm/types" + "github.com/cosmos/cosmos-sdk/baseapp" + sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" capabilitytypes "github.com/cosmos/cosmos-sdk/x/capability/types" - govtypes "github.com/cosmos/cosmos-sdk/x/gov/types" - ibctransfertypes "github.com/cosmos/cosmos-sdk/x/ibc/applications/transfer/types" clienttypes "github.com/cosmos/cosmos-sdk/x/ibc/core/02-client/types" channeltypes "github.com/cosmos/cosmos-sdk/x/ibc/core/04-channel/types" ibcexported "github.com/cosmos/cosmos-sdk/x/ibc/core/exported" - "github.com/golang/protobuf/proto" - "testing" - - "github.com/CosmWasm/wasmd/x/wasm/internal/types" - wasmvmtypes "github.com/CosmWasm/wasmvm/types" - sdk "github.com/cosmos/cosmos-sdk/types" - banktypes "github.com/cosmos/cosmos-sdk/x/bank/types" - distributiontypes "github.com/cosmos/cosmos-sdk/x/distribution/types" - stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "testing" ) -func TestEncoding(t *testing.T) { - addr1 := RandomAccountAddress(t) - addr2 := RandomAccountAddress(t) - invalidAddr := "xrnd1d02kd90n38qvr3qb9qof83fn2d2" - valAddr := make(sdk.ValAddress, sdk.AddrLen) - valAddr[0] = 12 - valAddr2 := make(sdk.ValAddress, sdk.AddrLen) - valAddr2[1] = 123 - var timeoutVal uint64 = 100 - - jsonMsg := json.RawMessage(`{"foo": 123}`) - - bankMsg := &banktypes.MsgSend{ - FromAddress: addr2.String(), - ToAddress: addr1.String(), - Amount: sdk.Coins{ - sdk.NewInt64Coin("uatom", 12345), - sdk.NewInt64Coin("utgd", 54321), - }, - } - bankMsgBin, err := proto.Marshal(bankMsg) - require.NoError(t, err) +func TestMessageHandlerChainDispatch(t *testing.T) { + capturingHandler, gotMsgs := wasmtesting.NewCapturingMessageHandler() - content, err := codectypes.NewAnyWithValue(types.StoreCodeProposalFixture()) - require.NoError(t, err) + alwaysUnknownMsgHandler := &wasmtesting.MockMessageHandler{ + DispatchMsgFn: func(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg) (events []sdk.Event, data [][]byte, err error) { + return nil, nil, types.ErrUnknownMsg + }} - proposalMsg := &govtypes.MsgSubmitProposal{ - Proposer: addr1.String(), - InitialDeposit: sdk.NewCoins(sdk.NewInt64Coin("uatom", 12345)), - Content: content, - } - proposalMsgBin, err := proto.Marshal(proposalMsg) - require.NoError(t, err) + assertNotCalledHandler := &wasmtesting.MockMessageHandler{ + DispatchMsgFn: func(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg) (events []sdk.Event, data [][]byte, err error) { + t.Fatal("not expected to be called") + return + }} - cases := map[string]struct { - sender sdk.AccAddress - srcMsg wasmvmtypes.CosmosMsg - srcIBCPort string - // set if valid - output []sdk.Msg - // set if invalid - isError bool + myMsg := wasmvmtypes.CosmosMsg{Custom: []byte(`{}`)} + specs := map[string]struct { + handlers []messenger + expErr *sdkerrors.Error + expEvents []sdk.Event }{ - "simple send": { - sender: addr1, - srcMsg: wasmvmtypes.CosmosMsg{ - Bank: &wasmvmtypes.BankMsg{ - Send: &wasmvmtypes.SendMsg{ - ToAddress: addr2.String(), - Amount: []wasmvmtypes.Coin{ - { - Denom: "uatom", - Amount: "12345", - }, - { - Denom: "usdt", - Amount: "54321", - }, - }, - }, - }, - }, - output: []sdk.Msg{ - &banktypes.MsgSend{ - FromAddress: addr1.String(), - ToAddress: addr2.String(), - Amount: sdk.Coins{ - sdk.NewInt64Coin("uatom", 12345), - sdk.NewInt64Coin("usdt", 54321), - }, - }, - }, + "single handler": { + handlers: []messenger{capturingHandler}, }, - "invalid send amount": { - sender: addr1, - srcMsg: wasmvmtypes.CosmosMsg{ - Bank: &wasmvmtypes.BankMsg{ - Send: &wasmvmtypes.SendMsg{ - ToAddress: addr2.String(), - Amount: []wasmvmtypes.Coin{ - { - Denom: "uatom", - Amount: "123.456", - }, - }, - }, - }, - }, - isError: true, + "passed to next handler": { + handlers: []messenger{alwaysUnknownMsgHandler, capturingHandler}, }, - "invalid address": { - sender: addr1, - srcMsg: wasmvmtypes.CosmosMsg{ - Bank: &wasmvmtypes.BankMsg{ - Send: &wasmvmtypes.SendMsg{ - ToAddress: invalidAddr, - Amount: []wasmvmtypes.Coin{ - { - Denom: "uatom", - Amount: "7890", - }, - }, - }, - }, - }, - isError: false, // addresses are checked in the handler - output: []sdk.Msg{ - &banktypes.MsgSend{ - FromAddress: addr1.String(), - ToAddress: invalidAddr, - Amount: sdk.Coins{ - sdk.NewInt64Coin("uatom", 7890), - }, - }, - }, + "stops iteration when handled": { + handlers: []messenger{capturingHandler, assertNotCalledHandler}, }, - "wasm execute": { - sender: addr1, - srcMsg: wasmvmtypes.CosmosMsg{ - Wasm: &wasmvmtypes.WasmMsg{ - Execute: &wasmvmtypes.ExecuteMsg{ - ContractAddr: addr2.String(), - Msg: jsonMsg, - Send: []wasmvmtypes.Coin{ - wasmvmtypes.NewCoin(12, "eth"), - }, - }, - }, - }, - output: []sdk.Msg{ - &types.MsgExecuteContract{ - Sender: addr1.String(), - Contract: addr2.String(), - Msg: jsonMsg, - Funds: sdk.NewCoins(sdk.NewInt64Coin("eth", 12)), - }, - }, + "stops iteration on handler error": { + handlers: []messenger{&wasmtesting.MockMessageHandler{ + DispatchMsgFn: func(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg) (events []sdk.Event, data [][]byte, err error) { + return nil, nil, types.ErrInvalidMsg + }}, assertNotCalledHandler}, + expErr: types.ErrInvalidMsg, }, - "wasm instantiate": { - sender: addr1, - srcMsg: wasmvmtypes.CosmosMsg{ - Wasm: &wasmvmtypes.WasmMsg{ - Instantiate: &wasmvmtypes.InstantiateMsg{ - CodeID: 7, - Msg: jsonMsg, - Send: []wasmvmtypes.Coin{ - wasmvmtypes.NewCoin(123, "eth"), - }, - Label: "myLabel", - }, - }, - }, - output: []sdk.Msg{ - &types.MsgInstantiateContract{ - Sender: addr1.String(), - CodeID: 7, - Label: "myLabel", - InitMsg: jsonMsg, - Funds: sdk.NewCoins(sdk.NewInt64Coin("eth", 123)), - }, - }, + "return events when handle": { + handlers: []messenger{&wasmtesting.MockMessageHandler{ + DispatchMsgFn: func(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg) (events []sdk.Event, data [][]byte, err error) { + _, data, _ = capturingHandler.DispatchMsg(ctx, contractAddr, contractIBCPortID, msg) + return []sdk.Event{sdk.NewEvent("myEvent", sdk.NewAttribute("foo", "bar"))}, data, nil + }}, + }, + expEvents: []sdk.Event{sdk.NewEvent("myEvent", sdk.NewAttribute("foo", "bar"))}, }, - "wasm migrate": { - sender: addr2, - srcMsg: wasmvmtypes.CosmosMsg{ - Wasm: &wasmvmtypes.WasmMsg{ - Migrate: &wasmvmtypes.MigrateMsg{ - ContractAddr: addr1.String(), - NewCodeID: 12, - Msg: jsonMsg, - }, - }, - }, - output: []sdk.Msg{ - &types.MsgMigrateContract{ - Sender: addr2.String(), - Contract: addr1.String(), - CodeID: 12, - MigrateMsg: jsonMsg, - }, - }, + "return error when none can handle": { + handlers: []messenger{alwaysUnknownMsgHandler}, + expErr: types.ErrUnknownMsg, }, - "staking delegate": { - sender: addr1, - srcMsg: wasmvmtypes.CosmosMsg{ - Staking: &wasmvmtypes.StakingMsg{ - Delegate: &wasmvmtypes.DelegateMsg{ - Validator: valAddr.String(), - Amount: wasmvmtypes.NewCoin(777, "stake"), - }, - }, - }, - output: []sdk.Msg{ - &stakingtypes.MsgDelegate{ - DelegatorAddress: addr1.String(), - ValidatorAddress: valAddr.String(), - Amount: sdk.NewInt64Coin("stake", 777), - }, - }, - }, - "staking delegate to non-validator": { - sender: addr1, - srcMsg: wasmvmtypes.CosmosMsg{ - Staking: &wasmvmtypes.StakingMsg{ - Delegate: &wasmvmtypes.DelegateMsg{ - Validator: addr2.String(), - Amount: wasmvmtypes.NewCoin(777, "stake"), - }, - }, - }, - isError: false, // fails in the handler - output: []sdk.Msg{ - &stakingtypes.MsgDelegate{ - DelegatorAddress: addr1.String(), - ValidatorAddress: addr2.String(), - Amount: sdk.NewInt64Coin("stake", 777), - }, - }, - }, - "staking undelegate": { - sender: addr1, - srcMsg: wasmvmtypes.CosmosMsg{ - Staking: &wasmvmtypes.StakingMsg{ - Undelegate: &wasmvmtypes.UndelegateMsg{ - Validator: valAddr.String(), - Amount: wasmvmtypes.NewCoin(555, "stake"), - }, - }, - }, - output: []sdk.Msg{ - &stakingtypes.MsgUndelegate{ - DelegatorAddress: addr1.String(), - ValidatorAddress: valAddr.String(), - Amount: sdk.NewInt64Coin("stake", 555), - }, - }, - }, - "staking redelegate": { - sender: addr1, - srcMsg: wasmvmtypes.CosmosMsg{ - Staking: &wasmvmtypes.StakingMsg{ - Redelegate: &wasmvmtypes.RedelegateMsg{ - SrcValidator: valAddr.String(), - DstValidator: valAddr2.String(), - Amount: wasmvmtypes.NewCoin(222, "stake"), - }, - }, - }, - output: []sdk.Msg{ - &stakingtypes.MsgBeginRedelegate{ - DelegatorAddress: addr1.String(), - ValidatorSrcAddress: valAddr.String(), - ValidatorDstAddress: valAddr2.String(), - Amount: sdk.NewInt64Coin("stake", 222), - }, - }, - }, - "staking withdraw (implicit recipient)": { - sender: addr1, - srcMsg: wasmvmtypes.CosmosMsg{ - Staking: &wasmvmtypes.StakingMsg{ - Withdraw: &wasmvmtypes.WithdrawMsg{ - Validator: valAddr2.String(), - }, - }, - }, - output: []sdk.Msg{ - &distributiontypes.MsgSetWithdrawAddress{ - DelegatorAddress: addr1.String(), - WithdrawAddress: addr1.String(), - }, - &distributiontypes.MsgWithdrawDelegatorReward{ - DelegatorAddress: addr1.String(), - ValidatorAddress: valAddr2.String(), - }, - }, - }, - "staking withdraw (explicit recipient)": { - sender: addr1, - srcMsg: wasmvmtypes.CosmosMsg{ - Staking: &wasmvmtypes.StakingMsg{ - Withdraw: &wasmvmtypes.WithdrawMsg{ - Validator: valAddr2.String(), - Recipient: addr2.String(), - }, - }, - }, - output: []sdk.Msg{ - &distributiontypes.MsgSetWithdrawAddress{ - DelegatorAddress: addr1.String(), - WithdrawAddress: addr2.String(), - }, - &distributiontypes.MsgWithdrawDelegatorReward{ - DelegatorAddress: addr1.String(), - ValidatorAddress: valAddr2.String(), - }, - }, - }, - "stargate encoded bank msg": { - sender: addr2, - srcMsg: wasmvmtypes.CosmosMsg{ - Stargate: &wasmvmtypes.StargateMsg{ - TypeURL: "/cosmos.bank.v1beta1.MsgSend", - Value: bankMsgBin, - }, - }, - output: []sdk.Msg{bankMsg}, + } + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + *gotMsgs = make([]wasmvmtypes.CosmosMsg, 0) + + // when + h := MessageHandlerChain{spec.handlers} + gotEvents, gotData, gotErr := h.DispatchMsg(sdk.Context{}, RandomAccountAddress(t), "anyPort", myMsg) + + // then + require.True(t, spec.expErr.Is(gotErr), "exp %v but got %#+v", spec.expErr, gotErr) + if spec.expErr != nil { + return + } + assert.Equal(t, []wasmvmtypes.CosmosMsg{myMsg}, *gotMsgs) + assert.Equal(t, [][]byte{{1}}, gotData) // {1} is default in capturing handler + assert.Equal(t, spec.expEvents, gotEvents) + }) + } +} + +func TestSDKMessageHandlerDispatch(t *testing.T) { + myEvent := sdk.NewEvent("myEvent", sdk.NewAttribute("foo", "bar")) + const myData = "myData" + myRouterResult := sdk.Result{ + Data: []byte(myData), + Events: sdk.Events{myEvent}.ToABCIEvents(), + } + + var gotMsg []sdk.Msg + capturingRouteFn := func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) { + gotMsg = append(gotMsg, msg) + return &myRouterResult, nil + } + + myContractAddr := RandomAccountAddress(t) + myContractMessage := wasmvmtypes.CosmosMsg{Custom: []byte("{}")} + + specs := map[string]struct { + srcRoute sdk.Route + srcEncoder CustomEncoder + expErr *sdkerrors.Error + expMsgDispatched int + }{ + "all good": { + srcRoute: sdk.NewRoute(types.RouterKey, capturingRouteFn), + srcEncoder: func(sender sdk.AccAddress, msg json.RawMessage) ([]sdk.Msg, error) { + myMsg := types.MsgExecuteContract{ + Sender: myContractAddr.String(), + Contract: RandomBech32AccountAddress(t), + Msg: []byte("{}"), + } + return []sdk.Msg{&myMsg}, nil + }, + expMsgDispatched: 1, }, - "stargate encoded msg with any type": { - sender: addr2, - srcMsg: wasmvmtypes.CosmosMsg{ - Stargate: &wasmvmtypes.StargateMsg{ - TypeURL: "/cosmos.gov.v1beta1.MsgSubmitProposal", - Value: proposalMsgBin, - }, - }, - output: []sdk.Msg{proposalMsg}, + "multiple output msgs": { + srcRoute: sdk.NewRoute(types.RouterKey, capturingRouteFn), + srcEncoder: func(sender sdk.AccAddress, msg json.RawMessage) ([]sdk.Msg, error) { + first := &types.MsgExecuteContract{ + Sender: myContractAddr.String(), + Contract: RandomBech32AccountAddress(t), + Msg: []byte("{}"), + } + second := &types.MsgExecuteContract{ + Sender: myContractAddr.String(), + Contract: RandomBech32AccountAddress(t), + Msg: []byte("{}"), + } + return []sdk.Msg{first, second}, nil + }, + expMsgDispatched: 2, }, - "stargate encoded invalid typeUrl": { - sender: addr2, - srcMsg: wasmvmtypes.CosmosMsg{ - Stargate: &wasmvmtypes.StargateMsg{ - TypeURL: "/cosmos.bank.v2.MsgSend", - Value: bankMsgBin, - }, - }, - isError: true, + "invalid sdk message rejected": { + srcRoute: sdk.NewRoute(types.RouterKey, capturingRouteFn), + srcEncoder: func(sender sdk.AccAddress, msg json.RawMessage) ([]sdk.Msg, error) { + invalidMsg := types.MsgExecuteContract{ + Sender: myContractAddr.String(), + Contract: RandomBech32AccountAddress(t), + Msg: []byte("INVALID_JSON"), + } + return []sdk.Msg{&invalidMsg}, nil + }, + expErr: types.ErrInvalid, }, - "IBC transfer with block timeout": { - sender: addr1, - srcIBCPort: "myIBCPort", - srcMsg: wasmvmtypes.CosmosMsg{ - IBC: &wasmvmtypes.IBCMsg{ - Transfer: &wasmvmtypes.TransferMsg{ - ChannelID: "myChanID", - ToAddress: addr2.String(), - Amount: wasmvmtypes.Coin{ - Denom: "ALX", - Amount: "1", - }, - TimeoutBlock: &wasmvmtypes.IBCTimeoutBlock{Revision: 1, Height: 2}, - }, - }, - }, - output: []sdk.Msg{ - &ibctransfertypes.MsgTransfer{ - SourcePort: "transfer", - SourceChannel: "myChanID", - Token: sdk.Coin{ - Denom: "ALX", - Amount: sdk.NewInt(1), - }, - Sender: addr1.String(), - Receiver: addr2.String(), - TimeoutHeight: clienttypes.Height{RevisionNumber: 1, RevisionHeight: 2}, - }, - }, + "invalid sender rejected": { + srcRoute: sdk.NewRoute(types.RouterKey, capturingRouteFn), + srcEncoder: func(sender sdk.AccAddress, msg json.RawMessage) ([]sdk.Msg, error) { + invalidMsg := types.MsgExecuteContract{ + Sender: RandomBech32AccountAddress(t), + Contract: RandomBech32AccountAddress(t), + Msg: []byte("{}"), + } + return []sdk.Msg{&invalidMsg}, nil + }, + expErr: sdkerrors.ErrUnauthorized, }, - "IBC transfer with time timeout": { - sender: addr1, - srcIBCPort: "myIBCPort", - srcMsg: wasmvmtypes.CosmosMsg{ - IBC: &wasmvmtypes.IBCMsg{ - Transfer: &wasmvmtypes.TransferMsg{ - ChannelID: "myChanID", - ToAddress: addr2.String(), - Amount: wasmvmtypes.Coin{ - Denom: "ALX", - Amount: "1", - }, - TimeoutTimestamp: &timeoutVal, - }, - }, - }, - output: []sdk.Msg{ - &ibctransfertypes.MsgTransfer{ - SourcePort: "transfer", - SourceChannel: "myChanID", - Token: sdk.Coin{ - Denom: "ALX", - Amount: sdk.NewInt(1), - }, - Sender: addr1.String(), - Receiver: addr2.String(), - TimeoutTimestamp: 100, - }, - }, + "unroutable message rejected": { + srcRoute: sdk.NewRoute("nothing", capturingRouteFn), + srcEncoder: func(sender sdk.AccAddress, msg json.RawMessage) ([]sdk.Msg, error) { + myMsg := types.MsgExecuteContract{ + Sender: myContractAddr.String(), + Contract: RandomBech32AccountAddress(t), + Msg: []byte("{}"), + } + return []sdk.Msg{&myMsg}, nil + }, + expErr: sdkerrors.ErrUnknownRequest, }, - "IBC close channel": { - sender: addr1, - srcIBCPort: "myIBCPort", - srcMsg: wasmvmtypes.CosmosMsg{ - IBC: &wasmvmtypes.IBCMsg{ - CloseChannel: &wasmvmtypes.CloseChannelMsg{ - ChannelID: "channel-1", - }, - }, - }, - output: []sdk.Msg{ - &channeltypes.MsgChannelCloseInit{ - PortId: "wasm." + addr1.String(), - ChannelId: "channel-1", - Signer: addr1.String(), - }, + "encoding error passed": { + srcRoute: sdk.NewRoute("nothing", capturingRouteFn), + srcEncoder: func(sender sdk.AccAddress, msg json.RawMessage) ([]sdk.Msg, error) { + myErr := types.ErrUnpinContractFailed + return nil, myErr }, + expErr: types.ErrUnpinContractFailed, }, } - encodingConfig := MakeEncodingConfig(t) - encoder := DefaultEncoders(nil, nil, encodingConfig.Marshaler) - for name, tc := range cases { - tc := tc + for name, spec := range specs { t.Run(name, func(t *testing.T) { - var ctx sdk.Context - res, err := encoder.Encode(ctx, tc.sender, tc.srcIBCPort, tc.srcMsg) - if tc.isError { - require.Error(t, err) - } else { - require.NoError(t, err) - assert.Equal(t, tc.output, res) + gotMsg = make([]sdk.Msg, 0) + router := baseapp.NewRouter() + router.AddRoute(spec.srcRoute) + + // when + ctx := sdk.Context{} + h := NewSDKMessageHandler(router, MessageEncoders{Custom: spec.srcEncoder}) + gotEvents, gotData, gotErr := h.DispatchMsg(ctx, myContractAddr, "myPort", myContractMessage) + + // then + require.True(t, spec.expErr.Is(gotErr), "exp %v but got %#+v", spec.expErr, gotErr) + if spec.expErr != nil { + require.Len(t, gotMsg, 0) + return + } + assert.Len(t, gotMsg, spec.expMsgDispatched) + for i := 0; i < spec.expMsgDispatched; i++ { + assert.Equal(t, myEvent, gotEvents[i]) + assert.Equal(t, []byte(myData), gotData[i]) } }) } } -func TestEncodeIBCSendPacket(t *testing.T) { +func TestIBCRawPacketHandler(t *testing.T) { ibcPort := "contractsIBCPort" var ctx sdk.Context + + var capturedPacket ibcexported.PacketI + + chanKeeper := &wasmtesting.MockChannelKeeper{ + GetNextSequenceSendFn: func(ctx sdk.Context, portID, channelID string) (uint64, bool) { + return 1, true + }, + GetChannelFn: func(ctx sdk.Context, srcPort, srcChan string) (channeltypes.Channel, bool) { + return channeltypes.Channel{ + Counterparty: channeltypes.NewCounterparty( + "other-port", + "other-channel-1", + )}, true + }, + SendPacketFn: func(ctx sdk.Context, channelCap *capabilitytypes.Capability, packet ibcexported.PacketI) error { + capturedPacket = packet + return nil + }, + } + capKeeper := &wasmtesting.MockCapabilityKeeper{ + GetCapabilityFn: func(ctx sdk.Context, name string) (*capabilitytypes.Capability, bool) { + return &capabilitytypes.Capability{}, true + }, + } + specs := map[string]struct { srcMsg wasmvmtypes.SendPacketMsg + chanKeeper types.ChannelKeeper + capKeeper types.CapabilityKeeper expPacketSent channeltypes.Packet + expErr *sdkerrors.Error }{ "all good": { srcMsg: wasmvmtypes.SendPacketMsg{ @@ -460,6 +251,8 @@ func TestEncodeIBCSendPacket(t *testing.T) { Data: []byte("myData"), TimeoutBlock: &wasmvmtypes.IBCTimeoutBlock{Revision: 1, Height: 2}, }, + chanKeeper: chanKeeper, + capKeeper: capKeeper, expPacketSent: channeltypes.Packet{ Sequence: 1, SourcePort: ibcPort, @@ -470,38 +263,46 @@ func TestEncodeIBCSendPacket(t *testing.T) { TimeoutHeight: clienttypes.Height{RevisionNumber: 1, RevisionHeight: 2}, }, }, + "sequence not found returns error": { + srcMsg: wasmvmtypes.SendPacketMsg{ + ChannelID: "channel-1", + Data: []byte("myData"), + TimeoutBlock: &wasmvmtypes.IBCTimeoutBlock{Revision: 1, Height: 2}, + }, + chanKeeper: &wasmtesting.MockChannelKeeper{ + GetNextSequenceSendFn: func(ctx sdk.Context, portID, channelID string) (uint64, bool) { + return 0, false + }}, + expErr: channeltypes.ErrSequenceSendNotFound, + }, + "capability not found returns error": { + srcMsg: wasmvmtypes.SendPacketMsg{ + ChannelID: "channel-1", + Data: []byte("myData"), + TimeoutBlock: &wasmvmtypes.IBCTimeoutBlock{Revision: 1, Height: 2}, + }, + chanKeeper: chanKeeper, + capKeeper: wasmtesting.MockCapabilityKeeper{ + GetCapabilityFn: func(ctx sdk.Context, name string) (*capabilitytypes.Capability, bool) { + return nil, false + }}, + expErr: channeltypes.ErrChannelCapabilityNotFound, + }, } for name, spec := range specs { t.Run(name, func(t *testing.T) { - var gotPacket ibcexported.PacketI - - var chanKeeper types.ChannelKeeper = &wasmtesting.MockChannelKeeper{ - GetNextSequenceSendFn: func(ctx sdk.Context, portID, channelID string) (uint64, bool) { - return 1, true - }, - GetChannelFn: func(ctx sdk.Context, srcPort, srcChan string) (channeltypes.Channel, bool) { - return channeltypes.Channel{ - Counterparty: channeltypes.NewCounterparty( - "other-port", - "other-channel-1", - )}, true - }, - SendPacketFn: func(ctx sdk.Context, channelCap *capabilitytypes.Capability, packet ibcexported.PacketI) error { - gotPacket = packet - return nil - }, + capturedPacket = nil + // when + h := NewIBCRawPacketHandler(spec.chanKeeper, spec.capKeeper) + data, evts, gotErr := h.DispatchMsg(ctx, RandomAccountAddress(t), ibcPort, wasmvmtypes.CosmosMsg{IBC: &wasmvmtypes.IBCMsg{SendPacket: &spec.srcMsg}}) + // then + require.True(t, spec.expErr.Is(gotErr), "exp %v but got %#+v", spec.expErr, gotErr) + if spec.expErr != nil { + return } - var capKeeper types.CapabilityKeeper = &wasmtesting.MockCapabilityKeeper{ - GetCapabilityFn: func(ctx sdk.Context, name string) (*capabilitytypes.Capability, bool) { - return &capabilitytypes.Capability{}, true - }, - } - sender := RandomAccountAddress(t) - res, err := EncodeIBCMsg(chanKeeper, capKeeper)(ctx, sender, ibcPort, &wasmvmtypes.IBCMsg{SendPacket: &spec.srcMsg}) - - require.NoError(t, err) - assert.Nil(t, res) - assert.Equal(t, spec.expPacketSent, gotPacket) + assert.Nil(t, data) + assert.Nil(t, evts) + assert.Equal(t, spec.expPacketSent, capturedPacket) }) } } diff --git a/x/wasm/internal/keeper/ibc_test.go b/x/wasm/internal/keeper/ibc_test.go index 45519f6123..fc94f7ecac 100644 --- a/x/wasm/internal/keeper/ibc_test.go +++ b/x/wasm/internal/keeper/ibc_test.go @@ -10,14 +10,14 @@ import ( ) func TestDontBindPortNonIBCContract(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) example := InstantiateHackatomExampleContract(t, ctx, keepers) // ensure we bound the port _, _, err := keepers.IBCKeeper.PortKeeper.LookupModuleByPort(ctx, keepers.WasmKeeper.GetContractInfo(ctx, example.Contract).IBCPortID) require.Error(t, err) } func TestBindingPortForIBCContractOnInstantiate(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) example := InstantiateIBCReflectContract(t, ctx, keepers) // ensure we bound the port owner, _, err := keepers.IBCKeeper.PortKeeper.LookupModuleByPort(ctx, keepers.WasmKeeper.GetContractInfo(ctx, example.Contract).IBCPortID) require.NoError(t, err) diff --git a/x/wasm/internal/keeper/keeper.go b/x/wasm/internal/keeper/keeper.go index 8b5b44611c..e8cfad3731 100644 --- a/x/wasm/internal/keeper/keeper.go +++ b/x/wasm/internal/keeper/keeper.go @@ -49,7 +49,15 @@ type Option interface { apply(*Keeper) } +// WasmVMQueryHandler is an extension point for custom query handler implementations +type wasmVMQueryHandler interface { + // HandleQuery executes the requested query + HandleQuery(ctx sdk.Context, caller sdk.AccAddress, request wasmvmtypes.QueryRequest) ([]byte, error) +} + +// messenger is an extension point for custom wasmVM message handling type messenger interface { + // DispatchMsg encodes the wasmVM message and dispatches it. DispatchMsg(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg) (events []sdk.Event, data [][]byte, err error) } @@ -60,17 +68,16 @@ type coinTransferrer interface { // Keeper will have a reference to Wasmer with it's own data directory. type Keeper struct { - storeKey sdk.StoreKey - cdc codec.Marshaler - accountKeeper types.AccountKeeper - bank coinTransferrer - ChannelKeeper types.ChannelKeeper - portKeeper types.PortKeeper - capabilityKeeper types.CapabilityKeeper - - wasmer types.WasmerEngine - queryPlugins QueryPlugins - messenger messenger + storeKey sdk.StoreKey + cdc codec.Marshaler + accountKeeper types.AccountKeeper + bank coinTransferrer + ChannelKeeper types.ChannelKeeper + portKeeper types.PortKeeper + capabilityKeeper types.CapabilityKeeper + wasmVM types.WasmerEngine + wasmVMQueryHandler wasmVMQueryHandler + messenger messenger // queryGasLimit is the max wasmvm gas that can be spent on executing a query with a contract queryGasLimit uint64 authZPolicy AuthorizationPolicy @@ -95,8 +102,6 @@ func NewKeeper( homeDir string, wasmConfig types.WasmConfig, supportedFeatures string, - customEncoders *MessageEncoders, - customPlugins *QueryPlugins, opts ...Option, ) Keeper { wasmer, err := wasmvm.NewVM(filepath.Join(homeDir, "wasm"), supportedFeatures, contractMemoryLimit, wasmConfig.ContractDebugMode, wasmConfig.MemoryCacheSize) @@ -111,18 +116,18 @@ func NewKeeper( keeper := Keeper{ storeKey: storeKey, cdc: cdc, - wasmer: wasmer, + wasmVM: wasmer, accountKeeper: accountKeeper, bank: NewBankCoinTransferrer(bankKeeper), ChannelKeeper: channelKeeper, portKeeper: portKeeper, capabilityKeeper: capabilityKeeper, - messenger: NewDefaultMessageHandler(router, channelKeeper, capabilityKeeper, cdc, customEncoders), + messenger: NewDefaultMessageHandler(router, channelKeeper, capabilityKeeper, cdc), queryGasLimit: wasmConfig.SmartQueryGasLimit, authZPolicy: DefaultAuthorizationPolicy{}, paramSpace: paramSpace, } - keeper.queryPlugins = DefaultQueryPlugins(bankKeeper, stakingKeeper, distKeeper, channelKeeper, queryRouter, &keeper).Merge(customPlugins) + keeper.wasmVMQueryHandler = DefaultQueryPlugins(bankKeeper, stakingKeeper, distKeeper, channelKeeper, queryRouter, &keeper) for _, o := range opts { o.apply(&keeper) } @@ -173,7 +178,7 @@ func (k Keeper) create(ctx sdk.Context, creator sdk.AccAddress, wasmCode []byte, } ctx.GasMeter().ConsumeGas(CompileCost*uint64(len(wasmCode)), "Compiling WASM Bytecode") - codeHash, err := k.wasmer.Create(wasmCode) + codeHash, err := k.wasmVM.Create(wasmCode) if err != nil { return 0, sdkerrors.Wrap(types.ErrCreateFailed, err.Error()) } @@ -198,7 +203,7 @@ func (k Keeper) importCode(ctx sdk.Context, codeID uint64, codeInfo types.CodeIn if err != nil { return sdkerrors.Wrap(types.ErrCreateFailed, err.Error()) } - newCodeHash, err := k.wasmer.Create(wasmCode) + newCodeHash, err := k.wasmVM.Create(wasmCode) if err != nil { return sdkerrors.Wrap(types.ErrCreateFailed, err.Error()) } @@ -269,11 +274,11 @@ func (k Keeper) instantiate(ctx sdk.Context, codeID uint64, creator, admin sdk.A prefixStore := prefix.NewStore(ctx.KVStore(k.storeKey), prefixStoreKey) // prepare querier - querier := NewQueryHandler(ctx, k.queryPlugins, contractAddress) + querier := NewQueryHandler(ctx, k.wasmVMQueryHandler, contractAddress) // instantiate wasm contract gas := gasForContract(ctx) - res, gasUsed, err := k.wasmer.Instantiate(codeInfo.CodeHash, env, info, initMsg, prefixStore, cosmwasmAPI, querier, gasMeter(ctx), gas) + res, gasUsed, err := k.wasmVM.Instantiate(codeInfo.CodeHash, env, info, initMsg, prefixStore, cosmwasmAPI, querier, gasMeter(ctx), gas) consumeGas(ctx, gasUsed) if err != nil { return contractAddress, nil, sdkerrors.Wrap(types.ErrInstantiateFailed, err.Error()) @@ -288,7 +293,7 @@ func (k Keeper) instantiate(ctx sdk.Context, codeID uint64, creator, admin sdk.A contractInfo := types.NewContractInfo(codeID, creator, admin, label, createdAt) // check for IBC flag - report, err := k.wasmer.AnalyzeCode(codeInfo.CodeHash) + report, err := k.wasmVM.AnalyzeCode(codeInfo.CodeHash) if err != nil { return contractAddress, nil, sdkerrors.Wrap(types.ErrInstantiateFailed, err.Error()) } @@ -336,9 +341,9 @@ func (k Keeper) Execute(ctx sdk.Context, contractAddress sdk.AccAddress, caller info := types.NewInfo(caller, coins) // prepare querier - querier := NewQueryHandler(ctx, k.queryPlugins, contractAddress) + querier := NewQueryHandler(ctx, k.wasmVMQueryHandler, contractAddress) gas := gasForContract(ctx) - res, gasUsed, execErr := k.wasmer.Execute(codeInfo.CodeHash, env, info, msg, prefixStore, cosmwasmAPI, querier, gasMeter(ctx), gas) + res, gasUsed, execErr := k.wasmVM.Execute(codeInfo.CodeHash, env, info, msg, prefixStore, cosmwasmAPI, querier, gasMeter(ctx), gas) consumeGas(ctx, gasUsed) if execErr != nil { return nil, sdkerrors.Wrap(types.ErrExecuteFailed, execErr.Error()) @@ -383,7 +388,7 @@ func (k Keeper) migrate(ctx sdk.Context, contractAddress sdk.AccAddress, caller } // check for IBC flag - switch report, err := k.wasmer.AnalyzeCode(newCodeInfo.CodeHash); { + switch report, err := k.wasmVM.AnalyzeCode(newCodeInfo.CodeHash); { case err != nil: return nil, sdkerrors.Wrap(types.ErrMigrationFailed, err.Error()) case !report.HasIBCEntryPoints && contractInfo.IBCPortID != "": @@ -401,12 +406,12 @@ func (k Keeper) migrate(ctx sdk.Context, contractAddress sdk.AccAddress, caller env := types.NewEnv(ctx, contractAddress) // prepare querier - querier := NewQueryHandler(ctx, k.queryPlugins, contractAddress) + querier := NewQueryHandler(ctx, k.wasmVMQueryHandler, contractAddress) prefixStoreKey := types.GetContractStorePrefix(contractAddress) prefixStore := prefix.NewStore(ctx.KVStore(k.storeKey), prefixStoreKey) gas := gasForContract(ctx) - res, gasUsed, err := k.wasmer.Migrate(newCodeInfo.CodeHash, env, msg, &prefixStore, cosmwasmAPI, &querier, gasMeter(ctx), gas) + res, gasUsed, err := k.wasmVM.Migrate(newCodeInfo.CodeHash, env, msg, &prefixStore, cosmwasmAPI, &querier, gasMeter(ctx), gas) consumeGas(ctx, gasUsed) if err != nil { return nil, sdkerrors.Wrap(types.ErrMigrationFailed, err.Error()) @@ -450,9 +455,9 @@ func (k Keeper) Sudo(ctx sdk.Context, contractAddress sdk.AccAddress, msg []byte env := types.NewEnv(ctx, contractAddress) // prepare querier - querier := NewQueryHandler(ctx, k.queryPlugins, contractAddress) + querier := NewQueryHandler(ctx, k.wasmVMQueryHandler, contractAddress) gas := gasForContract(ctx) - res, gasUsed, execErr := k.wasmer.Sudo(codeInfo.CodeHash, env, msg, prefixStore, cosmwasmAPI, querier, gasMeter(ctx), gas) + res, gasUsed, execErr := k.wasmVM.Sudo(codeInfo.CodeHash, env, msg, prefixStore, cosmwasmAPI, querier, gasMeter(ctx), gas) consumeGas(ctx, gasUsed) if execErr != nil { return nil, sdkerrors.Wrap(types.ErrExecuteFailed, execErr.Error()) @@ -491,10 +496,10 @@ func (k Keeper) reply(ctx sdk.Context, contractAddress sdk.AccAddress, reply was // prepare querier querier := QueryHandler{ Ctx: ctx, - Plugins: k.queryPlugins, + Plugins: k.wasmVMQueryHandler, } gas := gasForContract(ctx) - res, gasUsed, execErr := k.wasmer.Reply(codeInfo.CodeHash, env, reply, prefixStore, cosmwasmAPI, querier, gasMeter(ctx), gas) + res, gasUsed, execErr := k.wasmVM.Reply(codeInfo.CodeHash, env, reply, prefixStore, cosmwasmAPI, querier, gasMeter(ctx), gas) consumeGas(ctx, gasUsed) if execErr != nil { return nil, sdkerrors.Wrap(types.ErrExecuteFailed, execErr.Error()) @@ -581,10 +586,10 @@ func (k Keeper) QuerySmart(ctx sdk.Context, contractAddr sdk.AccAddress, req []b } // prepare querier - querier := NewQueryHandler(ctx, k.queryPlugins, contractAddr) + querier := NewQueryHandler(ctx, k.wasmVMQueryHandler, contractAddr) env := types.NewEnv(ctx, contractAddr) - queryResult, gasUsed, qErr := k.wasmer.Query(codeInfo.CodeHash, env, req, prefixStore, cosmwasmAPI, querier, gasMeter(ctx), gasForContract(ctx)) + queryResult, gasUsed, qErr := k.wasmVM.Query(codeInfo.CodeHash, env, req, prefixStore, cosmwasmAPI, querier, gasMeter(ctx), gasForContract(ctx)) consumeGas(ctx, gasUsed) if qErr != nil { return nil, sdkerrors.Wrap(types.ErrQueryFailed, qErr.Error()) @@ -716,7 +721,7 @@ func (k Keeper) GetByteCode(ctx sdk.Context, codeID uint64) ([]byte, error) { return nil, nil } k.cdc.MustUnmarshalBinaryBare(codeInfoBz, &codeInfo) - return k.wasmer.GetCode(codeInfo.CodeHash) + return k.wasmVM.GetCode(codeInfo.CodeHash) } // PinCode pins the wasm contract in wasmvm cache @@ -726,7 +731,7 @@ func (k Keeper) PinCode(ctx sdk.Context, codeID uint64) error { return sdkerrors.Wrap(types.ErrNotFound, "code info") } - if err := k.wasmer.Pin(codeInfo.CodeHash); err != nil { + if err := k.wasmVM.Pin(codeInfo.CodeHash); err != nil { return sdkerrors.Wrap(types.ErrPinContractFailed, err.Error()) } store := ctx.KVStore(k.storeKey) @@ -741,7 +746,7 @@ func (k Keeper) UnpinCode(ctx sdk.Context, codeID uint64) error { if codeInfo == nil { return sdkerrors.Wrap(types.ErrNotFound, "code info") } - if err := k.wasmer.Unpin(codeInfo.CodeHash); err != nil { + if err := k.wasmVM.Unpin(codeInfo.CodeHash); err != nil { return sdkerrors.Wrap(types.ErrUnpinContractFailed, err.Error()) } @@ -765,7 +770,7 @@ func (k Keeper) InitializePinnedCodes(ctx sdk.Context) error { if codeInfo == nil { return sdkerrors.Wrap(types.ErrNotFound, "code info") } - if err := k.wasmer.Pin(codeInfo.CodeHash); err != nil { + if err := k.wasmVM.Pin(codeInfo.CodeHash); err != nil { return sdkerrors.Wrap(types.ErrPinContractFailed, err.Error()) } } diff --git a/x/wasm/internal/keeper/keeper_test.go b/x/wasm/internal/keeper/keeper_test.go index 855ed4105a..11433229d7 100644 --- a/x/wasm/internal/keeper/keeper_test.go +++ b/x/wasm/internal/keeper/keeper_test.go @@ -25,12 +25,12 @@ import ( const SupportedFeatures = "staking,stargate" func TestNewKeeper(t *testing.T) { - _, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + _, keepers := CreateTestInput(t, false, SupportedFeatures) require.NotNil(t, keepers.WasmKeeper) } func TestCreate(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) @@ -79,7 +79,7 @@ func TestCreateStoresInstantiatePermission(t *testing.T) { } for msg, spec := range specs { t.Run(msg, func(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper keeper.setParams(ctx, types.Params{ CodeUploadAccess: types.AllowEverybody, @@ -99,7 +99,7 @@ func TestCreateStoresInstantiatePermission(t *testing.T) { } func TestCreateWithParamPermissions(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) accKeeper, bankKeeper, keeper := keepers.AccountKeeper, keepers.BankKeeper, keepers.WasmKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) @@ -146,7 +146,7 @@ func TestCreateWithParamPermissions(t *testing.T) { } func TestCreateDuplicate(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) @@ -175,7 +175,7 @@ func TestCreateDuplicate(t *testing.T) { } func TestCreateWithSimulation(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper ctx = ctx.WithBlockHeader(tmproto.Header{Height: 1}). @@ -193,7 +193,7 @@ func TestCreateWithSimulation(t *testing.T) { require.Equal(t, uint64(1), contractID) // then try to create it in non-simulation mode (should not fail) - ctx, keepers = CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers = CreateTestInput(t, false, SupportedFeatures) accKeeper, keeper = keepers.AccountKeeper, keepers.WasmKeeper contractID, err = keeper.Create(ctx, creator, wasmCode, "https://github.com/CosmWasm/wasmd/blob/master/x/wasm/testdata/escrow.wasm", "any/builder:tag", nil) require.NoError(t, err) @@ -231,7 +231,7 @@ func TestIsSimulationMode(t *testing.T) { } func TestCreateWithGzippedPayload(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) @@ -252,7 +252,7 @@ func TestCreateWithGzippedPayload(t *testing.T) { } func TestInstantiate(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) @@ -338,7 +338,7 @@ func TestInstantiateWithDeposit(t *testing.T) { } for msg, spec := range specs { t.Run(msg, func(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) accKeeper, bankKeeper, keeper := keepers.AccountKeeper, keepers.BankKeeper, keepers.WasmKeeper if spec.fundAddr { @@ -408,7 +408,7 @@ func TestInstantiateWithPermissions(t *testing.T) { } for msg, spec := range specs { t.Run(msg, func(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) accKeeper, bankKeeper, keeper := keepers.AccountKeeper, keepers.BankKeeper, keepers.WasmKeeper fundAccounts(t, ctx, accKeeper, bankKeeper, spec.srcActor, deposit) @@ -422,7 +422,7 @@ func TestInstantiateWithPermissions(t *testing.T) { } func TestInstantiateWithNonExistingCodeID(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) @@ -439,7 +439,7 @@ func TestInstantiateWithNonExistingCodeID(t *testing.T) { } func TestInstantiateWithContractDataResponse(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) wasmerMock := &wasmtesting.MockWasmer{ InstantiateFn: func(codeID wasmvm.Checksum, env wasmvmtypes.Env, info wasmvmtypes.MessageInfo, initMsg []byte, store wasmvm.KVStore, goapi wasmvm.GoAPI, querier wasmvm.Querier, gasMeter wasmvm.GasMeter, gasLimit uint64) (*wasmvmtypes.Response, uint64, error) { @@ -456,7 +456,7 @@ func TestInstantiateWithContractDataResponse(t *testing.T) { } func TestExecute(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) @@ -592,7 +592,7 @@ func TestExecuteWithDeposit(t *testing.T) { } for msg, spec := range specs { t.Run(msg, func(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) accKeeper, bankKeeper, keeper := keepers.AccountKeeper, keepers.BankKeeper, keepers.WasmKeeper if spec.newBankParams != nil { bankKeeper.SetParams(ctx, *spec.newBankParams) @@ -626,7 +626,7 @@ func TestExecuteWithDeposit(t *testing.T) { } func TestExecuteWithNonExistingAddress(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) @@ -639,7 +639,7 @@ func TestExecuteWithNonExistingAddress(t *testing.T) { } func TestExecuteWithPanic(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) @@ -673,7 +673,7 @@ func TestExecuteWithPanic(t *testing.T) { } func TestExecuteWithCpuLoop(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) @@ -718,7 +718,7 @@ func TestExecuteWithCpuLoop(t *testing.T) { } func TestExecuteWithStorageLoop(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) @@ -762,7 +762,7 @@ func TestExecuteWithStorageLoop(t *testing.T) { } func TestMigrate(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) @@ -943,7 +943,7 @@ func TestMigrate(t *testing.T) { } func TestMigrateReplacesTheSecondIndex(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) example := InstantiateHackatomExampleContract(t, ctx, keepers) // then assert a second index exists @@ -968,7 +968,7 @@ func TestMigrateReplacesTheSecondIndex(t *testing.T) { } func TestMigrateWithDispatchedMessage(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) @@ -1064,7 +1064,7 @@ type stealFundsMsg struct { } func TestSudo(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) @@ -1143,7 +1143,7 @@ func mustMarshal(t *testing.T, r interface{}) []byte { } func TestUpdateContractAdmin(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) @@ -1214,7 +1214,7 @@ func TestUpdateContractAdmin(t *testing.T) { } func TestClearContractAdmin(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) diff --git a/x/wasm/internal/keeper/legacy_querier_test.go b/x/wasm/internal/keeper/legacy_querier_test.go index 94276541eb..80379f1de5 100644 --- a/x/wasm/internal/keeper/legacy_querier_test.go +++ b/x/wasm/internal/keeper/legacy_querier_test.go @@ -16,7 +16,7 @@ import ( ) func TestLegacyQueryContractState(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) @@ -151,7 +151,7 @@ func TestLegacyQueryContractState(t *testing.T) { } func TestLegacyQueryContractListByCodeOrdering(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 1000000)) @@ -216,7 +216,7 @@ func TestLegacyQueryContractListByCodeOrdering(t *testing.T) { } func TestLegacyQueryContractHistory(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) keeper := keepers.WasmKeeper var ( @@ -327,7 +327,7 @@ func TestLegacyQueryCodeList(t *testing.T) { for msg, spec := range specs { t.Run(msg, func(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) keeper := keepers.WasmKeeper for _, codeID := range spec.codeIDs { diff --git a/x/wasm/internal/keeper/options.go b/x/wasm/internal/keeper/options.go index b26897ee02..cdb8590366 100644 --- a/x/wasm/internal/keeper/options.go +++ b/x/wasm/internal/keeper/options.go @@ -1,6 +1,9 @@ package keeper -import "github.com/CosmWasm/wasmd/x/wasm/internal/types" +import ( + "fmt" + "github.com/CosmWasm/wasmd/x/wasm/internal/types" +) type optsFn func(*Keeper) @@ -8,21 +11,63 @@ func (f optsFn) apply(keeper *Keeper) { f(keeper) } -// WithMessageHandler is an optional constructor parameter to replace the default wasm vm engine with the +// WithMessageHandler is an optional constructor parameter to replace the default wasmVM engine with the // given one. func WithWasmEngine(x types.WasmerEngine) Option { return optsFn(func(k *Keeper) { - k.wasmer = x + k.wasmVM = x }) } -// WithMessageHandler is an optional constructor parameter to set a custom message handler. +// WithMessageHandler is an optional constructor parameter to set a custom handler for wasmVM messages. +// This option should not be combined with Option `WithMessageEncoders`. func WithMessageHandler(x messenger) Option { return optsFn(func(k *Keeper) { k.messenger = x }) } +// WithQueryHandler is an optional constructor parameter to set custom query handler for wasmVM requests. +// This option should not be combined with Option `WithQueryPlugins`. +func WithQueryHandler(x wasmVMQueryHandler) Option { + return optsFn(func(k *Keeper) { + k.wasmVMQueryHandler = x + }) +} + +// WithQueryPlugins is an optional constructor parameter to pass custom query plugins for wasmVM requests. +// This option expects the default `QueryHandler` set an should not be combined with Option `WithQueryHandler`. +func WithQueryPlugins(x *QueryPlugins) Option { + return optsFn(func(k *Keeper) { + q, ok := k.wasmVMQueryHandler.(QueryPlugins) + if !ok { + panic(fmt.Sprintf("Unsupported query handler type: %T", k.wasmVMQueryHandler)) + } + k.wasmVMQueryHandler = q.Merge(x) + }) +} + +// WithMessageEncoders is an optional constructor parameter to pass custom message encoder to the default wasm message handler. +// This option expects the `DefaultMessageHandler` set an should not be combined with Option `WithMessageHandler`. +func WithMessageEncoders(x *MessageEncoders) Option { + return optsFn(func(k *Keeper) { + q, ok := k.messenger.(*MessageHandlerChain) + if !ok { + panic(fmt.Sprintf("Unsupported message handler type: %T", k.messenger)) + } + s, ok := q.handlers[0].(SDKMessageHandler) + if !ok { + panic(fmt.Sprintf("Unexpected message handler type: %T", q.handlers[0])) + } + e, ok := s.encoders.(MessageEncoders) + if !ok { + panic(fmt.Sprintf("Unsupported encoder type: %T", s.encoders)) + } + s.encoders = e.Merge(x) + q.handlers[0] = s + }) +} + // WithCoinTransferrer is an optional constructor parameter to set a custom coin transferrer func WithCoinTransferrer(x coinTransferrer) Option { return optsFn(func(k *Keeper) { diff --git a/x/wasm/internal/keeper/options_test.go b/x/wasm/internal/keeper/options_test.go index 74ee11373d..eac62378f0 100644 --- a/x/wasm/internal/keeper/options_test.go +++ b/x/wasm/internal/keeper/options_test.go @@ -19,7 +19,7 @@ func TestConstructorOptions(t *testing.T) { "wasm engine": { srcOpt: WithWasmEngine(&wasmtesting.MockWasmer{}), verify: func(k Keeper) { - assert.IsType(t, k.wasmer, &wasmtesting.MockWasmer{}) + assert.IsType(t, k.wasmVM, &wasmtesting.MockWasmer{}) }, }, "message handler": { @@ -28,6 +28,12 @@ func TestConstructorOptions(t *testing.T) { assert.IsType(t, k.messenger, &wasmtesting.MockMessageHandler{}) }, }, + "query plugins": { + srcOpt: WithQueryHandler(&wasmtesting.MockQueryHandler{}), + verify: func(k Keeper) { + assert.IsType(t, k.wasmVMQueryHandler, &wasmtesting.MockQueryHandler{}) + }, + }, "coin transferrer": { srcOpt: WithCoinTransferrer(&wasmtesting.MockCoinTransferrer{}), verify: func(k Keeper) { @@ -37,26 +43,7 @@ func TestConstructorOptions(t *testing.T) { } for name, spec := range specs { t.Run(name, func(t *testing.T) { - k := NewKeeper( - nil, - nil, - paramtypes.NewSubspace(nil, nil, nil, nil, ""), - authkeeper.AccountKeeper{}, - nil, - stakingkeeper.Keeper{}, - distributionkeeper.Keeper{}, - nil, - nil, - nil, - nil, - nil, - "tempDir", - types.DefaultWasmConfig(), - SupportedFeatures, - nil, - nil, - spec.srcOpt, - ) + k := NewKeeper(nil, nil, paramtypes.NewSubspace(nil, nil, nil, nil, ""), authkeeper.AccountKeeper{}, nil, stakingkeeper.Keeper{}, distributionkeeper.Keeper{}, nil, nil, nil, nil, nil, "tempDir", types.DefaultWasmConfig(), SupportedFeatures, spec.srcOpt) spec.verify(k) }) } diff --git a/x/wasm/internal/keeper/proposal_integration_test.go b/x/wasm/internal/keeper/proposal_integration_test.go index 0807feee41..bd7e3f508f 100644 --- a/x/wasm/internal/keeper/proposal_integration_test.go +++ b/x/wasm/internal/keeper/proposal_integration_test.go @@ -19,7 +19,7 @@ import ( ) func TestStoreCodeProposal(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, "staking", nil, nil) + ctx, keepers := CreateTestInput(t, false, "staking") govKeeper, wasmKeeper := keepers.GovKeeper, keepers.WasmKeeper wasmKeeper.setParams(ctx, types.Params{ CodeUploadAccess: types.AllowNobody, @@ -60,7 +60,7 @@ func TestStoreCodeProposal(t *testing.T) { } func TestInstantiateProposal(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, "staking", nil, nil) + ctx, keepers := CreateTestInput(t, false, "staking") govKeeper, wasmKeeper := keepers.GovKeeper, keepers.WasmKeeper wasmKeeper.setParams(ctx, types.Params{ CodeUploadAccess: types.AllowNobody, @@ -116,7 +116,7 @@ func TestInstantiateProposal(t *testing.T) { } func TestMigrateProposal(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, "staking", nil, nil) + ctx, keepers := CreateTestInput(t, false, "staking") govKeeper, wasmKeeper := keepers.GovKeeper, keepers.WasmKeeper wasmKeeper.setParams(ctx, types.Params{ CodeUploadAccess: types.AllowNobody, @@ -249,7 +249,7 @@ func TestAdminProposals(t *testing.T) { } for msg, spec := range specs { t.Run(msg, func(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, "staking", nil, nil) + ctx, keepers := CreateTestInput(t, false, "staking") govKeeper, wasmKeeper := keepers.GovKeeper, keepers.WasmKeeper wasmKeeper.setParams(ctx, types.Params{ CodeUploadAccess: types.AllowNobody, @@ -279,7 +279,7 @@ func TestAdminProposals(t *testing.T) { } func TestUpdateParamsProposal(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, "staking", nil, nil) + ctx, keepers := CreateTestInput(t, false, "staking") govKeeper, wasmKeeper := keepers.GovKeeper, keepers.WasmKeeper var ( @@ -351,7 +351,7 @@ func TestUpdateParamsProposal(t *testing.T) { } func TestPinCodesProposal(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, "staking", nil, nil) + ctx, keepers := CreateTestInput(t, false, "staking") govKeeper, wasmKeeper := keepers.GovKeeper, keepers.WasmKeeper mock := wasmtesting.MockWasmer{ @@ -438,7 +438,7 @@ func TestPinCodesProposal(t *testing.T) { } } func TestUnpinCodesProposal(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, "staking", nil, nil) + ctx, keepers := CreateTestInput(t, false, "staking") govKeeper, wasmKeeper := keepers.GovKeeper, keepers.WasmKeeper mock := wasmtesting.MockWasmer{ diff --git a/x/wasm/internal/keeper/querier_test.go b/x/wasm/internal/keeper/querier_test.go index 5230c20803..42c8cd70a1 100644 --- a/x/wasm/internal/keeper/querier_test.go +++ b/x/wasm/internal/keeper/querier_test.go @@ -20,7 +20,7 @@ import ( ) func TestQueryAllContractState(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) keeper := keepers.WasmKeeper exampleContract := InstantiateHackatomExampleContract(t, ctx, keepers) @@ -107,7 +107,7 @@ func TestQueryAllContractState(t *testing.T) { } func TestQuerySmartContractState(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) keeper := keepers.WasmKeeper exampleContract := InstantiateHackatomExampleContract(t, ctx, keepers) @@ -150,7 +150,7 @@ func TestQuerySmartContractState(t *testing.T) { } func TestQuerySmartContractPanics(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) contractAddr := contractAddress(1, 1) keepers.WasmKeeper.storeCodeInfo(ctx, 1, types.CodeInfo{}) keepers.WasmKeeper.storeContractInfo(ctx, contractAddr, &types.ContractInfo{ @@ -178,7 +178,7 @@ func TestQuerySmartContractPanics(t *testing.T) { } for msg, spec := range specs { t.Run(msg, func(t *testing.T) { - keepers.WasmKeeper.wasmer = &wasmtesting.MockWasmer{QueryFn: func(checksum cosmwasm.Checksum, env wasmvmtypes.Env, queryMsg []byte, store cosmwasm.KVStore, goapi cosmwasm.GoAPI, querier cosmwasm.Querier, gasMeter cosmwasm.GasMeter, gasLimit uint64) ([]byte, uint64, error) { + keepers.WasmKeeper.wasmVM = &wasmtesting.MockWasmer{QueryFn: func(checksum cosmwasm.Checksum, env wasmvmtypes.Env, queryMsg []byte, store cosmwasm.KVStore, goapi cosmwasm.GoAPI, querier cosmwasm.Querier, gasMeter cosmwasm.GasMeter, gasLimit uint64) ([]byte, uint64, error) { spec.doInContract() return nil, 0, nil }} @@ -194,7 +194,7 @@ func TestQuerySmartContractPanics(t *testing.T) { } func TestQueryRawContractState(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) keeper := keepers.WasmKeeper exampleContract := InstantiateHackatomExampleContract(t, ctx, keepers) @@ -249,7 +249,7 @@ func TestQueryRawContractState(t *testing.T) { } func TestQueryContractListByCodeOrdering(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 1000000)) @@ -308,7 +308,7 @@ func TestQueryContractListByCodeOrdering(t *testing.T) { } func TestQueryContractHistory(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) keeper := keepers.WasmKeeper var ( @@ -452,7 +452,7 @@ func TestQueryCodeList(t *testing.T) { wasmCode, err := ioutil.ReadFile("./testdata/hackatom.wasm") require.NoError(t, err) - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures) keeper := keepers.WasmKeeper specs := map[string]struct { diff --git a/x/wasm/internal/keeper/query_plugins.go b/x/wasm/internal/keeper/query_plugins.go index 6dddd5223d..c4f6a299b8 100644 --- a/x/wasm/internal/keeper/query_plugins.go +++ b/x/wasm/internal/keeper/query_plugins.go @@ -16,14 +16,14 @@ import ( type QueryHandler struct { Ctx sdk.Context - Plugins QueryPlugins + Plugins wasmVMQueryHandler Caller sdk.AccAddress } -func NewQueryHandler(ctx sdk.Context, plugins QueryPlugins, caller sdk.AccAddress) QueryHandler { +func NewQueryHandler(ctx sdk.Context, vmQueryHandler wasmVMQueryHandler, caller sdk.AccAddress) QueryHandler { return QueryHandler{ Ctx: ctx, - Plugins: plugins, + Plugins: vmQueryHandler, Caller: caller, } } @@ -51,27 +51,7 @@ func (q QueryHandler) Query(request wasmvmtypes.QueryRequest, gasLimit uint64) ( defer func() { q.Ctx.GasMeter().ConsumeGas(subctx.GasMeter().GasConsumed(), "contract sub-query") }() - - // do the query - if request.Bank != nil { - return q.Plugins.Bank(subctx, request.Bank) - } - if request.Custom != nil { - return q.Plugins.Custom(subctx, request.Custom) - } - if request.IBC != nil { - return q.Plugins.IBC(subctx, q.Caller, request.IBC) - } - if request.Staking != nil { - return q.Plugins.Staking(subctx, request.Staking) - } - if request.Stargate != nil { - return q.Plugins.Stargate(subctx, request.Stargate) - } - if request.Wasm != nil { - return q.Plugins.Wasm(subctx, request.Wasm) - } - return nil, wasmvmtypes.Unknown{} + return q.Plugins.HandleQuery(subctx, q.Caller, request) } func (q QueryHandler) GasConsumed() uint64 { @@ -89,7 +69,24 @@ type QueryPlugins struct { Wasm func(ctx sdk.Context, request *wasmvmtypes.WasmQuery) ([]byte, error) } -func DefaultQueryPlugins(bank types.BankViewKeeper, staking types.StakingKeeper, distKeeper types.DistributionKeeper, channelKeeper types.ChannelKeeper, queryRouter GRPCQueryRouter, wasm *Keeper) QueryPlugins { +type contractMetaDataSource interface { + GetContractInfo(ctx sdk.Context, contractAddress sdk.AccAddress) *types.ContractInfo +} + +type wasmQueryKeeper interface { + contractMetaDataSource + QueryRaw(ctx sdk.Context, contractAddress sdk.AccAddress, key []byte) []byte + QuerySmart(ctx sdk.Context, contractAddr sdk.AccAddress, req []byte) ([]byte, error) +} + +func DefaultQueryPlugins( + bank types.BankViewKeeper, + staking types.StakingKeeper, + distKeeper types.DistributionKeeper, + channelKeeper types.ChannelKeeper, + queryRouter GRPCQueryRouter, + wasm wasmQueryKeeper, +) QueryPlugins { return QueryPlugins{ Bank: BankQuerier(bank), Custom: NoCustomQuerier, @@ -126,6 +123,30 @@ func (e QueryPlugins) Merge(o *QueryPlugins) QueryPlugins { return e } +// HandleQuery executes the requested query +func (e QueryPlugins) HandleQuery(ctx sdk.Context, caller sdk.AccAddress, request wasmvmtypes.QueryRequest) ([]byte, error) { + // do the query + if request.Bank != nil { + return e.Bank(ctx, request.Bank) + } + if request.Custom != nil { + return e.Custom(ctx, request.Custom) + } + if request.IBC != nil { + return e.IBC(ctx, caller, request.IBC) + } + if request.Staking != nil { + return e.Staking(ctx, request.Staking) + } + if request.Stargate != nil { + return e.Stargate(ctx, request.Stargate) + } + if request.Wasm != nil { + return e.Wasm(ctx, request.Wasm) + } + return nil, wasmvmtypes.Unknown{} +} + func BankQuerier(bankKeeper types.BankViewKeeper) func(ctx sdk.Context, request *wasmvmtypes.BankQuery) ([]byte, error) { return func(ctx sdk.Context, request *wasmvmtypes.BankQuery) ([]byte, error) { if request.AllBalances != nil { @@ -162,7 +183,7 @@ func NoCustomQuerier(sdk.Context, json.RawMessage) ([]byte, error) { return nil, wasmvmtypes.UnsupportedRequest{Kind: "custom"} } -func IBCQuerier(wasm *Keeper, channelKeeper types.ChannelKeeper) func(ctx sdk.Context, caller sdk.AccAddress, request *wasmvmtypes.IBCQuery) ([]byte, error) { +func IBCQuerier(wasm contractMetaDataSource, channelKeeper types.ChannelKeeper) func(ctx sdk.Context, caller sdk.AccAddress, request *wasmvmtypes.IBCQuery) ([]byte, error) { return func(ctx sdk.Context, caller sdk.AccAddress, request *wasmvmtypes.IBCQuery) ([]byte, error) { if request.PortID != nil { contractInfo := wasm.GetContractInfo(ctx, caller) @@ -417,7 +438,7 @@ func getAccumulatedRewards(ctx sdk.Context, distKeeper types.DistributionKeeper, return rewards, nil } -func WasmQuerier(wasm *Keeper) func(ctx sdk.Context, request *wasmvmtypes.WasmQuery) ([]byte, error) { +func WasmQuerier(wasm wasmQueryKeeper) func(ctx sdk.Context, request *wasmvmtypes.WasmQuery) ([]byte, error) { return func(ctx sdk.Context, request *wasmvmtypes.WasmQuery) ([]byte, error) { if request.Smart != nil { addr, err := sdk.AccAddressFromBech32(request.Smart.ContractAddr) diff --git a/x/wasm/internal/keeper/query_plugins_test.go b/x/wasm/internal/keeper/query_plugins_test.go new file mode 100644 index 0000000000..98cc6bc24d --- /dev/null +++ b/x/wasm/internal/keeper/query_plugins_test.go @@ -0,0 +1,255 @@ +package keeper + +import ( + "github.com/CosmWasm/wasmd/x/wasm/internal/keeper/wasmtesting" + "github.com/CosmWasm/wasmd/x/wasm/internal/types" + wasmvmtypes "github.com/CosmWasm/wasmvm/types" + sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" + channeltypes "github.com/cosmos/cosmos-sdk/x/ibc/core/04-channel/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" +) + +func TestIBCQuerier(t *testing.T) { + myExampleChannels := []channeltypes.IdentifiedChannel{ + { + State: channeltypes.OPEN, + Ordering: channeltypes.ORDERED, + Counterparty: channeltypes.Counterparty{ + PortId: "counterPartyPortID", + ChannelId: "counterPartyChannelID", + }, + ConnectionHops: []string{"one"}, + Version: "v1", + PortId: "myPortID", + ChannelId: "myChannelID", + }, + { + State: channeltypes.INIT, + Ordering: channeltypes.UNORDERED, + Counterparty: channeltypes.Counterparty{ + PortId: "otherCounterPartyPortID", + ChannelId: "otherCounterPartyChannelID", + }, + ConnectionHops: []string{"other", "second"}, + Version: "otherVersion", + PortId: "otherPortID", + ChannelId: "otherChannelID", + }, + } + specs := map[string]struct { + srcQuery *wasmvmtypes.IBCQuery + wasmKeeper *wasmKeeperMock + channelKeeper *wasmtesting.MockChannelKeeper + expJsonResult string + expErr *sdkerrors.Error + }{ + "query port id": { + srcQuery: &wasmvmtypes.IBCQuery{ + PortID: &wasmvmtypes.PortIDQuery{}, + }, + wasmKeeper: newWasmKeeperMock( + func(ctx sdk.Context, contractAddress sdk.AccAddress) *types.ContractInfo { + return &types.ContractInfo{IBCPortID: "myIBCPortID"} + }, + ), + channelKeeper: &wasmtesting.MockChannelKeeper{}, + expJsonResult: `{"port_id":"myIBCPortID"}`, + }, + "query list channels - all": { + srcQuery: &wasmvmtypes.IBCQuery{ + ListChannels: &wasmvmtypes.ListChannelsQuery{}, + }, + channelKeeper: &wasmtesting.MockChannelKeeper{ + IterateChannelsFn: wasmtesting.MockChannelKeeperIterator(myExampleChannels), + }, + expJsonResult: `{ + "channels": [ + { + "endpoint": { + "port_id": "myPortID", + "channel_id": "myChannelID" + }, + "counterparty_endpoint": { + "port_id": "counterPartyPortID", + "channel_id": "counterPartyChannelID" + }, + "order": "ORDER_ORDERED", + "version": "v1", + "connection_id": "one" + }, + { + "endpoint": { + "port_id": "otherPortID", + "channel_id": "otherChannelID" + }, + "counterparty_endpoint": { + "port_id": "otherCounterPartyPortID", + "channel_id": "otherCounterPartyChannelID" + }, + "order": "ORDER_UNORDERED", + "version": "otherVersion", + "connection_id": "other" + } + ] +}`, + }, + "query list channels - filtered": { + srcQuery: &wasmvmtypes.IBCQuery{ + ListChannels: &wasmvmtypes.ListChannelsQuery{ + PortID: "otherPortID", + }, + }, + channelKeeper: &wasmtesting.MockChannelKeeper{ + IterateChannelsFn: wasmtesting.MockChannelKeeperIterator(myExampleChannels), + }, + expJsonResult: `{ + "channels": [ + { + "endpoint": { + "port_id": "otherPortID", + "channel_id": "otherChannelID" + }, + "counterparty_endpoint": { + "port_id": "otherCounterPartyPortID", + "channel_id": "otherCounterPartyChannelID" + }, + "order": "ORDER_UNORDERED", + "version": "otherVersion", + "connection_id": "other" + } + ] +}`, + }, + "query list channels - filtered empty": { + srcQuery: &wasmvmtypes.IBCQuery{ + ListChannels: &wasmvmtypes.ListChannelsQuery{ + PortID: "none-existing", + }, + }, + channelKeeper: &wasmtesting.MockChannelKeeper{ + IterateChannelsFn: wasmtesting.MockChannelKeeperIterator(myExampleChannels), + }, + expJsonResult: `{"channels": []}`, + }, + "query channel": { + srcQuery: &wasmvmtypes.IBCQuery{ + Channel: &wasmvmtypes.ChannelQuery{ + PortID: "myQueryPortID", + ChannelID: "myQueryChannelID", + }, + }, + channelKeeper: &wasmtesting.MockChannelKeeper{ + GetChannelFn: func(ctx sdk.Context, srcPort, srcChan string) (channel channeltypes.Channel, found bool) { + return channeltypes.Channel{ + State: channeltypes.INIT, + Ordering: channeltypes.UNORDERED, + Counterparty: channeltypes.Counterparty{ + PortId: "counterPartyPortID", + ChannelId: "otherCounterPartyChannelID", + }, + ConnectionHops: []string{"one"}, + Version: "version", + }, true + }, + }, + expJsonResult: `{ + "channel": { + "endpoint": { + "port_id": "myQueryPortID", + "channel_id": "myQueryChannelID" + }, + "counterparty_endpoint": { + "port_id": "counterPartyPortID", + "channel_id": "otherCounterPartyChannelID" + }, + "order": "ORDER_UNORDERED", + "version": "version", + "connection_id": "one" + } +}`, + }, + "query channel - without port set": { + srcQuery: &wasmvmtypes.IBCQuery{ + Channel: &wasmvmtypes.ChannelQuery{ + ChannelID: "myQueryChannelID", + }, + }, + wasmKeeper: newWasmKeeperMock(func(ctx sdk.Context, contractAddress sdk.AccAddress) *types.ContractInfo { + return &types.ContractInfo{IBCPortID: "myLoadedPortID"} + }), + channelKeeper: &wasmtesting.MockChannelKeeper{ + GetChannelFn: func(ctx sdk.Context, srcPort, srcChan string) (channel channeltypes.Channel, found bool) { + return channeltypes.Channel{ + State: channeltypes.INIT, + Ordering: channeltypes.UNORDERED, + Counterparty: channeltypes.Counterparty{ + PortId: "counterPartyPortID", + ChannelId: "otherCounterPartyChannelID", + }, + ConnectionHops: []string{"one"}, + Version: "version", + }, true + }, + }, + expJsonResult: `{ + "channel": { + "endpoint": { + "port_id": "myLoadedPortID", + "channel_id": "myQueryChannelID" + }, + "counterparty_endpoint": { + "port_id": "counterPartyPortID", + "channel_id": "otherCounterPartyChannelID" + }, + "order": "ORDER_UNORDERED", + "version": "version", + "connection_id": "one" + } +}`, + }, + "query channel - empty result": { + srcQuery: &wasmvmtypes.IBCQuery{ + Channel: &wasmvmtypes.ChannelQuery{ + PortID: "myQueryPortID", + ChannelID: "myQueryChannelID", + }, + }, + channelKeeper: &wasmtesting.MockChannelKeeper{ + GetChannelFn: func(ctx sdk.Context, srcPort, srcChan string) (channel channeltypes.Channel, found bool) { + return channeltypes.Channel{}, false + }, + }, + expJsonResult: "{}", + }, + } + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + h := IBCQuerier(spec.wasmKeeper, spec.channelKeeper) + gotResult, gotErr := h(sdk.Context{}, RandomAccountAddress(t), spec.srcQuery) + require.True(t, spec.expErr.Is(gotErr), "exp %v but got %#+v", spec.expErr, gotErr) + if spec.expErr != nil { + return + } + assert.JSONEq(t, spec.expJsonResult, string(gotResult), string(gotResult)) + }) + } + +} + +type wasmKeeperMock struct { + GetContractInfoFn func(ctx sdk.Context, contractAddress sdk.AccAddress) *types.ContractInfo +} + +func newWasmKeeperMock(f func(ctx sdk.Context, contractAddress sdk.AccAddress) *types.ContractInfo) *wasmKeeperMock { + return &wasmKeeperMock{GetContractInfoFn: f} +} + +func (m wasmKeeperMock) GetContractInfo(ctx sdk.Context, contractAddress sdk.AccAddress) *types.ContractInfo { + if m.GetContractInfoFn == nil { + panic("not expected to be called") + } + return m.GetContractInfoFn(ctx, contractAddress) +} diff --git a/x/wasm/internal/keeper/recurse_test.go b/x/wasm/internal/keeper/recurse_test.go index 01fd965568..bedb9aa0d1 100644 --- a/x/wasm/internal/keeper/recurse_test.go +++ b/x/wasm/internal/keeper/recurse_test.go @@ -2,6 +2,7 @@ package keeper import ( "encoding/json" + "github.com/CosmWasm/wasmd/x/wasm/internal/keeper/wasmtesting" "github.com/CosmWasm/wasmd/x/wasm/internal/types" "testing" @@ -40,14 +41,13 @@ var totalWasmQueryCounter int func initRecurseContract(t *testing.T) (contract sdk.AccAddress, creator sdk.AccAddress, ctx sdk.Context, keeper *Keeper) { // we do one basic setup before all test cases (which are read-only and don't change state) var realWasmQuerier func(ctx sdk.Context, request *wasmvmtypes.WasmQuery) ([]byte, error) - countingQuerier := &QueryPlugins{ - Wasm: func(ctx sdk.Context, request *wasmvmtypes.WasmQuery) ([]byte, error) { + countingQuerier := &wasmtesting.MockQueryHandler{ + HandleQueryFn: func(ctx sdk.Context, request wasmvmtypes.QueryRequest, caller sdk.AccAddress) ([]byte, error) { totalWasmQueryCounter++ - return realWasmQuerier(ctx, request) - }, - } + return realWasmQuerier(ctx, request.Wasm) + }} - ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, countingQuerier) + ctx, keepers := CreateTestInput(t, false, SupportedFeatures, WithQueryHandler(countingQuerier)) keeper = keepers.WasmKeeper realWasmQuerier = WasmQuerier(keeper) diff --git a/x/wasm/internal/keeper/reflect_test.go b/x/wasm/internal/keeper/reflect_test.go index 1cf359d061..cd1272205b 100644 --- a/x/wasm/internal/keeper/reflect_test.go +++ b/x/wasm/internal/keeper/reflect_test.go @@ -83,8 +83,8 @@ func mustParse(t *testing.T, data []byte, res interface{}) { const ReflectFeatures = "staking,mask,stargate" func TestReflectContractSend(t *testing.T) { - cdc := MakeTestCodec(t) - ctx, keepers := CreateTestInput(t, false, ReflectFeatures, reflectEncoders(cdc), nil) + cdc := MakeEncodingConfig(t).Marshaler + ctx, keepers := CreateTestInput(t, false, ReflectFeatures, WithMessageEncoders(reflectEncoders(cdc))) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) @@ -165,8 +165,8 @@ func TestReflectContractSend(t *testing.T) { } func TestReflectCustomMsg(t *testing.T) { - cdc := MakeTestCodec(t) - ctx, keepers := CreateTestInput(t, false, ReflectFeatures, reflectEncoders(cdc), reflectPlugins()) + cdc := MakeEncodingConfig(t).Marshaler + ctx, keepers := CreateTestInput(t, false, ReflectFeatures, WithMessageEncoders(reflectEncoders(cdc)), WithQueryPlugins(reflectPlugins())) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) @@ -258,8 +258,8 @@ func TestReflectCustomMsg(t *testing.T) { } func TestMaskReflectCustomQuery(t *testing.T) { - cdc := MakeTestCodec(t) - ctx, keepers := CreateTestInput(t, false, ReflectFeatures, reflectEncoders(cdc), reflectPlugins()) + cdc := MakeEncodingConfig(t).Marshaler + ctx, keepers := CreateTestInput(t, false, ReflectFeatures, WithMessageEncoders(reflectEncoders(cdc)), WithQueryPlugins(reflectPlugins())) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) @@ -308,8 +308,8 @@ func TestMaskReflectCustomQuery(t *testing.T) { } func TestReflectStargateQuery(t *testing.T) { - cdc := MakeTestCodec(t) - ctx, keepers := CreateTestInput(t, false, ReflectFeatures, reflectEncoders(cdc), reflectPlugins()) + cdc := MakeEncodingConfig(t).Marshaler + ctx, keepers := CreateTestInput(t, false, ReflectFeatures, WithMessageEncoders(reflectEncoders(cdc)), WithQueryPlugins(reflectPlugins())) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper funds := sdk.NewCoins(sdk.NewInt64Coin("denom", 320000)) @@ -385,7 +385,8 @@ type reflectState struct { } func TestMaskReflectWasmQueries(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, ReflectFeatures, reflectEncoders(MakeTestCodec(t)), nil) + cdc := MakeEncodingConfig(t).Marshaler + ctx, keepers := CreateTestInput(t, false, ReflectFeatures, WithMessageEncoders(reflectEncoders(cdc)), WithQueryPlugins(reflectPlugins())) accKeeper, keeper := keepers.AccountKeeper, keepers.WasmKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) @@ -456,7 +457,8 @@ func TestMaskReflectWasmQueries(t *testing.T) { } func TestWasmRawQueryWithNil(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, ReflectFeatures, reflectEncoders(MakeTestCodec(t)), nil) + cdc := MakeEncodingConfig(t).Marshaler + ctx, keepers := CreateTestInput(t, false, ReflectFeatures, WithMessageEncoders(reflectEncoders(cdc)), WithQueryPlugins(reflectPlugins())) accKeeper, keeper := keepers.AccountKeeper, keepers.WasmKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) diff --git a/x/wasm/internal/keeper/relay.go b/x/wasm/internal/keeper/relay.go index 1aaddc6da5..0bdcd0d8c3 100644 --- a/x/wasm/internal/keeper/relay.go +++ b/x/wasm/internal/keeper/relay.go @@ -23,10 +23,10 @@ func (k Keeper) OnOpenChannel( } env := types.NewEnv(ctx, contractAddr) - querier := NewQueryHandler(ctx, k.queryPlugins, contractAddr) + querier := NewQueryHandler(ctx, k.wasmVMQueryHandler, contractAddr) gas := gasForContract(ctx) - gasUsed, execErr := k.wasmer.IBCChannelOpen(codeInfo.CodeHash, env, channel, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas) + gasUsed, execErr := k.wasmVM.IBCChannelOpen(codeInfo.CodeHash, env, channel, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas) consumeGas(ctx, gasUsed) if execErr != nil { return sdkerrors.Wrap(types.ErrExecuteFailed, execErr.Error()) @@ -53,10 +53,10 @@ func (k Keeper) OnConnectChannel( } env := types.NewEnv(ctx, contractAddr) - querier := NewQueryHandler(ctx, k.queryPlugins, contractAddr) + querier := NewQueryHandler(ctx, k.wasmVMQueryHandler, contractAddr) gas := gasForContract(ctx) - res, gasUsed, execErr := k.wasmer.IBCChannelConnect(codeInfo.CodeHash, env, channel, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas) + res, gasUsed, execErr := k.wasmVM.IBCChannelConnect(codeInfo.CodeHash, env, channel, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas) consumeGas(ctx, gasUsed) if execErr != nil { return sdkerrors.Wrap(types.ErrExecuteFailed, execErr.Error()) @@ -89,10 +89,10 @@ func (k Keeper) OnCloseChannel( } params := types.NewEnv(ctx, contractAddr) - querier := NewQueryHandler(ctx, k.queryPlugins, contractAddr) + querier := NewQueryHandler(ctx, k.wasmVMQueryHandler, contractAddr) gas := gasForContract(ctx) - res, gasUsed, execErr := k.wasmer.IBCChannelClose(codeInfo.CodeHash, params, channel, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas) + res, gasUsed, execErr := k.wasmVM.IBCChannelClose(codeInfo.CodeHash, params, channel, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas) consumeGas(ctx, gasUsed) if execErr != nil { return sdkerrors.Wrap(types.ErrExecuteFailed, execErr.Error()) @@ -125,10 +125,10 @@ func (k Keeper) OnRecvPacket( } env := types.NewEnv(ctx, contractAddr) - querier := NewQueryHandler(ctx, k.queryPlugins, contractAddr) + querier := NewQueryHandler(ctx, k.wasmVMQueryHandler, contractAddr) gas := gasForContract(ctx) - res, gasUsed, execErr := k.wasmer.IBCPacketReceive(codeInfo.CodeHash, env, packet, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas) + res, gasUsed, execErr := k.wasmVM.IBCPacketReceive(codeInfo.CodeHash, env, packet, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas) consumeGas(ctx, gasUsed) if execErr != nil { return nil, sdkerrors.Wrap(types.ErrExecuteFailed, execErr.Error()) @@ -162,10 +162,10 @@ func (k Keeper) OnAckPacket( } env := types.NewEnv(ctx, contractAddr) - querier := NewQueryHandler(ctx, k.queryPlugins, contractAddr) + querier := NewQueryHandler(ctx, k.wasmVMQueryHandler, contractAddr) gas := gasForContract(ctx) - res, gasUsed, execErr := k.wasmer.IBCPacketAck(codeInfo.CodeHash, env, acknowledgement, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas) + res, gasUsed, execErr := k.wasmVM.IBCPacketAck(codeInfo.CodeHash, env, acknowledgement, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas) consumeGas(ctx, gasUsed) if execErr != nil { return sdkerrors.Wrap(types.ErrExecuteFailed, execErr.Error()) @@ -195,10 +195,10 @@ func (k Keeper) OnTimeoutPacket( } env := types.NewEnv(ctx, contractAddr) - querier := NewQueryHandler(ctx, k.queryPlugins, contractAddr) + querier := NewQueryHandler(ctx, k.wasmVMQueryHandler, contractAddr) gas := gasForContract(ctx) - res, gasUsed, execErr := k.wasmer.IBCPacketTimeout(codeInfo.CodeHash, env, packet, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas) + res, gasUsed, execErr := k.wasmVM.IBCPacketTimeout(codeInfo.CodeHash, env, packet, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas) consumeGas(ctx, gasUsed) if execErr != nil { return sdkerrors.Wrap(types.ErrExecuteFailed, execErr.Error()) diff --git a/x/wasm/internal/keeper/relay_test.go b/x/wasm/internal/keeper/relay_test.go index cb92eb680d..bcea274ccf 100644 --- a/x/wasm/internal/keeper/relay_test.go +++ b/x/wasm/internal/keeper/relay_test.go @@ -15,7 +15,7 @@ import ( func TestOnOpenChannel(t *testing.T) { var m wasmtesting.MockWasmer wasmtesting.MakeIBCInstantiable(&m) - parentCtx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + parentCtx, keepers := CreateTestInput(t, false, SupportedFeatures) example := SeedNewContractInstance(t, parentCtx, keepers, &m) specs := map[string]struct { @@ -74,7 +74,7 @@ func TestOnOpenChannel(t *testing.T) { func TestOnConnectChannel(t *testing.T) { var m wasmtesting.MockWasmer wasmtesting.MakeIBCInstantiable(&m) - parentCtx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + parentCtx, keepers := CreateTestInput(t, false, SupportedFeatures) example := SeedNewContractInstance(t, parentCtx, keepers, &m) specs := map[string]struct { @@ -184,7 +184,7 @@ func TestOnConnectChannel(t *testing.T) { func TestOnCloseChannel(t *testing.T) { var m wasmtesting.MockWasmer wasmtesting.MakeIBCInstantiable(&m) - parentCtx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + parentCtx, keepers := CreateTestInput(t, false, SupportedFeatures) example := SeedNewContractInstance(t, parentCtx, keepers, &m) specs := map[string]struct { @@ -294,7 +294,7 @@ func TestOnCloseChannel(t *testing.T) { func TestOnRecvPacket(t *testing.T) { var m wasmtesting.MockWasmer wasmtesting.MakeIBCInstantiable(&m) - parentCtx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + parentCtx, keepers := CreateTestInput(t, false, SupportedFeatures) example := SeedNewContractInstance(t, parentCtx, keepers, &m) specs := map[string]struct { @@ -419,7 +419,7 @@ func TestOnRecvPacket(t *testing.T) { func TestOnAckPacket(t *testing.T) { var m wasmtesting.MockWasmer wasmtesting.MakeIBCInstantiable(&m) - parentCtx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + parentCtx, keepers := CreateTestInput(t, false, SupportedFeatures) example := SeedNewContractInstance(t, parentCtx, keepers, &m) specs := map[string]struct { @@ -530,7 +530,7 @@ func TestOnAckPacket(t *testing.T) { func TestOnTimeoutPacket(t *testing.T) { var m wasmtesting.MockWasmer wasmtesting.MakeIBCInstantiable(&m) - parentCtx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil) + parentCtx, keepers := CreateTestInput(t, false, SupportedFeatures) example := SeedNewContractInstance(t, parentCtx, keepers, &m) specs := map[string]struct { diff --git a/x/wasm/internal/keeper/staking_test.go b/x/wasm/internal/keeper/staking_test.go index 8fc3b64b88..4cd2006310 100644 --- a/x/wasm/internal/keeper/staking_test.go +++ b/x/wasm/internal/keeper/staking_test.go @@ -90,7 +90,7 @@ type InvestmentResponse struct { } func TestInitializeStaking(t *testing.T) { - ctx, k := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, k := CreateTestInput(t, false, SupportedFeatures) accKeeper, stakingKeeper, keeper, bankKeeper := k.AccountKeeper, k.StakingKeeper, k.WasmKeeper, k.BankKeeper valAddr := addValidator(t, ctx, stakingKeeper, accKeeper, bankKeeper, sdk.NewInt64Coin("stake", 1234567)) @@ -163,7 +163,7 @@ type initInfo struct { } func initializeStaking(t *testing.T) initInfo { - ctx, k := CreateTestInput(t, false, SupportedFeatures, nil, nil) + ctx, k := CreateTestInput(t, false, SupportedFeatures) accKeeper, stakingKeeper, keeper, bankKeeper := k.AccountKeeper, k.StakingKeeper, k.WasmKeeper, k.BankKeeper valAddr := addValidator(t, ctx, stakingKeeper, accKeeper, bankKeeper, sdk.NewInt64Coin("stake", 1000000)) diff --git a/x/wasm/internal/keeper/submsg_test.go b/x/wasm/internal/keeper/submsg_test.go index c287e5cf62..086ff45353 100644 --- a/x/wasm/internal/keeper/submsg_test.go +++ b/x/wasm/internal/keeper/submsg_test.go @@ -17,7 +17,7 @@ import ( // Try a simple send, no gas limit to for a sanity check before trying table tests func TestDispatchSubMsgSuccessCase(t *testing.T) { - ctx, keepers := CreateTestInput(t, false, ReflectFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, ReflectFeatures) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) @@ -125,7 +125,7 @@ func TestDispatchSubMsgErrorHandling(t *testing.T) { subGasLimit := uint64(300_000) // prep - create one chain and upload the code - ctx, keepers := CreateTestInput(t, false, ReflectFeatures, nil, nil) + ctx, keepers := CreateTestInput(t, false, ReflectFeatures) ctx = ctx.WithGasMeter(sdk.NewInfiniteGasMeter()) ctx = ctx.WithBlockGasMeter(sdk.NewInfiniteGasMeter()) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper @@ -386,7 +386,7 @@ func TestDispatchSubMsgEncodeToNoSdkMsg(t *testing.T) { Bank: nilEncoder, } - ctx, keepers := CreateTestInput(t, false, ReflectFeatures, customEncoders, nil) + ctx, keepers := CreateTestInput(t, false, ReflectFeatures, WithMessageHandler(NewSDKMessageHandler(nil, customEncoders))) accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) diff --git a/x/wasm/internal/keeper/test_common.go b/x/wasm/internal/keeper/test_common.go index c0f42747db..601c3837dc 100644 --- a/x/wasm/internal/keeper/test_common.go +++ b/x/wasm/internal/keeper/test_common.go @@ -4,7 +4,7 @@ import ( "encoding/binary" "encoding/json" "fmt" - "github.com/CosmWasm/wasmd/x/wasm/internal/types" + types "github.com/CosmWasm/wasmd/x/wasm/internal/types" "github.com/cosmos/cosmos-sdk/baseapp" "github.com/cosmos/cosmos-sdk/codec" codectypes "github.com/cosmos/cosmos-sdk/codec/types" @@ -133,19 +133,18 @@ type TestKeepers struct { GovKeeper govkeeper.Keeper WasmKeeper *Keeper IBCKeeper *ibckeeper.Keeper + Router *baseapp.Router } // CreateDefaultTestInput common settings for CreateTestInput func CreateDefaultTestInput(t TestingT) (sdk.Context, TestKeepers) { - return CreateTestInput(t, false, "staking", nil, nil) + return CreateTestInput(t, false, "staking") } // encoders can be nil to accept the defaults, or set it to override some of the message handlers (like default) -func CreateTestInput(t TestingT, isCheckTx bool, supportedFeatures string, encoders *MessageEncoders, queriers *QueryPlugins) (sdk.Context, TestKeepers) { +func CreateTestInput(t TestingT, isCheckTx bool, supportedFeatures string, opts ...Option) (sdk.Context, TestKeepers) { // Load default wasm config - wasmConfig := types.DefaultWasmConfig() - db := dbm.NewMemDB() - return createTestInput(t, isCheckTx, supportedFeatures, encoders, queriers, wasmConfig, db) + return createTestInput(t, isCheckTx, supportedFeatures, types.DefaultWasmConfig(), dbm.NewMemDB(), opts...) } // encoders can be nil to accept the defaults, or set it to override some of the message handlers (like default) @@ -153,10 +152,9 @@ func createTestInput( t TestingT, isCheckTx bool, supportedFeatures string, - encoders *MessageEncoders, - queriers *QueryPlugins, wasmConfig types.WasmConfig, db dbm.DB, + opts ...Option, ) (sdk.Context, TestKeepers) { tempDir := t.TempDir() @@ -299,8 +297,7 @@ func createTestInput( tempDir, wasmConfig, supportedFeatures, - encoders, - queriers, + opts..., ) keeper.setParams(ctx, types.DefaultParams()) // add wasm handler so we can loop-back (contracts calling contracts) @@ -329,6 +326,7 @@ func createTestInput( BankKeeper: bankKeeper, GovKeeper: govKeeper, IBCKeeper: ibcKeeper, + Router: router, } return ctx, keepers } @@ -484,7 +482,7 @@ func StoreRandomContract(t TestingT, ctx sdk.Context, keepers TestKeepers, mock anyAmount := sdk.NewCoins(sdk.NewInt64Coin("denom", 1000)) creator, _, creatorAddr := keyPubAddr() fundAccounts(t, ctx, keepers.AccountKeeper, keepers.BankKeeper, creatorAddr, anyAmount) - keepers.WasmKeeper.wasmer = mock + keepers.WasmKeeper.wasmVM = mock wasmCode := append(wasmIdent, rand.Bytes(10)...) codeID, err := keepers.WasmKeeper.Create(ctx, creatorAddr, wasmCode, "", "", nil) require.NoError(t, err) diff --git a/x/wasm/internal/keeper/wasmtesting/mock_keepers.go b/x/wasm/internal/keeper/wasmtesting/mock_keepers.go index acee3a4082..7d216911bc 100644 --- a/x/wasm/internal/keeper/wasmtesting/mock_keepers.go +++ b/x/wasm/internal/keeper/wasmtesting/mock_keepers.go @@ -13,6 +13,7 @@ type MockChannelKeeper struct { SendPacketFn func(ctx sdk.Context, channelCap *capabilitytypes.Capability, packet ibcexported.PacketI) error ChanCloseInitFn func(ctx sdk.Context, portID, channelID string, chanCap *capabilitytypes.Capability) error GetAllChannelsFn func(ctx sdk.Context) []channeltypes.IdentifiedChannel + IterateChannelsFn func(ctx sdk.Context, cb func(channeltypes.IdentifiedChannel) bool) } func (m *MockChannelKeeper) GetChannel(ctx sdk.Context, srcPort, srcChan string) (channel channeltypes.Channel, found bool) { @@ -29,17 +30,6 @@ func (m *MockChannelKeeper) GetAllChannels(ctx sdk.Context) []channeltypes.Ident return m.GetAllChannelsFn(ctx) } -// Auto-implemented from GetAllChannels data -func (m *MockChannelKeeper) IterateChannels(ctx sdk.Context, cb func(channeltypes.IdentifiedChannel) bool) { - channels := m.GetAllChannels(ctx) - for _, channel := range channels { - stop := cb(channel) - if stop { - break - } - } -} - func (m *MockChannelKeeper) GetNextSequenceSend(ctx sdk.Context, portID, channelID string) (uint64, bool) { if m.GetNextSequenceSendFn == nil { panic("not supposed to be called!") @@ -61,6 +51,24 @@ func (m *MockChannelKeeper) ChanCloseInit(ctx sdk.Context, portID, channelID str return m.ChanCloseInitFn(ctx, portID, channelID, chanCap) } +func (m *MockChannelKeeper) IterateChannels(ctx sdk.Context, cb func(channeltypes.IdentifiedChannel) bool) { + if m.IterateChannelsFn == nil { + panic("not expected to be called") + } + m.IterateChannelsFn(ctx, cb) +} + +func MockChannelKeeperIterator(s []channeltypes.IdentifiedChannel) func(ctx sdk.Context, cb func(channeltypes.IdentifiedChannel) bool) { + return func(ctx sdk.Context, cb func(channeltypes.IdentifiedChannel) bool) { + for _, channel := range s { + stop := cb(channel) + if stop { + break + } + } + } +} + type MockCapabilityKeeper struct { GetCapabilityFn func(ctx sdk.Context, name string) (*capabilitytypes.Capability, bool) ClaimCapabilityFn func(ctx sdk.Context, cap *capabilitytypes.Capability, name string) error diff --git a/x/wasm/internal/keeper/wasmtesting/query_handler.go b/x/wasm/internal/keeper/wasmtesting/query_handler.go new file mode 100644 index 0000000000..52cf97d31f --- /dev/null +++ b/x/wasm/internal/keeper/wasmtesting/query_handler.go @@ -0,0 +1,17 @@ +package wasmtesting + +import ( + wasmvmtypes "github.com/CosmWasm/wasmvm/types" + sdk "github.com/cosmos/cosmos-sdk/types" +) + +type MockQueryHandler struct { + HandleQueryFn func(ctx sdk.Context, request wasmvmtypes.QueryRequest, caller sdk.AccAddress) ([]byte, error) +} + +func (m *MockQueryHandler) HandleQuery(ctx sdk.Context, caller sdk.AccAddress, request wasmvmtypes.QueryRequest) ([]byte, error) { + if m.HandleQueryFn == nil { + panic("not expected to be called") + } + return m.HandleQueryFn(ctx, request, caller) +} diff --git a/x/wasm/internal/types/errors.go b/x/wasm/internal/types/errors.go index 5e8799913f..200d7934f0 100644 --- a/x/wasm/internal/types/errors.go +++ b/x/wasm/internal/types/errors.go @@ -63,4 +63,7 @@ var ( // ErrUnpinContractFailed error for unpinning contract failures ErrUnpinContractFailed = sdkErrors.Register(DefaultCodespace, 19, "unpinning contract failed") + + // ErrUnknownMsg error by a message handler to show that it is not responsible for this message type + ErrUnknownMsg = sdkErrors.Register(DefaultCodespace, 20, "unknown message from the contract") ) diff --git a/x/wasm/module_test.go b/x/wasm/module_test.go index 7f0b77b5e2..a239151842 100644 --- a/x/wasm/module_test.go +++ b/x/wasm/module_test.go @@ -33,7 +33,7 @@ type testData struct { // returns a cleanup function, which must be defered on func setupTest(t *testing.T) testData { - ctx, keepers := CreateTestInput(t, false, "staking,stargate", nil, nil) + ctx, keepers := CreateTestInput(t, false, "staking,stargate") cdc := keeper.MakeTestCodec(t) data := testData{ module: NewAppModule(cdc, keepers.WasmKeeper, keepers.StakingKeeper),