Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pooling of big.Int instances in the EVM #124

Merged
merged 4 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions helper/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,43 @@ func EncodeUint64ToBytes(value uint64) []byte {
func EncodeBytesToUint64(b []byte) uint64 {
return binary.BigEndian.Uint64(b)
}

// Generic object pool implementation intended to be used in single-threaded
// manner and avoid synchronization overhead. It could be probably additionally
// be improved by using circular buffer as oposed to stack.
type UnsafePool[T any] struct {
Stefan-Ethernal marked this conversation as resolved.
Show resolved Hide resolved
stack []T
}

// Creates new instance of UnsafePool. Depending on observed usage, pool size
// should be set on creation to avoid pool resizing
func NewUnsafePool[T any]() *UnsafePool[T] {
return &UnsafePool[T]{}
}

// Get retrieves an object from the unsafepool, or allocates a new one if the pool
// is empty. The allocation logic (i.e., creating a new object of type T) needs to
// be provided externally, as Go's type system does not allow calling constructors
// or functions specific to T without an interface.
func (f *UnsafePool[T]) Get(newFunc func() T) T {
n := len(f.stack)
if n == 0 {
// Allocate a new T instance using the provided newFunc if the stack is empty.
return newFunc()
}

obj := f.stack[n-1]
f.stack = f.stack[:n-1]

return obj
}

// Put returns an object to the pool and executes reset function if provided. Reset
// function is used to return the T instance to initial state.
func (f *UnsafePool[T]) Put(resetFunc func(T) T, obj T) {
if resetFunc != nil {
obj = resetFunc(obj)
}

f.stack = append(f.stack, obj)
}
53 changes: 53 additions & 0 deletions helper/common/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,56 @@ func Test_SafeAddUint64(t *testing.T) {
})
}
}

func TestNewUnsafePool(t *testing.T) {
pool := NewUnsafePool[int]()

require.NotNilf(t, pool, "NewUnsafePool returned nil")

require.Empty(t, "Expected empty pool, got %v", pool.stack)
}

func TestUnsafePoolGetWhenEmpty(t *testing.T) {
pool := NewUnsafePool[int]()
newInt := func() int {
return 1
}

obj := pool.Get(newInt)

require.Equal(t, 1, obj, "Expected 1 from newFunc, got %v", obj)
}

func TestUnsafePoolGetPut(t *testing.T) {
pool := NewUnsafePool[int]()
resetInt := func(i int) int {
return 0
}

// Initially put an object into the pool.
pool.Put(resetInt, 2)

// Retrieve the object, which should now be the reset value.
obj := pool.Get(func() int { return 3 })

// Expecting the original object, not the one from newFunc
require.Equal(t, 0, obj, "Expected 0 from the pool, got %v", obj)

// Test if Get correctly uses newFunc when pool is empty again.
obj = pool.Get(func() int { return 3 })

require.Equal(t, 3, obj, "Expected 3 from newFunc, got %v", obj)
}

func TestUnsafePoolPutWithReset(t *testing.T) {
pool := NewUnsafePool[int]()
resetInt := func(i int) int {
return 0
}

// Put an object into the pool with a reset function.
pool.Put(resetInt, 5)

// Directly check if the object was reset.
require.Equal(t, 0, pool.stack[0], "Expected object to be reset to 0, got %v", pool.stack[0])
}
52 changes: 41 additions & 11 deletions state/runtime/evm/instructions_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package evm

import (
"errors"
"math/big"
"reflect"
"testing"

"github.com/0xPolygon/polygon-edge/chain"
Expand Down Expand Up @@ -37,9 +39,9 @@ func testLogicalOperation(t *testing.T, f instruction, test OperandsLogical, s *
f(s)

if test.expectedResult {
assert.Equal(t, one, s.pop())
assert.Equal(t, one.Uint64(), s.pop().Uint64())
} else {
assert.Equal(t, zero, s.pop())
assert.Equal(t, zero.Uint64(), s.pop().Uint64())
}
}

Expand All @@ -57,7 +59,7 @@ func testArithmeticOperation(t *testing.T, f instruction, test OperandsArithmeti

f(s)

assert.Equal(t, test.expectedResult, s.pop())
assert.Equal(t, test.expectedResult.Uint64(), s.pop().Uint64())
}

func TestAdd(t *testing.T) {
Expand Down Expand Up @@ -355,7 +357,7 @@ func TestPush0(t *testing.T) {
defer closeFn()

opPush0(s)
require.Equal(t, zero, s.pop())
require.Equal(t, zero.Uint64(), s.pop().Uint64())
})

t.Run("single push0 (EIP-3855 disabled)", func(t *testing.T) {
Expand Down Expand Up @@ -857,15 +859,15 @@ func TestCallValue(t *testing.T) {
s.msg.Value = value

opCallValue(s)
assert.Equal(t, value, s.pop())
assert.Equal(t, value.Uint64(), s.pop().Uint64())
})

t.Run("Msg Value nil", func(t *testing.T) {
s, cancelFn := getState(&chain.ForksInTime{})
defer cancelFn()

opCallValue(s)
assert.Equal(t, zero, s.pop())
assert.Equal(t, zero.Uint64(), s.pop().Uint64())
})
}

Expand All @@ -879,7 +881,7 @@ func TestCallDataLoad(t *testing.T) {
s.msg = &runtime.Contract{Input: big.NewInt(7).Bytes()}

opCallDataLoad(s)
assert.Equal(t, zero, s.pop())
assert.Equal(t, zero.Uint64(), s.pop().Uint64())
})
t.Run("ZeroOffset", func(t *testing.T) {
s, cancelFn := getState(&chain.ForksInTime{})
Expand All @@ -890,7 +892,7 @@ func TestCallDataLoad(t *testing.T) {
s.msg = &runtime.Contract{Input: big.NewInt(7).Bytes()}

opCallDataLoad(s)
assert.NotEqual(t, zero, s.pop())
assert.NotEqual(t, zero.Uint64(), s.pop().Uint64())
})
}

Expand Down Expand Up @@ -1013,7 +1015,7 @@ func TestExtCodeHash(t *testing.T) {
opExtCodeHash(s)

assert.Equal(t, s.gas, gasLeft)
assert.Equal(t, one, s.pop())
assert.Equal(t, one.Uint64(), s.pop().Uint64())
})

t.Run("NonIstanbul", func(t *testing.T) {
Expand All @@ -1032,7 +1034,7 @@ func TestExtCodeHash(t *testing.T) {

opExtCodeHash(s)
assert.Equal(t, gasLeft, s.gas)
assert.Equal(t, zero, s.pop())
assert.Equal(t, zero.Uint64(), s.pop().Uint64())
})

t.Run("NoForks", func(t *testing.T) {
Expand Down Expand Up @@ -2288,11 +2290,39 @@ func Test_opReturnDataCopy(t *testing.T) {

opReturnDataCopy(state)

assert.Equal(t, test.resultState, state)
assert.True(t, CompareStates(test.resultState, state))
})
}
}

// Since the state is complex structure, here is the specialized comparison
// function that checks significant fields. This function should be updated
// to suite future needs.
func CompareStates(a *state, b *state) bool {
// Compare simple fields
if a.ip != b.ip || a.lastGasCost != b.lastGasCost || a.sp != b.sp || !errors.Is(a.err, b.err) || a.stop != b.stop || a.gas != b.gas {
return false
}

// Deep compare slices
if !reflect.DeepEqual(a.code, b.code) || !reflect.DeepEqual(a.tmp, b.tmp) || !reflect.DeepEqual(a.returnData, b.returnData) || !reflect.DeepEqual(a.memory, b.memory) {
return false
}

// Deep comparison of stacks
if len(a.stack) != len(b.stack) {
return false
}

for i := range a.stack {
if a.stack[i].Cmp(b.stack[i]) != 0 {
return false
}
}

return true
}

func Test_opCall(t *testing.T) {
t.Parallel()

Expand Down
23 changes: 16 additions & 7 deletions state/runtime/evm/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ type state struct {

returnData []byte
ret []byte

unsafepool common.UnsafePool[*big.Int]
}

func (c *state) reset() {
Expand All @@ -95,6 +97,13 @@ func (c *state) reset() {
c.memory[i] = 0
}

// Before stack cleanup, return instances of big.Int to the pool
// for the future usage
for i := range c.stack {
c.unsafepool.Put(func(x *big.Int) *big.Int {
return x.SetInt64(0)
}, c.stack[i])
}
c.stack = c.stack[:0]
c.tmp = c.tmp[:0]
c.ret = c.ret[:0]
Expand Down Expand Up @@ -136,7 +145,10 @@ func (c *state) push1() *big.Int {
return c.stack[c.sp-1]
}

v := big.NewInt(0)
v := c.unsafepool.Get(func() *big.Int {
return big.NewInt(0)
})

c.stack = append(c.stack, v)
c.sp++

Expand Down Expand Up @@ -180,10 +192,6 @@ func (c *state) pop() *big.Int {
o := c.stack[c.sp-1]
c.sp--

if o.Cmp(zero) == 0 {
return big.NewInt(0)
}

return o
}

Expand Down Expand Up @@ -261,8 +269,9 @@ func (c *state) Run() ([]byte, error) {
// execute the instruction
inst.inst(c)

c.captureExecution(op.String(), ipCopy, gasCopy, gasCopy-c.gas)

if c.host.GetTracer() != nil {
c.captureExecution(op.String(), ipCopy, gasCopy, gasCopy-c.gas)
}
// check if stack size exceeds the max size
if c.sp > stackSize {
c.exit(&runtime.StackOverflowError{StackLen: c.sp, Limit: stackSize})
Expand Down
Loading