Skip to content

Commit

Permalink
Add backward compatibility for EIP-1898 getBalance method
Browse files Browse the repository at this point in the history
Signed-off-by: Luca Georges Francois <luca.georges-francois@epitech.eu>
  • Loading branch information
0xpanoramix committed Dec 10, 2021
1 parent 53af097 commit b2f6abb
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 23 deletions.
57 changes: 53 additions & 4 deletions jsonrpc/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package jsonrpc
import (
"encoding/json"
"fmt"
"strconv"
"strings"

"github.com/0xPolygon/polygon-sdk/types"
Expand Down Expand Up @@ -103,13 +104,61 @@ const (
type BlockNumber int64

type BlockNumberOrHash struct {
BlockNumber BlockNumber
BlockHash types.Hash
RequireCanonical bool
BlockNumber *BlockNumber `json:"blockNumber,omitempty"`
BlockHash *types.Hash `json:"blockHash,omitempty"`
}

func (h *BlockNumberOrHash) Extract() {
func (bnh *BlockNumberOrHash) Unmarshal(input *interface{}) error {
var placeholder BlockNumberOrHash

data, err := json.Marshal(*input)
if err != nil {
return fmt.Errorf("failed to serialize input: %v", err)
}

err = json.Unmarshal(data, &placeholder)
if err != nil {
var keyword string
err = json.Unmarshal(data, &keyword)
if err == nil {
// Try to extract keyword
switch keyword {
case "pending":
n := PendingBlockNumber
bnh.BlockNumber = &n
return nil
case "latest":
n := LatestBlockNumber
bnh.BlockNumber = &n
return nil
case "earliest":
n := EarliestBlockNumber
bnh.BlockNumber = &n
return nil
default:
// Try to extract hex number
s, ok := (*input).(string)
if !ok {
return fmt.Errorf("input cannot be converted to string")
}
number, err := strconv.ParseInt(s[2:], 16, 64)
if err != nil {
return fmt.Errorf("failed to convert hex string to int64: %v", err)
}
bnh.BlockNumber = (*BlockNumber)(&number)
return nil
}
}
}

// Try to extract object
bnh.BlockNumber = placeholder.BlockNumber
bnh.BlockHash = placeholder.BlockHash
return nil
}

func (bnh *BlockNumberOrHash) GetNumber() BlockNumber {
return *bnh.BlockNumber
}

func stringToBlockNumber(str string) (BlockNumber, error) {
Expand Down
33 changes: 16 additions & 17 deletions jsonrpc/eth_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -550,34 +550,33 @@ func (e *Eth) GetLogs(filterOptions *LogFilter) (interface{}, error) {
return result, nil
}

// GetBalance returns the account's balance at the referenced block
func (e *Eth) GetBalance(
address types.Address,
filter *BlockNumberOrHash,
) (interface{}, error) {
var number *BlockNumber
var blockHash string
// GetBalance returns the account's balance at the referenced block.
func (e *Eth) GetBalance(address types.Address, filter interface{}) (interface{}, error) {
var bnh BlockNumberOrHash
var header *types.Header
var err error

if filter == nil {
number, _ = createBlockNumberPointer("latest")
bnh.BlockNumber, _ = createBlockNumberPointer("latest")
} else {
number = &filter.BlockNumber
err = bnh.Unmarshal(&filter)
if err != nil {
return nil, fmt.Errorf("failed to decode filter: %v", err)
}
}

if filter != nil && filter.BlockHash != types.ZeroHash {
block, ok := e.d.store.GetBlockByHash(filter.BlockHash, false)
if bnh.BlockNumber != nil {
header, err = e.d.getBlockHeaderImpl(*bnh.BlockNumber)
if err != nil {
return nil, err
}
} else {
block, ok := e.d.store.GetBlockByHash(*bnh.BlockHash, false)
if !ok {
return nil, fmt.Errorf("could not find block referenced by the hash %s", blockHash)
return nil, fmt.Errorf("could not find block referenced by the hash %s", bnh.BlockHash.String())
}

header = block.Header
} else {
header, err = e.d.getBlockHeaderImpl(*number)
if err != nil {
return nil, err
}
}

acc, err := e.d.store.GetAccount(header.StateRoot, address)
Expand Down
6 changes: 4 additions & 2 deletions jsonrpc/eth_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,12 +425,14 @@ func TestEth_State_GetBalance(t *testing.T) {
if err != nil {
assert.Error(t, err)
}
balance, err := dispatcher.endpoints.Eth.GetBalance(addr0, &BlockNumberOrHash{BlockNumber: *blockNumber})
param := BlockNumberOrHash{BlockNumber: blockNumber}
balance, err := dispatcher.endpoints.Eth.GetBalance(addr0, &param)
assert.NoError(t, err)
assert.Equal(t, balance, argBigPtr(big.NewInt(100)))

// address not found
balance, err = dispatcher.endpoints.Eth.GetBalance(addr1, &BlockNumberOrHash{BlockNumber: *blockNumber})
param = BlockNumberOrHash{BlockNumber: blockNumber}
balance, err = dispatcher.endpoints.Eth.GetBalance(addr1, &param)
assert.NoError(t, err)
assert.Equal(t, balance, argUintPtr(0))

Expand Down

0 comments on commit b2f6abb

Please sign in to comment.