diff --git a/integration-tests/common/common.go b/integration-tests/common/common.go index 2d6da4044..9e0a84d42 100644 --- a/integration-tests/common/common.go +++ b/integration-tests/common/common.go @@ -51,13 +51,14 @@ type TestEnvDetails struct { } type RPCDetails struct { - RPCL1Internal string - RPCL2Internal string - RPCL1External string - RPCL2External string - MockServerUrl string - MockServerEndpoint string - P2PPort string + RPCL1Internal string + RPCL2Internal string + RPCL2InternalApiKey string + RPCL1External string + RPCL2External string + MockServerUrl string + MockServerEndpoint string + P2PPort string } func New(testConfig *testconfig.TestConfig) *Common { @@ -72,6 +73,11 @@ func New(testConfig *testconfig.TestConfig) *Common { if *testConfig.Common.Network == "testnet" { chainDetails = chainconfig.SepoliaConfig() chainDetails.L2RPCInternal = *testConfig.Common.L2RPCUrl + if testConfig.Common.L2RPCApiKey == nil { + chainDetails.L2RPCInternalApiKey = "" + } else { + chainDetails.L2RPCInternalApiKey = *testConfig.Common.L2RPCApiKey + } } else { // set up mocked local feedernet server because starknet-devnet does not provide one localDevnetFeederSrv := starknet.NewTestServer() @@ -85,8 +91,9 @@ func New(testConfig *testconfig.TestConfig) *Common { TestDuration: duration, }, RPCDetails: &RPCDetails{ - P2PPort: "6690", - RPCL2Internal: chainDetails.L2RPCInternal, + P2PPort: "6690", + RPCL2Internal: chainDetails.L2RPCInternal, + RPCL2InternalApiKey: chainDetails.L2RPCInternalApiKey, }, } @@ -147,8 +154,9 @@ func (c *Common) DefaultNodeConfig() *cl.Config { FeederURL: common_cfg.MustParseURL(c.ChainDetails.FeederURL), Nodes: []*config.Node{ { - Name: ptr.Ptr("primary"), - URL: common_cfg.MustParseURL(c.RPCDetails.RPCL2Internal), + Name: ptr.Ptr("primary"), + URL: common_cfg.MustParseURL(c.RPCDetails.RPCL2Internal), + APIKey: ptr.Ptr(c.RPCDetails.RPCL2InternalApiKey), }, }, } diff --git a/integration-tests/common/gauntlet_common.go b/integration-tests/common/gauntlet_common.go index f293f4a5f..02bedb6c0 100644 --- a/integration-tests/common/gauntlet_common.go +++ b/integration-tests/common/gauntlet_common.go @@ -4,8 +4,9 @@ import ( "encoding/json" "errors" "fmt" - "github.com/smartcontractkit/chainlink-starknet/integration-tests/utils" "os" + + "github.com/smartcontractkit/chainlink-starknet/integration-tests/utils" ) func (m *OCRv2TestState) fundNodes() ([]string, error) { @@ -26,7 +27,7 @@ func (m *OCRv2TestState) fundNodes() ([]string, error) { for _, key := range nAccounts { // We are not deploying in parallel here due to testnet limitations (429 too many requests) l.Debug().Msg(fmt.Sprintf("Funding node with address: %s", key)) - _, err := m.Clients.GauntletClient.TransferToken(m.Common.ChainDetails.StarkTokenAddress, key, "100000000000000000") // Transferring 0.1 STRK to each node + _, err := m.Clients.GauntletClient.TransferToken(m.Common.ChainDetails.StarkTokenAddress, key, "1000000000000000000") // Transferring 0.1 STRK to each node if err != nil { return nil, err } diff --git a/integration-tests/common/test_common.go b/integration-tests/common/test_common.go index fcec8f904..e96efdc6d 100644 --- a/integration-tests/common/test_common.go +++ b/integration-tests/common/test_common.go @@ -3,21 +3,23 @@ package common import ( "context" "fmt" + "net/http" + starknetdevnet "github.com/NethermindEth/starknet.go/devnet" "github.com/go-resty/resty/v2" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/smartcontractkit/chainlink-common/pkg/logger" test_env_ctf "github.com/smartcontractkit/chainlink-testing-framework/docker/test_env" - "net/http" + + "math/big" + "testing" + "time" test_env_starknet "github.com/smartcontractkit/chainlink-starknet/integration-tests/docker/test_env" "github.com/smartcontractkit/chainlink-starknet/integration-tests/testconfig" "github.com/smartcontractkit/chainlink-testing-framework/logging" "github.com/smartcontractkit/chainlink/integration-tests/docker/test_env" - "math/big" - "testing" - "time" "github.com/NethermindEth/juno/core/felt" starknetutils "github.com/NethermindEth/starknet.go/utils" @@ -200,7 +202,7 @@ func (m *OCRv2TestState) DeployCluster() { require.NoError(m.TestConfig.T, m.TestConfig.err) } lggr := logger.Nop() - m.Clients.StarknetClient, m.TestConfig.err = starknet.NewClient(m.Common.ChainDetails.ChainID, m.Common.RPCDetails.RPCL2External, lggr, &rpcRequestTimeout) + m.Clients.StarknetClient, m.TestConfig.err = starknet.NewClient(m.Common.ChainDetails.ChainID, m.Common.RPCDetails.RPCL2External, m.Common.RPCDetails.RPCL2InternalApiKey, lggr, &rpcRequestTimeout) require.NoError(m.TestConfig.T, m.TestConfig.err, "Creating starknet client should not fail") m.Clients.OCR2Client, m.TestConfig.err = ocr2.NewClient(m.Clients.StarknetClient, lggr) require.NoError(m.TestConfig.T, m.TestConfig.err, "Creating ocr2 client should not fail") diff --git a/integration-tests/config/config.go b/integration-tests/config/config.go index 405cac3be..6bdeed681 100644 --- a/integration-tests/config/config.go +++ b/integration-tests/config/config.go @@ -5,12 +5,13 @@ var ( ) type Config struct { - ChainName string - ChainID string - StarkTokenAddress string - L2RPCInternal string - TokenName string - FeederURL string + ChainName string + ChainID string + StarkTokenAddress string + L2RPCInternal string + L2RPCInternalApiKey string + TokenName string + FeederURL string } func SepoliaConfig() *Config { diff --git a/integration-tests/testconfig/testconfig.go b/integration-tests/testconfig/testconfig.go index c47684018..50a395f17 100644 --- a/integration-tests/testconfig/testconfig.go +++ b/integration-tests/testconfig/testconfig.go @@ -4,10 +4,11 @@ import ( "embed" "encoding/base64" "fmt" - "github.com/smartcontractkit/chainlink-testing-framework/docker/test_env" "os" "strings" + "github.com/smartcontractkit/chainlink-testing-framework/docker/test_env" + "github.com/barkimedes/go-deepcopy" "github.com/google/uuid" "github.com/pelletier/go-toml/v2" @@ -109,9 +110,11 @@ func (c *TestConfig) AsBase64() (string, error) { } type Common struct { - Network *string `toml:"network"` - InsideK8s *bool `toml:"inside_k8"` - User *string `toml:"user"` + Network *string `toml:"network"` + InsideK8s *bool `toml:"inside_k8"` + User *string `toml:"user"` + // if rpc requires api key to be passed as an HTTP header + L2RPCApiKey *string `toml:"l2_rpc_url_api_key"` L2RPCUrl *string `toml:"l2_rpc_url"` PrivateKey *string `toml:"private_key"` Account *string `toml:"account"` diff --git a/monitoring/cmd/monitoring/main.go b/monitoring/cmd/monitoring/main.go index 89f79a8c0..b4fd6f466 100644 --- a/monitoring/cmd/monitoring/main.go +++ b/monitoring/cmd/monitoring/main.go @@ -36,6 +36,7 @@ func main() { starknetClient, err := starknet.NewClient( starknetConfig.GetChainID(), starknetConfig.GetRPCEndpoint(), + starknetConfig.GetRPCApiKey(), logger.With(log, "component", "starknet-client"), &readTimeout, ) diff --git a/monitoring/pkg/monitoring/config_chain.go b/monitoring/pkg/monitoring/config_chain.go index e29c95aa0..b7c44a6d5 100644 --- a/monitoring/pkg/monitoring/config_chain.go +++ b/monitoring/pkg/monitoring/config_chain.go @@ -11,6 +11,7 @@ import ( type StarknetConfig struct { rpcEndpoint string + rpcApiKey string networkName string networkID string chainID string @@ -22,6 +23,7 @@ type StarknetConfig struct { var _ relayMonitoring.ChainConfig = StarknetConfig{} func (s StarknetConfig) GetRPCEndpoint() string { return s.rpcEndpoint } +func (s StarknetConfig) GetRPCApiKey() string { return s.rpcApiKey } func (s StarknetConfig) GetNetworkName() string { return s.networkName } func (s StarknetConfig) GetNetworkID() string { return s.networkID } func (s StarknetConfig) GetChainID() string { return s.chainID } @@ -54,6 +56,9 @@ func parseEnvVars(cfg *StarknetConfig) error { if value, isPresent := os.LookupEnv("STARKNET_RPC_ENDPOINT"); isPresent { cfg.rpcEndpoint = value } + if value, isPresent := os.LookupEnv("STARKNET_RPC_API_KEY"); isPresent { + cfg.rpcApiKey = value + } if value, isPresent := os.LookupEnv("STARKNET_NETWORK_NAME"); isPresent { cfg.networkName = value } diff --git a/relayer/pkg/chainlink/chain/chain.go b/relayer/pkg/chainlink/chain/chain.go index 5763bcaf3..909ee588a 100644 --- a/relayer/pkg/chainlink/chain/chain.go +++ b/relayer/pkg/chainlink/chain/chain.go @@ -143,7 +143,7 @@ func (c *chain) getClient() (*starknet.Client, error) { for _, i := range index { node = nodes[i] // create client and check - client, err = starknet.NewClient(node.ChainID, node.URL, c.lggr, &timeout) + client, err = starknet.NewClient(node.ChainID, node.URL, node.APIKey, c.lggr, &timeout) // if error, try another node if err != nil { c.lggr.Warnw("failed to create node", "name", node.Name, "starknet-url", node.URL, "err", err.Error()) diff --git a/relayer/pkg/chainlink/config/config.go b/relayer/pkg/chainlink/config/config.go index 9cf9431e7..a3a09e3c1 100644 --- a/relayer/pkg/chainlink/config/config.go +++ b/relayer/pkg/chainlink/config/config.go @@ -75,6 +75,8 @@ func (c *Chain) SetDefaults() { type Node struct { Name *string URL *config.URL + // only if rpc url needs api key passed in header + APIKey *string } type TOMLConfigs []*TOMLConfig @@ -227,6 +229,7 @@ func legacyNode(n *Node, id string) db.Node { Name: *n.Name, ChainID: id, URL: (*url.URL)(n.URL).String(), + APIKey: *n.APIKey, } } diff --git a/relayer/pkg/chainlink/db/db.go b/relayer/pkg/chainlink/db/db.go index e82f61848..2f089753e 100644 --- a/relayer/pkg/chainlink/db/db.go +++ b/relayer/pkg/chainlink/db/db.go @@ -9,6 +9,7 @@ type Node struct { Name string ChainID string `db:"starknet_chain_id"` URL string + APIKey string CreatedAt time.Time UpdatedAt time.Time } diff --git a/relayer/pkg/chainlink/ocr2/client_test.go b/relayer/pkg/chainlink/ocr2/client_test.go index 436fd8bc3..af61b13bc 100644 --- a/relayer/pkg/chainlink/ocr2/client_test.go +++ b/relayer/pkg/chainlink/ocr2/client_test.go @@ -93,7 +93,7 @@ func TestOCR2Client(t *testing.T) { url := mockServer.URL duration := 10 * time.Second - reader, err := starknet.NewClient(chainID, url, lggr, &duration) + reader, err := starknet.NewClient(chainID, url, "", lggr, &duration) require.NoError(t, err) client, err := NewClient(reader, lggr) assert.NoError(t, err) diff --git a/relayer/pkg/chainlink/txm/nonce.go b/relayer/pkg/chainlink/txm/nonce.go index 90e6cfbcb..437b6d9c4 100644 --- a/relayer/pkg/chainlink/txm/nonce.go +++ b/relayer/pkg/chainlink/txm/nonce.go @@ -24,7 +24,7 @@ type NonceManager interface { NextSequence(address *felt.Felt, chainID string) (*felt.Felt, error) IncrementNextSequence(address *felt.Felt, chainID string, currentNonce *felt.Felt) error // Resets local account nonce to on-chain account nonce - Sync(ctx context.Context, address *felt.Felt, chainId string, client NonceManagerClient) error + Sync(ctx context.Context, address *felt.Felt, publicKey *felt.Felt, chainId string, client NonceManagerClient) error } var _ NonceManager = (*nonceManager)(nil) @@ -64,7 +64,7 @@ func (nm *nonceManager) HealthReport() map[string]error { return map[string]error{nm.Name(): nm.starter.Healthy()} } -func (nm *nonceManager) Sync(ctx context.Context, address *felt.Felt, chainId string, client NonceManagerClient) error { +func (nm *nonceManager) Sync(ctx context.Context, address *felt.Felt, publicKey *felt.Felt, chainId string, client NonceManagerClient) error { if err := nm.validate(address, chainId); err != nil { return err } @@ -76,7 +76,7 @@ func (nm *nonceManager) Sync(ctx context.Context, address *felt.Felt, chainId st return err } - nm.n[address.String()][chainId] = n + nm.n[publicKey.String()][chainId] = n return nil } @@ -101,40 +101,40 @@ func (nm *nonceManager) Register(ctx context.Context, addr *felt.Felt, publicKey return nil } -func (nm *nonceManager) NextSequence(addr *felt.Felt, chainId string) (*felt.Felt, error) { - if err := nm.validate(addr, chainId); err != nil { +func (nm *nonceManager) NextSequence(publicKey *felt.Felt, chainId string) (*felt.Felt, error) { + if err := nm.validate(publicKey, chainId); err != nil { return nil, err } nm.lock.RLock() defer nm.lock.RUnlock() - return nm.n[addr.String()][chainId], nil + return nm.n[publicKey.String()][chainId], nil } -func (nm *nonceManager) IncrementNextSequence(addr *felt.Felt, chainId string, currentNonce *felt.Felt) error { - if err := nm.validate(addr, chainId); err != nil { +func (nm *nonceManager) IncrementNextSequence(publicKey *felt.Felt, chainId string, currentNonce *felt.Felt) error { + if err := nm.validate(publicKey, chainId); err != nil { return err } nm.lock.Lock() defer nm.lock.Unlock() - n := nm.n[addr.String()][chainId] + n := nm.n[publicKey.String()][chainId] if n.Cmp(currentNonce) != 0 { - return fmt.Errorf("mismatched nonce for %s: %s (expected) != %s (got)", addr, n, currentNonce) + return fmt.Errorf("mismatched nonce for %s: %s (expected) != %s (got)", publicKey, n, currentNonce) } one := new(felt.Felt).SetUint64(1) - nm.n[addr.String()][chainId] = new(felt.Felt).Add(n, one) + nm.n[publicKey.String()][chainId] = new(felt.Felt).Add(n, one) return nil } -func (nm *nonceManager) validate(addr *felt.Felt, chainId string) error { +func (nm *nonceManager) validate(publicKey *felt.Felt, chainId string) error { nm.lock.RLock() defer nm.lock.RUnlock() - if _, exists := nm.n[addr.String()]; !exists { - return fmt.Errorf("nonce tracking does not exist for key: %s", addr.String()) + if _, exists := nm.n[publicKey.String()]; !exists { + return fmt.Errorf("nonce tracking does not exist for key: %s", publicKey.String()) } - if _, exists := nm.n[addr.String()][chainId]; !exists { - return fmt.Errorf("nonce does not exist for key: %s and chain: %s", addr.String(), chainId) + if _, exists := nm.n[publicKey.String()][chainId]; !exists { + return fmt.Errorf("nonce does not exist for key: %s and chain: %s", publicKey.String(), chainId) } return nil } diff --git a/relayer/pkg/chainlink/txm/txm.go b/relayer/pkg/chainlink/txm/txm.go index 65a1c8a67..38a5cd1a2 100644 --- a/relayer/pkg/chainlink/txm/txm.go +++ b/relayer/pkg/chainlink/txm/txm.go @@ -15,6 +15,8 @@ import ( starknetutils "github.com/NethermindEth/starknet.go/utils" "golang.org/x/exp/maps" + pkgerrors "github.com/pkg/errors" + "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/loop" "github.com/smartcontractkit/chainlink-common/pkg/services" @@ -129,7 +131,28 @@ func (txm *starktxm) broadcastLoop() { } } -func (txm *starktxm) handleNonceErr(ctx context.Context, accountAddress *felt.Felt) error { +func (txm *starktxm) handleBroadcastErr(ctx context.Context, data any, accountAddress *felt.Felt, publicKey *felt.Felt, call starknetrpc.FunctionCall) error { + + errData := fmt.Sprintf("%s", data) + txm.lggr.Debug("encountered handleBroadcastErr", errData) + + if isInvalidNonce(errData) { + // resubmits all unconfirmed transactions + err := txm.handleNonceErr(ctx, accountAddress, publicKey) + if err != nil { + return pkgerrors.Wrap(err, "error in nonce handling") + } + // resubmits the current 1 unbroadcasted tx that just failed + err = txm.Enqueue(accountAddress, publicKey, call) + if err != nil { + return pkgerrors.Wrap(err, "error in re-enqueuing after nonce handling") + } + } + + return nil +} + +func (txm *starktxm) handleNonceErr(ctx context.Context, accountAddress *felt.Felt, publicKey *felt.Felt) error { txm.lggr.Debugw("Handling Nonce Validation Error By Resubmitting Txs...", "account", accountAddress) @@ -144,9 +167,20 @@ func (txm *starktxm) handleNonceErr(ctx context.Context, accountAddress *felt.Fe return err } - txm.nonce.Sync(ctx, accountAddress, chainId, client) + // get current nonce before syncing (for logging purposes) + oldVal, err := txm.nonce.NextSequence(publicKey, chainId) + if err != nil { + return err + } - // todo: in the future, revisit resetting txm.txStore.currentNonce + txm.nonce.Sync(ctx, accountAddress, publicKey, chainId, client) + + getVal, err := txm.nonce.NextSequence(publicKey, chainId) + if err != nil { + return err + } + + txm.lggr.Debug("prior nonce: ", oldVal, "new nonce: ", getVal) unconfirmedTxs, err := txm.txStore.GetUnconfirmedSorted(accountAddress) if err != nil { @@ -243,22 +277,16 @@ func (txm *starktxm) broadcast(ctx context.Context, publicKey *felt.Felt, accoun tx.Signature = signature // get fee for tx - // optional - pass nonce to fee estimate (if nonce gets ahead, estimate may fail) - // can we estimate fee without calling estimate - tbd with 1.0 - simFlags := []starknetrpc.SimulationFlag{} - feeEstimate, err := account.EstimateFee(ctx, []starknetrpc.BroadcastTxn{tx}, simFlags, starknetrpc.BlockID{Tag: "latest"}) + simFlags := []starknetrpc.SimulationFlag{starknetrpc.SKIP_VALIDATE} + feeEstimate, err := account.EstimateFee(ctx, []starknetrpc.BroadcastTxn{tx}, simFlags, starknetrpc.BlockID{Tag: "pending"}) if err != nil { var data any if err, ok := err.(ethrpc.DataError); ok { data = err.ErrorData() - errData := fmt.Sprintf("%s", data) - txm.lggr.Debug("err data formatted as string", errData) - - if isInvalidNonce(errData) { - // resubmits all unconfirmed transactions - txm.handleNonceErr(ctx, accountAddress) - // resubmits the current 1 unbroadcasted tx that just failed - txm.Enqueue(accountAddress, publicKey, call) + + err := txm.handleBroadcastErr(ctx, data, accountAddress, publicKey, call) + if err != nil { + return txhash, err } } @@ -280,14 +308,18 @@ func (txm *starktxm) broadcast(ctx context.Context, publicKey *felt.Felt, accoun } // TODO: does v3 tx uses fri instead of wei? check feeEstimate[0].FeeUnit? - // pad estimate to 110% + + // pad estimate to 140% (add extra because estimate did not include validation) gasConsumed := friEstimate.GasConsumed.BigInt(new(big.Int)) expandedGas := new(big.Int).Mul(gasConsumed, big.NewInt(140)) maxGas := new(big.Int).Div(expandedGas, big.NewInt(100)) tx.ResourceBounds.L1Gas.MaxAmount = starknetrpc.U64(starknetutils.BigIntToFelt(maxGas).String()) - // TODO: add margin - tx.ResourceBounds.L1Gas.MaxPricePerUnit = starknetrpc.U128(friEstimate.GasPrice.String()) + // pad by 120% + gasPrice := friEstimate.GasPrice.BigInt(new(big.Int)) + expandedGasPrice := new(big.Int).Mul(gasPrice, big.NewInt(120)) + maxGasPrice := new(big.Int).Div(expandedGasPrice, big.NewInt(100)) + tx.ResourceBounds.L1Gas.MaxPricePerUnit = starknetrpc.U128(starknetutils.BigIntToFelt(maxGasPrice).String()) txm.lggr.Infow("Set resource bounds", "L1MaxAmount", tx.ResourceBounds.L1Gas.MaxAmount, "L1MaxPricePerUnit", tx.ResourceBounds.L1Gas.MaxPricePerUnit) @@ -313,14 +345,10 @@ func (txm *starktxm) broadcast(ctx context.Context, publicKey *felt.Felt, accoun var data any if err, ok := err.(ethrpc.DataError); ok { data = err.ErrorData() - errData := fmt.Sprintf("%s", data) - txm.lggr.Debug("err data formatted as string", errData) - - if isInvalidNonce(errData) { - // resubmits all unconfirmed transactions - txm.handleNonceErr(ctx, accountAddress) - // resubmits the current 1 unbroadcasted tx that just failed - txm.Enqueue(accountAddress, publicKey, call) + + err := txm.handleBroadcastErr(ctx, data, accountAddress, publicKey, call) + if err != nil { + return txhash, err } } txm.lggr.Errorw("failed to invoke tx from address", accountAddress, "error", err, "data", data) @@ -396,8 +424,17 @@ func (txm *starktxm) confirmLoop() { } if isInvalidNonce(rejectedTx.ErrorMessage) { + + utx, err := txm.txStore.GetSingleUnconfirmed(addr, hash) + if err != nil { + txm.lggr.Errorw("failed to fetch unconfirmed tx from txstore", err) + } // resubmits all unconfirmed transactions (includes the current one that just failed) - txm.handleNonceErr(ctx, addr) + err = txm.handleNonceErr(ctx, addr, utx.PublicKey) + + if err != nil { + txm.lggr.Errorw("error in nonce handling: ", err) + } // move on to process next address's txs because // unconfirmed txs for this address are out of date because they have been purged and resubmitted // we'll reprocess this address's txs on the next cycle of the confirm loop diff --git a/relayer/pkg/chainlink/txm/txm_test.go b/relayer/pkg/chainlink/txm/txm_test.go index a9fe18cac..f36a4c4bc 100644 --- a/relayer/pkg/chainlink/txm/txm_test.go +++ b/relayer/pkg/chainlink/txm/txm_test.go @@ -60,7 +60,7 @@ func TestIntegration_Txm(t *testing.T) { lggr, observer := logger.TestObserved(t, zapcore.DebugLevel) timeout := 10 * time.Second - client, err := starknet.NewClient("SN_GOERLI", url+"/rpc", lggr, &timeout) + client, err := starknet.NewClient("SN_GOERLI", url+"/rpc", "", lggr, &timeout) require.NoError(t, err) getFeederClient := func() (*starknet.FeederClient, error) { diff --git a/relayer/pkg/chainlink/txm/txstore.go b/relayer/pkg/chainlink/txm/txstore.go index bc3e78bc3..27cfd4d31 100644 --- a/relayer/pkg/chainlink/txm/txstore.go +++ b/relayer/pkg/chainlink/txm/txstore.go @@ -1,6 +1,7 @@ package txm import ( + "errors" "fmt" "sort" "sync" @@ -12,32 +13,45 @@ import ( // TxStore tracks broadcast & unconfirmed txs per account address per chain id type TxStore struct { - lock sync.RWMutex - nonceToHash map[felt.Felt]string // map nonce to txhash - hashToNonce map[string]felt.Felt // map hash to nonce - currentNonce felt.Felt // minimum nonce - hashToCall map[string]*starknetrpc.FunctionCall - hashToKey map[string]felt.Felt + lock sync.RWMutex + nonceToHash map[felt.Felt]string // map nonce to txhash + hashToNonce map[string]felt.Felt // map hash to nonce + hashToCall map[string]*starknetrpc.FunctionCall + hashToKey map[string]felt.Felt } -func NewTxStore(current *felt.Felt) *TxStore { +func NewTxStore() *TxStore { return &TxStore{ - currentNonce: *current, - nonceToHash: map[felt.Felt]string{}, - hashToNonce: map[string]felt.Felt{}, - hashToCall: map[string]*starknetrpc.FunctionCall{}, - hashToKey: map[string]felt.Felt{}, + nonceToHash: map[felt.Felt]string{}, + hashToNonce: map[string]felt.Felt{}, + hashToCall: map[string]*starknetrpc.FunctionCall{}, + hashToKey: map[string]felt.Felt{}, } } -// TODO: Save should make a copy otherwise wee're modiffying the same memory and could loop +func deepCopy(nonce *felt.Felt, call *starknetrpc.FunctionCall, publicKey *felt.Felt) (newNonce *felt.Felt, newCall *starknetrpc.FunctionCall, newPublicKey *felt.Felt) { + newNonce = new(felt.Felt).Set(nonce) + newPublicKey = new(felt.Felt).Set(publicKey) + newCall = copyCall(call) + return +} + +func copyCall(call *starknetrpc.FunctionCall) *starknetrpc.FunctionCall { + copyCall := starknetrpc.FunctionCall{ + ContractAddress: new(felt.Felt).Set(call.ContractAddress), + EntryPointSelector: new(felt.Felt).Set(call.EntryPointSelector), + Calldata: []*felt.Felt{}, + } + for i := 0; i < len(call.Calldata); i++ { + copyCall.Calldata = append(copyCall.Calldata, new(felt.Felt).Set(call.Calldata[i])) + } + return ©Call +} + func (s *TxStore) Save(nonce *felt.Felt, hash string, call *starknetrpc.FunctionCall, publicKey *felt.Felt) error { s.lock.Lock() defer s.lock.Unlock() - if s.currentNonce.Cmp(nonce) == 1 { - return fmt.Errorf("nonce too low: %s < %s (lowest)", nonce, &s.currentNonce) - } if h, exists := s.nonceToHash[*nonce]; exists { return fmt.Errorf("nonce used: tried to use nonce (%s) for tx (%s), already used by (%s)", nonce, hash, h) } @@ -45,19 +59,15 @@ func (s *TxStore) Save(nonce *felt.Felt, hash string, call *starknetrpc.Function return fmt.Errorf("hash used: tried to use tx (%s) for nonce (%s), already used nonce (%s)", hash, nonce, &n) } + newNonce, newCall, newPublicKey := deepCopy(nonce, call, publicKey) + // store hash - s.nonceToHash[*nonce] = hash + s.nonceToHash[*newNonce] = hash - s.hashToNonce[hash] = *nonce - s.hashToCall[hash] = call - s.hashToKey[hash] = *publicKey + s.hashToNonce[hash] = *newNonce + s.hashToCall[hash] = newCall + s.hashToKey[hash] = *newPublicKey - // find next unused nonce - _, exists := s.nonceToHash[s.currentNonce] - for exists { - s.currentNonce = *new(felt.Felt).Add(&s.currentNonce, new(felt.Felt).SetUint64(1)) - _, exists = s.nonceToHash[s.currentNonce] - } return nil } @@ -89,10 +99,32 @@ type UnconfirmedTx struct { Call *starknetrpc.FunctionCall } +func (s *TxStore) GetSingleUnconfirmed(hash string) (tx UnconfirmedTx, err error) { + s.lock.RLock() + defer s.lock.RUnlock() + + n, hExists := s.hashToNonce[hash] + c, cExists := s.hashToCall[hash] + k, kExists := s.hashToKey[hash] + + if !hExists || !cExists || !kExists { + return tx, errors.New("datum not found in txstore") + } + + newNonce, newCall, newPublicKey := deepCopy(&n, c, &k) + + tx.Call = newCall + tx.Nonce = newNonce + tx.PublicKey = newPublicKey + tx.Hash = hash + + return tx, nil +} + // Retrieve Unconfirmed Txs in their queued order (by nonce) func (s *TxStore) GetUnconfirmedSorted() (txs []UnconfirmedTx) { - s.lock.Lock() - defer s.lock.Unlock() + s.lock.RLock() + defer s.lock.RUnlock() nonces := maps.Values(s.hashToNonce) sort.Slice(nonces, func(i, j int) bool { @@ -102,7 +134,12 @@ func (s *TxStore) GetUnconfirmedSorted() (txs []UnconfirmedTx) { for i := 0; i < len(nonces); i++ { n := nonces[i] h := s.nonceToHash[n] - txs = append(txs, UnconfirmedTx{Hash: h, Nonce: &n, Call: s.hashToCall[h]}) + k := s.hashToKey[h] + c := s.hashToCall[h] + + newNonce, newCall, newPublicKey := deepCopy(&n, c, &k) + + txs = append(txs, UnconfirmedTx{Hash: h, Nonce: newNonce, Call: newCall, PublicKey: newPublicKey}) } return txs @@ -131,7 +168,7 @@ func (c *ChainTxStore) Save(from *felt.Felt, nonce *felt.Felt, hash string, call defer c.lock.Unlock() if err := c.validate(from); err != nil { // if does not exist, create a new store for the address - c.store[from] = NewTxStore(nonce) + c.store[from] = NewTxStore() } return c.store[from].Save(nonce, hash, call, publicKey) } @@ -148,15 +185,25 @@ func (c *ChainTxStore) Confirm(from *felt.Felt, hash string) error { } func (c *ChainTxStore) GetUnconfirmedSorted(from *felt.Felt) ([]UnconfirmedTx, error) { - c.lock.Lock() - defer c.lock.Unlock() + c.lock.RLock() + defer c.lock.RUnlock() if err := c.validate(from); err != nil { return nil, err } return c.store[from].GetUnconfirmedSorted(), nil +} + +func (c *ChainTxStore) GetSingleUnconfirmed(from *felt.Felt, hash string) (tx UnconfirmedTx, err error) { + c.lock.RLock() + defer c.lock.RUnlock() + + if err := c.validate(from); err != nil { + return tx, err + } + return c.store[from].GetSingleUnconfirmed(hash) } func (c *ChainTxStore) GetAllInflightCount() map[*felt.Felt]int { diff --git a/relayer/pkg/chainlink/txm/txstore_test.go b/relayer/pkg/chainlink/txm/txstore_test.go index 608b458fb..77d73f171 100644 --- a/relayer/pkg/chainlink/txm/txstore_test.go +++ b/relayer/pkg/chainlink/txm/txstore_test.go @@ -25,7 +25,7 @@ func TestTxStore(t *testing.T) { feltKey := new(felt.Felt).SetUint64(7) - s := NewTxStore(&felt.Zero) + s := NewTxStore() assert.Equal(t, 0, s.InflightCount()) require.NoError(t, s.Save(new(felt.Felt).SetUint64(0), "0x0", call, feltKey)) assert.Equal(t, 1, s.InflightCount()) @@ -39,7 +39,7 @@ func TestTxStore(t *testing.T) { t.Parallel() // create - s := NewTxStore(new(felt.Felt).SetUint64(0)) + s := NewTxStore() call := &starknetrpc.FunctionCall{ ContractAddress: new(felt.Felt).SetUint64(0), @@ -51,26 +51,18 @@ func TestTxStore(t *testing.T) { // accepts tx in order require.NoError(t, s.Save(new(felt.Felt).SetUint64(0), "0x0", call, feltKey)) assert.Equal(t, 1, s.InflightCount()) - assert.Equal(t, new(felt.Felt).SetUint64(1), &s.currentNonce) // accepts tx that skips a nonce require.NoError(t, s.Save(new(felt.Felt).SetUint64(2), "0x2", call, feltKey)) assert.Equal(t, 2, s.InflightCount()) - assert.Equal(t, new(felt.Felt).SetUint64(1), &s.currentNonce) - // accepts tx that fills in the missing nonce + fast forwards currentNonce + // accepts tx that fills in the missing nonce require.NoError(t, s.Save(new(felt.Felt).SetUint64(1), "0x1", call, feltKey)) assert.Equal(t, 3, s.InflightCount()) - assert.Equal(t, new(felt.Felt).SetUint64(3), &s.currentNonce) // skip a nonce for later tests require.NoError(t, s.Save(new(felt.Felt).SetUint64(4), "0x4", call, feltKey)) assert.Equal(t, 4, s.InflightCount()) - assert.Equal(t, new(felt.Felt).SetUint64(3), &s.currentNonce) - - // rejects old nonce - require.ErrorContains(t, s.Save(new(felt.Felt).SetUint64(0), "0xold", call, feltKey), "nonce too low: 0x0 < 0x3 (lowest)") - assert.Equal(t, 4, s.InflightCount()) // reject already in use nonce require.ErrorContains(t, s.Save(new(felt.Felt).SetUint64(4), "0xskip", call, feltKey), "nonce used: tried to use nonce (0x4) for tx (0xskip), already used by (0x4)") @@ -108,7 +100,7 @@ func TestTxStore(t *testing.T) { feltKey := new(felt.Felt).SetUint64(7) // init store - s := NewTxStore(new(felt.Felt).SetUint64(0)) + s := NewTxStore() for i := 0; i < 5; i++ { require.NoError(t, s.Save(new(felt.Felt).SetUint64(uint64(i)), "0x"+fmt.Sprintf("%d", i), call, feltKey)) } diff --git a/relayer/pkg/starknet/client.go b/relayer/pkg/starknet/client.go index 02de45433..e01b98ffc 100644 --- a/relayer/pkg/starknet/client.go +++ b/relayer/pkg/starknet/client.go @@ -2,6 +2,7 @@ package starknet import ( "context" + "strings" "time" "github.com/pkg/errors" @@ -48,9 +49,13 @@ type Client struct { } // pass nil or 0 to timeout to not use built in default timeout -func NewClient(_chainID string, baseURL string, lggr logger.Logger, timeout *time.Duration) (*Client, error) { +func NewClient(_chainID string, baseURL string, apiKey string, lggr logger.Logger, timeout *time.Duration) (*Client, error) { // TODO: chainID now unused c, err := ethrpc.DialContext(context.Background(), baseURL) + if strings.TrimSpace(apiKey) != "" { + c.SetHeader("x-apikey", apiKey) + } + if err != nil { return nil, err } @@ -203,5 +208,5 @@ func (c *Client) AccountNonce(ctx context.Context, accountAddress *felt.Felt) (* if err != nil { return nil, errors.Wrap(err, "error in client.AccountNonce") } - return account.Nonce(ctx, starknetrpc.BlockID{Tag: "latest"}, account.AccountAddress) + return account.Nonce(ctx, starknetrpc.BlockID{Tag: "pending"}, account.AccountAddress) } diff --git a/relayer/pkg/starknet/client_test.go b/relayer/pkg/starknet/client_test.go index e3d88b24a..ee67900ea 100644 --- a/relayer/pkg/starknet/client_test.go +++ b/relayer/pkg/starknet/client_test.go @@ -52,7 +52,7 @@ func TestRPCClient(t *testing.T) { defer mockServer.Close() lggr := logger.Test(t) - client, err := NewClient(chainID, mockServer.URL, lggr, &timeout) + client, err := NewClient(chainID, mockServer.URL, "", lggr, &timeout) require.NoError(t, err) assert.Equal(t, timeout, client.defaultTimeout)