From 48de090b70eb1b95692ac480f2431792a8fa1887 Mon Sep 17 00:00:00 2001 From: vro <168573323+golangisfun123@users.noreply.github.com> Date: Mon, 5 Aug 2024 16:35:42 -0500 Subject: [PATCH] feat(relconfig): decimals config validation (#2919) * decimals implementation * add test * add test * lint * [goreleaser] * handle native gas token * combine into Validate func * commnent * [goreleaser] * better call to loadconfig and abstract away validate * abstract away validate * [goreleaser] --------- Co-authored-by: Trajan0x --- services/rfq/relayer/cmd/commands.go | 5 +- services/rfq/relayer/relconfig/config.go | 62 +++++++++++++- services/rfq/relayer/relconfig/config_test.go | 82 ++++++++++++++++++- services/rfq/relayer/relconfig/suite_test.go | 50 +++++++++++ 4 files changed, 189 insertions(+), 10 deletions(-) create mode 100644 services/rfq/relayer/relconfig/suite_test.go diff --git a/services/rfq/relayer/cmd/commands.go b/services/rfq/relayer/cmd/commands.go index 31697c730b..0a5e1f6f0e 100644 --- a/services/rfq/relayer/cmd/commands.go +++ b/services/rfq/relayer/cmd/commands.go @@ -32,13 +32,14 @@ var runCommand = &cli.Command{ Flags: []cli.Flag{configFlag, &commandline.LogLevel}, Action: func(c *cli.Context) (err error) { commandline.SetLogLevel(c) + + metricsProvider := metrics.Get() + cfg, err := relconfig.LoadConfig(core.ExpandOrReturnPath(c.String(configFlag.Name))) if err != nil { return fmt.Errorf("could not read config file: %w", err) } - metricsProvider := metrics.Get() - relayer, err := service.NewRelayer(c.Context, metricsProvider, cfg) if err != nil { return fmt.Errorf("could not create relayer: %w", err) diff --git a/services/rfq/relayer/relconfig/config.go b/services/rfq/relayer/relconfig/config.go index 36614772d0..7ed139cfec 100644 --- a/services/rfq/relayer/relconfig/config.go +++ b/services/rfq/relayer/relconfig/config.go @@ -2,6 +2,7 @@ package relconfig import ( + "context" "fmt" "math" "os" @@ -9,14 +10,20 @@ import ( "strings" "time" + "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" "github.com/jftuga/ellipsis" + "github.com/synapsecns/sanguine/core/metrics" "github.com/synapsecns/sanguine/ethergo/signer/config" submitterConfig "github.com/synapsecns/sanguine/ethergo/submitter/config" cctpConfig "github.com/synapsecns/sanguine/services/cctp-relayer/config" + "github.com/synapsecns/sanguine/services/rfq/contracts/ierc20" + "github.com/synapsecns/sanguine/services/rfq/relayer/chain" "gopkg.in/yaml.v2" "path/filepath" + + omniClient "github.com/synapsecns/sanguine/services/omnirpc/client" ) // Config represents the configuration for the relayer. @@ -199,15 +206,18 @@ func LoadConfig(path string) (config Config, err error) { if err != nil { return Config{}, fmt.Errorf("could not unmarshall config %s: %w", ellipsis.Shorten(string(input), 30), err) } - err = config.Validate() + omniClient := omniClient.NewOmnirpcClient(config.OmniRPCURL, metrics.NewNullHandler(), omniClient.WithCaptureReqRes()) + err = config.Validate(context.Background(), omniClient) if err != nil { - return config, fmt.Errorf("error validating config: %w", err) + return Config{}, fmt.Errorf("config validation failed: %w", err) } + return config, nil } -// Validate validates the config. -func (c Config) Validate() (err error) { +// Validate validates the config. Omniclient may be nil, but if not then it will also check the chain to see if the decimals +// match the actual token decimals. +func (c Config) Validate(ctx context.Context, omniclient omniClient.RPCClient) (err error) { maintenancePctSums := map[string]float64{} initialPctSums := map[string]float64{} for _, chainCfg := range c.Chains { @@ -228,5 +238,49 @@ func (c Config) Validate() (err error) { return fmt.Errorf("total initial percent does not total 100 for %s: %f", token, sum) } } + + if omniclient != nil { + err = c.validateTokenDecimals(ctx, omniclient) + if err != nil { + return fmt.Errorf("error validating token decimals: %w", err) + } + } + + return nil +} + +// ValidateTokenDecimals calls decimals() on the ERC20s to ensure that the decimals in the config match the actual token decimals. +func (c Config) validateTokenDecimals(ctx context.Context, omniClient omniClient.RPCClient) (err error) { + for chainID, chainCfg := range c.Chains { + for tokenName, tokenCFG := range chainCfg.Tokens { + chainClient, err := omniClient.GetChainClient(ctx, chainID) + if err != nil { + return fmt.Errorf("could not get chain client for chain %d: %w", chainID, err) + } + + // Check if the token is the gas token. SHOULD BE 18. + if tokenCFG.Address == chain.EthAddress.String() { + if tokenCFG.Decimals != 18 { + return fmt.Errorf("decimals mismatch for token %s on chain %d: expected 18, got %d", tokenName, chainID, tokenCFG.Decimals) + } + continue + } + + ierc20, err := ierc20.NewIERC20(common.HexToAddress(tokenCFG.Address), chainClient) + if err != nil { + return fmt.Errorf("could not create caller for token %s at address %s on chain %d: %w", tokenName, tokenCFG.Address, chainID, err) + } + + actualDecimals, err := ierc20.Decimals(&bind.CallOpts{Context: ctx}) + if err != nil { + return fmt.Errorf("could not get decimals for token %s on chain %d: %w", tokenName, chainID, err) + } + + if actualDecimals != tokenCFG.Decimals { + return fmt.Errorf("decimals mismatch for token %s on chain %d: expected %d, got %d", tokenName, chainID, tokenCFG.Decimals, actualDecimals) + } + } + } + return nil } diff --git a/services/rfq/relayer/relconfig/config_test.go b/services/rfq/relayer/relconfig/config_test.go index bac0016c52..1a3357f595 100644 --- a/services/rfq/relayer/relconfig/config_test.go +++ b/services/rfq/relayer/relconfig/config_test.go @@ -1,6 +1,7 @@ package relconfig_test import ( + "context" "testing" "time" @@ -372,7 +373,7 @@ func TestValidation(t *testing.T) { }, }, } - err := cfg.Validate() + err := cfg.Validate(context.Background(), nil) assert.Nil(t, err) }) @@ -399,7 +400,7 @@ func TestValidation(t *testing.T) { }, }, } - err := cfg.Validate() + err := cfg.Validate(context.Background(), nil) assert.NotNil(t, err) assert.Equal(t, "total initial percent does not total 100 for USDC: 101.000000", err.Error()) }) @@ -427,7 +428,7 @@ func TestValidation(t *testing.T) { }, }, } - err := cfg.Validate() + err := cfg.Validate(context.Background(), nil) assert.NotNil(t, err) assert.Equal(t, "total maintenance percent exceeds 100 for USDC: 100.100000", err.Error()) }) @@ -453,7 +454,7 @@ func TestValidation(t *testing.T) { }, }, } - err := cfg.Validate() + err := cfg.Validate(context.Background(), nil) assert.Nil(t, err) }) } @@ -504,3 +505,76 @@ func TestDecodeTokenID(t *testing.T) { }) } } + +const usdcAddr = "0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48" +const arbAddr = "0x912CE59144191C1204E64559FE8253a0e49E6548" +const opAddr = "0x4200000000000000000000000000000000000042" + +func (v *ValidateDecimalsSuite) TestValidateWrongDecimals() { + cfg := relconfig.Config{ + Chains: map[int]relconfig.ChainConfig{ + 1: { + Tokens: map[string]relconfig.TokenConfig{ + "USDC": { + Address: usdcAddr, + Decimals: 18, // WRONG + }, + }, + }, + }, + } + err := cfg.Validate(v.GetTestContext(), v.omniClient) + // we should error because the decimals are wrong + v.Require().Error(err) +} + +func (v *ValidateDecimalsSuite) TestValidateCorrectDecimals() { + cfg := relconfig.Config{ + Chains: map[int]relconfig.ChainConfig{ + 1: { + Tokens: map[string]relconfig.TokenConfig{ + "USDC": { + Address: usdcAddr, + Decimals: 6, + }, + }, + }, + }, + } + err := cfg.Validate(v.GetTestContext(), v.omniClient) + v.Require().NoError(err) +} + +func (v *ValidateDecimalsSuite) TestMixtureDecimals() { + cfg := relconfig.Config{ + Chains: map[int]relconfig.ChainConfig{ + 1: { + Tokens: map[string]relconfig.TokenConfig{ + "USDC": { + Address: usdcAddr, + Decimals: 6, + }, + }, + }, + 42161: { + Tokens: map[string]relconfig.TokenConfig{ + "ARB": { + Address: arbAddr, + Decimals: 18, + }, + }, + }, + 10: { + Tokens: map[string]relconfig.TokenConfig{ + "OP": { + Address: opAddr, + Decimals: 69, + }, + }, + }, + }, + } + + err := cfg.Validate(v.GetTestContext(), v.omniClient) + v.Require().Error(err) +} diff --git a/services/rfq/relayer/relconfig/suite_test.go b/services/rfq/relayer/relconfig/suite_test.go new file mode 100644 index 0000000000..f6c66f7902 --- /dev/null +++ b/services/rfq/relayer/relconfig/suite_test.go @@ -0,0 +1,50 @@ +package relconfig_test + +import ( + "testing" + + "github.com/stretchr/testify/suite" + "github.com/synapsecns/sanguine/core" + "github.com/synapsecns/sanguine/core/metrics" + "github.com/synapsecns/sanguine/core/metrics/localmetrics" + "github.com/synapsecns/sanguine/core/testsuite" + omniClient "github.com/synapsecns/sanguine/services/omnirpc/client" + "github.com/synapsecns/sanguine/services/rfq/relayer/metadata" +) + +func TestValidateDecimalsSuite(t *testing.T) { + suite.Run(t, NewTestSuite(t)) +} + +type ValidateDecimalsSuite struct { + *testsuite.TestSuite + // testBackends contains a list of all test backends + metricsHandler metrics.Handler + omniClient omniClient.RPCClient +} + +// NewTestSuite creates a new test suite. +func NewTestSuite(tb testing.TB) *ValidateDecimalsSuite { + tb.Helper() + return &ValidateDecimalsSuite{ + TestSuite: testsuite.NewTestSuite(tb), + } +} + +func (v *ValidateDecimalsSuite) SetupSuite() { + v.TestSuite.SetupSuite() + + var err error + // don't use metrics on ci for integration tests + isCI := core.GetEnvBool("CI", false) + metricsHandler := metrics.Null + + if !isCI { + localmetrics.SetupTestJaeger(v.GetSuiteContext(), v.T()) + metricsHandler = metrics.Jaeger + } + v.metricsHandler, err = metrics.NewByType(v.GetSuiteContext(), metadata.BuildInfo(), metricsHandler) + v.Require().NoError(err) + + v.omniClient = omniClient.NewOmnirpcClient("https://rpc.omnirpc.io", v.metricsHandler, omniClient.WithCaptureReqRes()) +}