Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(client/tx): simulate with correct pk #18472

Merged
merged 10 commits into from
Nov 17, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Ref: https://keepachangelog.com/en/1.0.0/

### Bug Fixes

* (client/tx) [#18472](https://github.com/cosmos/cosmos-sdk/pull/18472) Utilizes the correct Pubkey when simulating a transaction.
* (baseapp) [#18486](https://github.com/cosmos/cosmos-sdk/pull/18486) Fixed FinalizeBlock calls not being passed to ABCIListeners
* (baseapp) [#18383](https://github.com/cosmos/cosmos-sdk/pull/18383) Fixed a data race inside BaseApp.getContext, found by end-to-end (e2e) tests.
* (client/server) [#18345](https://github.com/cosmos/cosmos-sdk/pull/18345) Consistently set viper prefix in client and server. It defaults for the binary name for both client and server.
Expand Down
7 changes: 3 additions & 4 deletions client/tx/aux_builder_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package tx_test
package tx

import (
"testing"

"github.com/stretchr/testify/require"

"github.com/cosmos/cosmos-sdk/client/tx"
"github.com/cosmos/cosmos-sdk/codec"
codectypes "github.com/cosmos/cosmos-sdk/codec/types"
cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types"
Expand All @@ -26,7 +25,7 @@ func TestAuxTxBuilder(t *testing.T) {
// required for test case: "GetAuxSignerData works for DIRECT_AUX"
counterModule.RegisterInterfaces(reg)

var b tx.AuxTxBuilder
var b AuxTxBuilder

testcases := []struct {
name string
Expand Down Expand Up @@ -200,7 +199,7 @@ func TestAuxTxBuilder(t *testing.T) {
for _, tc := range testcases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
b = tx.NewAuxTxBuilder()
b = NewAuxTxBuilder()
err := tc.malleate()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was there a reason to change the package name?

Copy link
Member

@julienrbrt julienrbrt Nov 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's for testing their private function getSimSignatureData and set fromName on the Factory that is usually private. For the latter, they could use clientCtx with a from name however.


if tc.expErr {
Expand Down
40 changes: 29 additions & 11 deletions client/tx/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/cosmos/cosmos-sdk/client/flags"
codectypes "github.com/cosmos/cosmos-sdk/codec/types"
"github.com/cosmos/cosmos-sdk/crypto/keyring"
"github.com/cosmos/cosmos-sdk/crypto/keys/multisig"
"github.com/cosmos/cosmos-sdk/crypto/keys/secp256k1"
cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types"
sdk "github.com/cosmos/cosmos-sdk/types"
Expand All @@ -33,6 +34,7 @@ type Factory struct {
timeoutHeight uint64
gasAdjustment float64
chainID string
fromName string
offline bool
generateOnly bool
memo string
Expand Down Expand Up @@ -92,6 +94,7 @@ func NewFactoryCLI(clientCtx client.Context, flagSet *pflag.FlagSet) (Factory, e
accountRetriever: clientCtx.AccountRetriever,
keybase: clientCtx.Keyring,
chainID: clientCtx.ChainID,
fromName: clientCtx.FromName,
offline: clientCtx.Offline,
generateOnly: clientCtx.GenerateOnly,
gas: gasSetting.Gas,
Expand Down Expand Up @@ -406,10 +409,8 @@ func (f Factory) BuildSimTx(msgs ...sdk.Msg) ([]byte, error) {
// Create an empty signature literal as the ante handler will populate with a
// sentinel pubkey.
sig := signing.SignatureV2{
PubKey: pk,
Data: &signing.SingleSignatureData{
SignMode: f.signMode,
},
PubKey: pk,
Data: f.getSimSignatureData(pk),
Sequence: f.Sequence(),
}
if err := txb.SetSignatures(sig); err != nil {
Expand All @@ -435,16 +436,13 @@ func (f Factory) getSimPK() (cryptotypes.PubKey, error) {
pk cryptotypes.PubKey = &secp256k1.PubKey{} // use default public key type
)

// Use the first element from the list of keys in order to generate a valid
// pubkey that supports multiple algorithms.
if f.simulateAndExecute && f.keybase != nil {
records, _ := f.keybase.List()
if len(records) == 0 {
return nil, errors.New("cannot build signature for simulation, key records slice is empty")
record, err := f.keybase.Key(f.fromName)
if err != nil {
return nil, err
}

// take the first record just for simulation purposes
pk, ok = records[0].PubKey.GetCachedValue().(cryptotypes.PubKey)
pk, ok = record.PubKey.GetCachedValue().(cryptotypes.PubKey)
if !ok {
return nil, errors.New("cannot build signature for simulation, failed to convert proto Any to public key")
}
Expand All @@ -453,6 +451,26 @@ func (f Factory) getSimPK() (cryptotypes.PubKey, error) {
return pk, nil
}

// getSimSignatureData based on the pubKey type gets the correct SignatureData type
// to use for building a simulation tx.
func (f Factory) getSimSignatureData(pk cryptotypes.PubKey) signing.SignatureData {
multisigPubKey, ok := pk.(*multisig.LegacyAminoPubKey)
if !ok {
return &signing.SingleSignatureData{SignMode: f.signMode}
}

multiSignatureData := make([]signing.SignatureData, 0, multisigPubKey.Threshold)
for i := uint32(0); i < multisigPubKey.Threshold; i++ {
multiSignatureData = append(multiSignatureData, &signing.SingleSignatureData{
SignMode: f.SignMode(),
})
}

return &signing.MultiSignatureData{
Signatures: multiSignatureData,
}
}

// Prepare ensures the account defined by ctx.GetFromAddress() exists and
// if the account number and/or the account sequence number are zero (not set),
// they will be queried for and set on the provided Factory.
Expand Down
96 changes: 90 additions & 6 deletions client/tx/factory_test.go
Original file line number Diff line number Diff line change
@@ -1,35 +1,119 @@
package tx_test
package tx

import (
"testing"

"github.com/stretchr/testify/require"

"github.com/cosmos/cosmos-sdk/client"
"github.com/cosmos/cosmos-sdk/client/tx"
"github.com/cosmos/cosmos-sdk/codec"
codectypes "github.com/cosmos/cosmos-sdk/codec/types"
cryptocodec "github.com/cosmos/cosmos-sdk/crypto/codec"
"github.com/cosmos/cosmos-sdk/crypto/hd"
"github.com/cosmos/cosmos-sdk/crypto/keyring"
"github.com/cosmos/cosmos-sdk/crypto/keys/multisig"
"github.com/cosmos/cosmos-sdk/crypto/keys/secp256k1"
"github.com/cosmos/cosmos-sdk/crypto/types"
"github.com/cosmos/cosmos-sdk/testutil/testdata"
"github.com/cosmos/cosmos-sdk/types/tx/signing"
)

func TestFactoryPrepate(t *testing.T) {
func TestFactoryPrepare(t *testing.T) {
t.Parallel()

factory := tx.Factory{}
factory := Factory{}
clientCtx := client.Context{}

output, err := factory.Prepare(clientCtx.WithOffline(true))
require.NoError(t, err)
require.Equal(t, output, factory)

factory = tx.Factory{}.WithAccountRetriever(client.MockAccountRetriever{ReturnAccNum: 10, ReturnAccSeq: 1}).WithAccountNumber(5)
factory = Factory{}.WithAccountRetriever(client.MockAccountRetriever{ReturnAccNum: 10, ReturnAccSeq: 1}).WithAccountNumber(5)
output, err = factory.Prepare(clientCtx.WithFrom("foo"))
require.NoError(t, err)
require.NotEqual(t, output, factory)
require.Equal(t, output.AccountNumber(), uint64(5))
require.Equal(t, output.Sequence(), uint64(1))

factory = tx.Factory{}.WithAccountRetriever(client.MockAccountRetriever{ReturnAccNum: 10, ReturnAccSeq: 1})
factory = Factory{}.WithAccountRetriever(client.MockAccountRetriever{ReturnAccNum: 10, ReturnAccSeq: 1})
output, err = factory.Prepare(clientCtx.WithFrom("foo"))
require.NoError(t, err)
require.NotEqual(t, output, factory)
require.Equal(t, output.AccountNumber(), uint64(10))
require.Equal(t, output.Sequence(), uint64(1))
}

func TestFactory_getSimPKType(t *testing.T) {
// setup keyring
registry := codectypes.NewInterfaceRegistry()
cryptocodec.RegisterInterfaces(registry)
k := keyring.NewInMemory(codec.NewProtoCodec(registry))

tests := []struct {
name string
fromName string
genKey func(fromName string, k keyring.Keyring) error
wantType types.PubKey
}{
{
name: "simple key",
fromName: "testKey",
genKey: func(fromName string, k keyring.Keyring) error {
_, err := k.NewAccount(fromName, testdata.TestMnemonic, "", "", hd.Secp256k1)
return err
},
wantType: (*secp256k1.PubKey)(nil),
},
{
name: "multisig key",
fromName: "multiKey",
genKey: func(fromName string, k keyring.Keyring) error {
pk := multisig.NewLegacyAminoPubKey(1, []types.PubKey{&multisig.LegacyAminoPubKey{}})
_, err := k.SaveMultisig(fromName, pk)
return err
},
wantType: (*multisig.LegacyAminoPubKey)(nil),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
JulianToledano marked this conversation as resolved.
Show resolved Hide resolved
err := tt.genKey(tt.fromName, k)
require.NoError(t, err)
f := Factory{
keybase: k,
fromName: tt.fromName,
simulateAndExecute: true,
}
got, err := f.getSimPK()
require.NoError(t, err)
require.IsType(t, tt.wantType, got)
})
}
}

func TestFactory_getSimSignatureData(t *testing.T) {
tests := []struct {
name string
pk types.PubKey
wantType any
}{
{
name: "simple pubkey",
pk: &secp256k1.PubKey{},
wantType: (*signing.SingleSignatureData)(nil),
},
{
name: "multisig pubkey",
pk: &multisig.LegacyAminoPubKey{},
wantType: (*signing.MultiSignatureData)(nil),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := Factory{}.getSimSignatureData(tt.pk)
require.IsType(t, tt.wantType, got)
})
}
}
2 changes: 1 addition & 1 deletion client/tx/legacy_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tx_test
package tx

import (
"github.com/cosmos/cosmos-sdk/testutil/testdata"
Expand Down
21 changes: 10 additions & 11 deletions client/tx/tx_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tx_test
package tx

import (
"context"
Expand All @@ -9,12 +9,11 @@ import (
"github.com/stretchr/testify/require"
"google.golang.org/grpc"

ante "cosmossdk.io/x/auth/ante"
"cosmossdk.io/x/auth/ante"
"cosmossdk.io/x/auth/signing"
authtx "cosmossdk.io/x/auth/tx"

"github.com/cosmos/cosmos-sdk/client"
"github.com/cosmos/cosmos-sdk/client/tx"
"github.com/cosmos/cosmos-sdk/codec"
codectypes "github.com/cosmos/cosmos-sdk/codec/types"
"github.com/cosmos/cosmos-sdk/crypto/hd"
Expand Down Expand Up @@ -81,7 +80,7 @@ func TestCalculateGas(t *testing.T) {
defaultSignMode, err := signing.APISignModeToInternal(txCfg.SignModeHandler().DefaultMode())
require.NoError(t, err)

txf := tx.Factory{}.
txf := Factory{}.
WithChainID("test-chain").
WithTxConfig(txCfg).WithSignMode(defaultSignMode)

Expand All @@ -90,7 +89,7 @@ func TestCalculateGas(t *testing.T) {
gasUsed: tc.args.mockGasUsed,
wantErr: tc.args.mockWantErr,
}
simRes, gotAdjusted, err := tx.CalculateGas(mockClientCtx, txf.WithGasAdjustment(stc.args.adjustment))
simRes, gotAdjusted, err := CalculateGas(mockClientCtx, txf.WithGasAdjustment(stc.args.adjustment))
if stc.expPass {
require.NoError(t, err)
require.Equal(t, simRes.GasInfo.GasUsed, stc.wantEstimate)
Expand All @@ -104,8 +103,8 @@ func TestCalculateGas(t *testing.T) {
}
}

func mockTxFactory(txCfg client.TxConfig) tx.Factory {
return tx.Factory{}.
func mockTxFactory(txCfg client.TxConfig) Factory {
return Factory{}.
WithTxConfig(txCfg).
WithAccountNumber(50).
WithSequence(23).
Expand Down Expand Up @@ -196,7 +195,7 @@ func TestMnemonicInMemo(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
txf := tx.Factory{}.
txf := Factory{}.
WithTxConfig(txConfig).
WithAccountNumber(50).
WithSequence(23).
Expand Down Expand Up @@ -267,7 +266,7 @@ func TestSign(t *testing.T) {

testCases := []struct {
name string
txf tx.Factory
txf Factory
txb client.TxBuilder
from string
overwrite bool
Expand Down Expand Up @@ -354,7 +353,7 @@ func TestSign(t *testing.T) {
var prevSigs []signingtypes.SignatureV2
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err = tx.Sign(context.TODO(), tc.txf, tc.from, tc.txb, tc.overwrite)
err = Sign(context.TODO(), tc.txf, tc.from, tc.txb, tc.overwrite)
if len(tc.expectedPKs) == 0 {
requireT.Error(err)
} else {
Expand Down Expand Up @@ -409,7 +408,7 @@ func TestPreprocessHook(t *testing.T) {
txb, err := txfDirect.BuildUnsignedTx(msg1, msg2)
requireT.NoError(err)

err = tx.Sign(context.TODO(), txfDirect, from, txb, false)
err = Sign(context.TODO(), txfDirect, from, txb, false)
requireT.NoError(err)

// Run preprocessing
Expand Down
Loading