Skip to content

Commit

Permalink
ELO-335 -- make tokenAddresses optional for reward claiming; only que…
Browse files Browse the repository at this point in the history
…ry tokens which have pending rewards. (#223)
  • Loading branch information
bdchatham authored Oct 12, 2024
1 parent 08c43fc commit 297b1af
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 8 deletions.
59 changes: 53 additions & 6 deletions pkg/rewards/claim.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
contractrewardscoordinator "github.com/Layr-Labs/eigenlayer-contracts/pkg/bindings/IRewardsCoordinator"

"github.com/Layr-Labs/eigenlayer-rewards-proofs/pkg/claimgen"
"github.com/Layr-Labs/eigenlayer-rewards-proofs/pkg/distribution"
"github.com/Layr-Labs/eigenlayer-rewards-proofs/pkg/proofDataFetcher/httpProofDataFetcher"

"github.com/Layr-Labs/eigensdk-go/chainio/clients/elcontracts"
Expand All @@ -31,6 +32,7 @@ import (
"github.com/ethereum/go-ethereum/ethclient"

"github.com/urfave/cli/v2"
"github.com/wk8/go-ordered-map/v2"
)

type elChainReader interface {
Expand Down Expand Up @@ -121,10 +123,15 @@ func Claim(cCtx *cli.Context, p utils.Prompter) error {
return eigenSdkUtils.WrapError("failed to fetch claim amounts for date", err)
}

claimableTokens, present := proofData.Distribution.GetTokensForEarner(config.EarnerAddress)
if !present {
return errors.New("no tokens claimable by earner")
}

cg := claimgen.NewClaimgen(proofData.Distribution)
accounts, claim, err := cg.GenerateClaimProofForEarner(
config.EarnerAddress,
config.TokenAddresses,
getTokensToClaim(claimableTokens, config.TokenAddresses),
rootIndex,
)
if err != nil {
Expand Down Expand Up @@ -296,6 +303,44 @@ func getClaimDistributionRoot(
return "", 0, errors.New("invalid claim timestamp")
}

func getTokensToClaim(
claimableTokens *orderedmap.OrderedMap[gethcommon.Address, *distribution.BigInt],
tokenAddresses []gethcommon.Address,
) []gethcommon.Address {
if len(tokenAddresses) == 0 {
tokenAddresses = getAllClaimableTokenAddresses(claimableTokens)
} else {
tokenAddresses = filterClaimableTokenAddresses(claimableTokens, tokenAddresses)
}

return tokenAddresses
}

func getAllClaimableTokenAddresses(
addressesMap *orderedmap.OrderedMap[gethcommon.Address, *distribution.BigInt],
) []gethcommon.Address {
var addresses []gethcommon.Address
for pair := addressesMap.Oldest(); pair != nil; pair = pair.Next() {
addresses = append(addresses, pair.Key)
}

return addresses
}

func filterClaimableTokenAddresses(
addressesMap *orderedmap.OrderedMap[gethcommon.Address, *distribution.BigInt],
providedAddresses []gethcommon.Address,
) []gethcommon.Address {
var addresses []gethcommon.Address
for _, address := range providedAddresses {
if _, ok := addressesMap.Get(address); ok {
addresses = append(addresses, address)
}
}

return addresses
}

func convertClaimTokenLeaves(
claimTokenLeaves []contractrewardscoordinator.IRewardsCoordinatorTokenTreeMerkleLeaf,
) []rewardscoordinator.IRewardsCoordinatorTokenTreeMerkleLeaf {
Expand All @@ -307,7 +352,6 @@ func convertClaimTokenLeaves(
})
}
return tokenLeaves

}

func readAndValidateClaimConfig(cCtx *cli.Context, logger logging.Logger) (*ClaimConfig, error) {
Expand All @@ -319,7 +363,8 @@ func readAndValidateClaimConfig(cCtx *cli.Context, logger logging.Logger) (*Clai
outputType := cCtx.String(flags.OutputTypeFlag.Name)
broadcast := cCtx.Bool(flags.BroadcastFlag.Name)
tokenAddresses := cCtx.String(TokenAddressesFlag.Name)
tokenAddressArray := stringToAddressArray(strings.Split(tokenAddresses, ","))
splitTokenAddresses := strings.Split(tokenAddresses, ",")
validTokenAddresses := getValidHexAddresses(splitTokenAddresses)
rewardsCoordinatorAddress := cCtx.String(RewardsCoordinatorAddressFlag.Name)

var err error
Expand Down Expand Up @@ -395,7 +440,7 @@ func readAndValidateClaimConfig(cCtx *cli.Context, logger logging.Logger) (*Clai
Output: output,
OutputType: outputType,
Broadcast: broadcast,
TokenAddresses: tokenAddressArray,
TokenAddresses: validTokenAddresses,
RewardsCoordinatorAddress: gethcommon.HexToAddress(rewardsCoordinatorAddress),
ChainID: chainID,
ProofStoreBaseURL: proofStoreBaseURL,
Expand Down Expand Up @@ -428,10 +473,12 @@ func getEnvFromNetwork(network string) string {
}
}

func stringToAddressArray(addresses []string) []gethcommon.Address {
func getValidHexAddresses(addresses []string) []gethcommon.Address {
var addressArray []gethcommon.Address
for _, address := range addresses {
addressArray = append(addressArray, gethcommon.HexToAddress(address))
if gethcommon.IsHexAddress(address) && address != utils.ZeroAddress.String() {
addressArray = append(addressArray, gethcommon.HexToAddress(address))
}
}
return addressArray
}
107 changes: 107 additions & 0 deletions pkg/rewards/claim_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ import (

"github.com/Layr-Labs/eigenlayer-cli/pkg/internal/common/flags"
"github.com/Layr-Labs/eigenlayer-cli/pkg/internal/testutils"
"github.com/Layr-Labs/eigenlayer-cli/pkg/utils"

"github.com/Layr-Labs/eigenlayer-rewards-proofs/pkg/distribution"

rewardscoordinator "github.com/Layr-Labs/eigensdk-go/contracts/bindings/IRewardsCoordinator"
"github.com/Layr-Labs/eigensdk-go/logging"
Expand All @@ -21,6 +24,7 @@ import (

"github.com/stretchr/testify/assert"
"github.com/urfave/cli/v2"
"github.com/wk8/go-ordered-map/v2"
)

type fakeELReader struct {
Expand Down Expand Up @@ -114,6 +118,48 @@ func TestReadAndValidateConfig_NoRecipientProvided(t *testing.T) {
assert.Equal(t, common.HexToAddress(earnerAddress), config.RecipientAddress)
}

func TestReadAndValidateConfig_NoTokenAddressesProvided(t *testing.T) {
earnerAddress := testutils.GenerateRandomEthereumAddressString()
recipientAddress := testutils.GenerateRandomEthereumAddressString()
fs := flag.NewFlagSet("test", flag.ContinueOnError)
fs.String(flags.ETHRpcUrlFlag.Name, "rpc", "")
fs.String(EarnerAddressFlag.Name, earnerAddress, "")
fs.String(RecipientAddressFlag.Name, recipientAddress, "")
fs.String(RewardsCoordinatorAddressFlag.Name, "0x1234", "")
fs.String(TokenAddressesFlag.Name, "", "")
fs.String(ClaimTimestampFlag.Name, "latest", "")
fs.String(ProofStoreBaseURLFlag.Name, "dummy-url", "")
cliCtx := cli.NewContext(nil, fs, nil)

logger := logging.NewJsonSLogger(os.Stdout, &logging.SLoggerOptions{})

config, err := readAndValidateClaimConfig(cliCtx, logger)

assert.NoError(t, err)
assert.ElementsMatch(t, config.TokenAddresses, []common.Address{})
}

func TestReadAndValidateConfig_ZeroTokenAddressesProvided(t *testing.T) {
earnerAddress := testutils.GenerateRandomEthereumAddressString()
recipientAddress := testutils.GenerateRandomEthereumAddressString()
fs := flag.NewFlagSet("test", flag.ContinueOnError)
fs.String(flags.ETHRpcUrlFlag.Name, "rpc", "")
fs.String(EarnerAddressFlag.Name, earnerAddress, "")
fs.String(RecipientAddressFlag.Name, recipientAddress, "")
fs.String(RewardsCoordinatorAddressFlag.Name, "0x1234", "")
fs.String(TokenAddressesFlag.Name, utils.ZeroAddress.String(), "")
fs.String(ClaimTimestampFlag.Name, "latest", "")
fs.String(ProofStoreBaseURLFlag.Name, "dummy-url", "")
cliCtx := cli.NewContext(nil, fs, nil)

logger := logging.NewJsonSLogger(os.Stdout, &logging.SLoggerOptions{})

config, err := readAndValidateClaimConfig(cliCtx, logger)

assert.NoError(t, err)
assert.ElementsMatch(t, config.TokenAddresses, []common.Address{})
}

func TestReadAndValidateConfig_RecipientProvided(t *testing.T) {
earnerAddress := testutils.GenerateRandomEthereumAddressString()
recipientAddress := testutils.GenerateRandomEthereumAddressString()
Expand Down Expand Up @@ -222,3 +268,64 @@ func TestGetClaimDistributionRoot(t *testing.T) {
})
}
}

func TestGetTokensToClaim(t *testing.T) {
// Set up a mock claimableTokens map
claimableTokens := orderedmap.New[common.Address, *distribution.BigInt]()
addr1 := common.HexToAddress(testutils.GenerateRandomEthereumAddressString())
addr2 := common.HexToAddress(testutils.GenerateRandomEthereumAddressString())
addr3 := common.HexToAddress(testutils.GenerateRandomEthereumAddressString())

claimableTokens.Set(addr1, newBigInt(100))
claimableTokens.Set(addr2, newBigInt(200))

// Case 1: No token addresses provided, should return all addresses in claimableTokens
result := getTokensToClaim(claimableTokens, []common.Address{})
expected := []common.Address{addr1, addr2}
assert.ElementsMatch(t, result, expected)

// Case 2: Provided token addresses, should return only those present in claimableTokens
result = getTokensToClaim(claimableTokens, []common.Address{addr2, addr3})
expected = []common.Address{addr2}
assert.ElementsMatch(t, result, expected)
}

func TestGetTokenAddresses(t *testing.T) {
// Set up a mock addresses map
addressesMap := orderedmap.New[common.Address, *distribution.BigInt]()
addr1 := common.HexToAddress(testutils.GenerateRandomEthereumAddressString())
addr2 := common.HexToAddress(testutils.GenerateRandomEthereumAddressString())

addressesMap.Set(addr1, newBigInt(100))
addressesMap.Set(addr2, newBigInt(200))

// Test that the function returns all addresses in the map
result := getAllClaimableTokenAddresses(addressesMap)
expected := []common.Address{addr1, addr2}
assert.ElementsMatch(t, result, expected)
}

func TestFilterClaimableTokenAddresses(t *testing.T) {
// Set up a mock addresses map
addressesMap := orderedmap.New[common.Address, *distribution.BigInt]()
addr1 := common.HexToAddress(testutils.GenerateRandomEthereumAddressString())
addr2 := common.HexToAddress(testutils.GenerateRandomEthereumAddressString())

addressesMap.Set(addr1, newBigInt(100))
addressesMap.Set(addr2, newBigInt(200))

// Test filtering with provided addresses
newMissingAddress := common.HexToAddress(testutils.GenerateRandomEthereumAddressString())
providedAddresses := []common.Address{
addr1,
newMissingAddress,
}

result := filterClaimableTokenAddresses(addressesMap, providedAddresses)
expected := []common.Address{addr1}
assert.ElementsMatch(t, result, expected)
}

func newBigInt(value int64) *distribution.BigInt {
return &distribution.BigInt{Int: big.NewInt(value)}
}
4 changes: 2 additions & 2 deletions pkg/rewards/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ var (
TokenAddressesFlag = cli.StringFlag{
Name: "token-addresses",
Aliases: []string{"t"},
Usage: "Specify the addresses of the tokens to claim. Comma separated list of addresses",
Usage: "Specify the addresses of the tokens to claim. Comma separated list of addresses. Omit to claim all rewards.",
EnvVars: []string{"TOKEN_ADDRESSES"},
Required: true,
Required: false,
}

RewardsCoordinatorAddressFlag = cli.StringFlag{
Expand Down

0 comments on commit 297b1af

Please sign in to comment.