Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: add inplace decimal operations #11004

Merged
merged 7 commits into from
Feb 11, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 147 additions & 61 deletions types/decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,10 @@ func (d Dec) GTE(d2 Dec) bool { return (d.i).Cmp(d2.i) >= 0 } // greater
func (d Dec) LT(d2 Dec) bool { return (d.i).Cmp(d2.i) < 0 } // less than
func (d Dec) LTE(d2 Dec) bool { return (d.i).Cmp(d2.i) <= 0 } // less than or equal
func (d Dec) Neg() Dec { return Dec{new(big.Int).Neg(d.i)} } // reverse the decimal sign
func (d Dec) NegMut() Dec { d.i.Neg(d.i); return d } // reverse the decimal sign, mutable
func (d Dec) Abs() Dec { return Dec{new(big.Int).Abs(d.i)} } // absolute value
func (d Dec) Set(d2 Dec) Dec { d.i.Set(d2.i); return d } // set to existing dec value
func (d Dec) Clone() Dec { return Dec{new(big.Int).Set(d.i)} } // clone new dec

// BigInt returns a copy of the underlying big.Int.
func (d Dec) BigInt() *big.Int {
Expand All @@ -220,123 +223,191 @@ func (d Dec) BigInt() *big.Int {
return cp.Set(d.i)
}

func (d Dec) ImmutOp(op func(Dec, Dec) Dec, d2 Dec) Dec {
return op(d.Clone(), d2)
}

func (d Dec) ImmutOpInt(op func(Dec, Int) Dec, d2 Int) Dec {
return op(d.Clone(), d2)
}

func (d Dec) ImmutOpInt64(op func(Dec, int64) Dec, d2 int64) Dec {
// TODO: use already allocated operand bigint to avoid
// newint each time, add mutex for race condition
return op(d.Clone(), d2)
}

func (d Dec) SetInt64(i int64) Dec {
d.i.SetInt64(i)
d.i.Mul(d.i, precisionReuse)
return d
}

// addition
func (d Dec) Add(d2 Dec) Dec {
res := new(big.Int).Add(d.i, d2.i)
return d.ImmutOp(Dec.AddMut, d2)
}

if res.BitLen() > maxDecBitLen {
// mutable addition
func (d Dec) AddMut(d2 Dec) Dec {
d.i.Add(d.i, d2.i)

if d.i.BitLen() > maxDecBitLen {
panic("Int overflow")
}
return Dec{res}
return d
}

// subtraction
func (d Dec) Sub(d2 Dec) Dec {
res := new(big.Int).Sub(d.i, d2.i)
return d.ImmutOp(Dec.SubMut, d2)
}

if res.BitLen() > maxDecBitLen {
// mutable subtraction
func (d Dec) SubMut(d2 Dec) Dec {
d.i.Sub(d.i, d2.i)

if d.i.BitLen() > maxDecBitLen {
panic("Int overflow")
}
return Dec{res}
return d
}

// multiplication
func (d Dec) Mul(d2 Dec) Dec {
mul := new(big.Int).Mul(d.i, d2.i)
chopped := chopPrecisionAndRound(mul)
return d.ImmutOp(Dec.MulMut, d2)
}

// mutable multiplication
func (d Dec) MulMut(d2 Dec) Dec {
d.i.Mul(d.i, d2.i)
chopped := chopPrecisionAndRound(d.i)

if chopped.BitLen() > maxDecBitLen {
panic("Int overflow")
}
return Dec{chopped}
*d.i = *chopped
return d
}

// multiplication truncate
func (d Dec) MulTruncate(d2 Dec) Dec {
mul := new(big.Int).Mul(d.i, d2.i)
chopped := chopPrecisionAndTruncate(mul)
return d.ImmutOp(Dec.MulTruncateMut, d2)
}

if chopped.BitLen() > maxDecBitLen {
// mutable multiplication truncage
func (d Dec) MulTruncateMut(d2 Dec) Dec {
d.i.Mul(d.i, d2.i)
chopPrecisionAndTruncate(d.i)

if d.i.BitLen() > maxDecBitLen {
panic("Int overflow")
}
return Dec{chopped}
return d
}

// multiplication
func (d Dec) MulInt(i Int) Dec {
mul := new(big.Int).Mul(d.i, i.i)
return d.ImmutOpInt(Dec.MulIntMut, i)
}

if mul.BitLen() > maxDecBitLen {
func (d Dec) MulIntMut(i Int) Dec {
d.i.Mul(d.i, i.i)
if d.i.BitLen() > maxDecBitLen {
panic("Int overflow")
}
return Dec{mul}
return d
}

// MulInt64 - multiplication with int64
func (d Dec) MulInt64(i int64) Dec {
mul := new(big.Int).Mul(d.i, big.NewInt(i))
return d.ImmutOpInt64(Dec.MulInt64Mut, i)
}

func (d Dec) MulInt64Mut(i int64) Dec {
d.i.Mul(d.i, big.NewInt(i))

if mul.BitLen() > maxDecBitLen {
if d.i.BitLen() > maxDecBitLen {
panic("Int overflow")
}
return Dec{mul}
return d
}

// quotient
func (d Dec) Quo(d2 Dec) Dec {
// multiply precision twice
mul := new(big.Int).Mul(d.i, precisionReuse)
mul.Mul(mul, precisionReuse)
return d.ImmutOp(Dec.QuoMut, d2)
}

quo := new(big.Int).Quo(mul, d2.i)
chopped := chopPrecisionAndRound(quo)
// mutable quotient
func (d Dec) QuoMut(d2 Dec) Dec {
// multiply precision twice
d.i.Mul(d.i, precisionReuse)
d.i.Mul(d.i, precisionReuse)
d.i.Quo(d.i, d2.i)

if chopped.BitLen() > maxDecBitLen {
chopPrecisionAndRound(d.i)
if d.i.BitLen() > maxDecBitLen {
panic("Int overflow")
}
return Dec{chopped}
return d
}

// quotient truncate
func (d Dec) QuoTruncate(d2 Dec) Dec {
// multiply precision twice
mul := new(big.Int).Mul(d.i, precisionReuse)
mul.Mul(mul, precisionReuse)
return d.ImmutOp(Dec.QuoTruncateMut, d2)
}

quo := mul.Quo(mul, d2.i)
chopped := chopPrecisionAndTruncate(quo)
// mutable quotient truncate
func (d Dec) QuoTruncateMut(d2 Dec) Dec {
// multiply precision twice
d.i.Mul(d.i, precisionReuse)
d.i.Mul(d.i, precisionReuse)
d.i.Quo(d.i, d2.i)

if chopped.BitLen() > maxDecBitLen {
chopPrecisionAndTruncate(d.i)
if d.i.BitLen() > maxDecBitLen {
panic("Int overflow")
}
return Dec{chopped}
return d
}

// quotient, round up
func (d Dec) QuoRoundUp(d2 Dec) Dec {
// multiply precision twice
mul := new(big.Int).Mul(d.i, precisionReuse)
mul.Mul(mul, precisionReuse)
return d.ImmutOp(Dec.QuoRoundupMut, d2)
}

quo := new(big.Int).Quo(mul, d2.i)
chopped := chopPrecisionAndRoundUp(quo)
// mutable quotient, round up
func (d Dec) QuoRoundupMut(d2 Dec) Dec {
// multiply precision twice
d.i.Mul(d.i, precisionReuse)
d.i.Mul(d.i, precisionReuse)
d.i.Quo(d.i, d2.i)

if chopped.BitLen() > maxDecBitLen {
chopPrecisionAndRoundUp(d.i)
if d.i.BitLen() > maxDecBitLen {
panic("Int overflow")
}
return Dec{chopped}
return d
}

// quotient
func (d Dec) QuoInt(i Int) Dec {
mul := new(big.Int).Quo(d.i, i.i)
return Dec{mul}
return d.ImmutOpInt(Dec.QuoIntMut, i)
}

func (d Dec) QuoIntMut(i Int) Dec {
d.i.Quo(d.i, i.i)
return d
}

// QuoInt64 - quotient with int64
func (d Dec) QuoInt64(i int64) Dec {
mul := new(big.Int).Quo(d.i, big.NewInt(i))
return Dec{mul}
return d.ImmutOpInt64(Dec.QuoInt64Mut, i)
}

func (d Dec) QuoInt64Mut(i int64) Dec {
d.i.Quo(d.i, big.NewInt(i))
return d
}

// ApproxRoot returns an approximate estimation of a Dec's positive real nth root
Expand All @@ -357,8 +428,8 @@ func (d Dec) ApproxRoot(root uint64) (guess Dec, err error) {
}()

if d.IsNegative() {
absRoot, err := d.MulInt64(-1).ApproxRoot(root)
return absRoot.MulInt64(-1), err
absRoot, err := d.Neg().ApproxRoot(root)
return absRoot.NegMut(), err
}

if root == 1 || d.IsZero() || d.Equal(OneDec()) {
Expand All @@ -369,40 +440,45 @@ func (d Dec) ApproxRoot(root uint64) (guess Dec, err error) {
return OneDec(), nil
}

rootInt := NewIntFromUint64(root)
guess, delta := OneDec(), OneDec()

for iter := 0; delta.Abs().GT(SmallestDec()) && iter < maxApproxRootIterations; iter++ {
prev := guess.Power(root - 1)
if prev.IsZero() {
prev = SmallestDec()
}
delta = d.Quo(prev)
delta = delta.Sub(guess)
delta = delta.QuoInt(rootInt)
delta.Set(d).QuoMut(prev)
delta.SubMut(guess)
delta.QuoInt64Mut(int64(root))

guess = guess.Add(delta)
guess.AddMut(delta)
}

return guess, nil
}

// Power returns a the result of raising to a positive integer power
func (d Dec) Power(power uint64) Dec {
res := Dec{new(big.Int).Set(d.i)}
return res.PowerMut(power)
}

func (d Dec) PowerMut(power uint64) Dec {
// TODO: use mutable functions here
if power == 0 {
return OneDec()
}
tmp := OneDec()

for i := power; i > 1; {
if i%2 != 0 {
tmp = tmp.Mul(d)
tmp.MulMut(d)
}
i /= 2
d = d.Mul(d)
d.MulMut(d)
}

return d.Mul(tmp)
return d.MulMut(tmp)
}

// ApproxSqrt is a wrapper around ApproxRoot for the common special case
Expand Down Expand Up @@ -543,7 +619,7 @@ func chopPrecisionAndRoundUp(d *big.Int) *big.Int {
// make d positive, compute chopped value, and then un-mutate d
d = d.Neg(d)
// truncate since d is negative...
d = chopPrecisionAndTruncate(d)
chopPrecisionAndTruncate(d)
d = d.Neg(d)
return d
}
Expand Down Expand Up @@ -580,13 +656,19 @@ func (d Dec) RoundInt() Int {

// chopPrecisionAndTruncate is similar to chopPrecisionAndRound,
// but always rounds down. It does not mutate the input.
func chopPrecisionAndTruncate(d *big.Int) *big.Int {
return new(big.Int).Quo(d, precisionReuse)
func chopPrecisionAndTruncate(d *big.Int) {
d.Quo(d, precisionReuse)
}

func chopPrecisionAndTruncateNonMutative(d *big.Int) *big.Int {
tmp := new(big.Int).Set(d)
chopPrecisionAndTruncate(tmp)
return tmp
}

// TruncateInt64 truncates the decimals from the number and returns an int64
func (d Dec) TruncateInt64() int64 {
chopped := chopPrecisionAndTruncate(d.i)
chopped := chopPrecisionAndTruncateNonMutative(d.i)
if !chopped.IsInt64() {
panic("Int64() out of bound")
}
Expand All @@ -595,12 +677,12 @@ func (d Dec) TruncateInt64() int64 {

// TruncateInt truncates the decimals from the number and returns an Int
func (d Dec) TruncateInt() Int {
return NewIntFromBigInt(chopPrecisionAndTruncate(d.i))
return NewIntFromBigInt(chopPrecisionAndTruncateNonMutative(d.i))
}

// TruncateDec truncates the decimals from the number and returns a Dec
func (d Dec) TruncateDec() Dec {
return NewDecFromBigInt(chopPrecisionAndTruncate(d.i))
return NewDecFromBigInt(chopPrecisionAndTruncateNonMutative(d.i))
}

// Ceil returns the smallest interger value (as a decimal) that is greater than
Expand All @@ -625,7 +707,11 @@ func (d Dec) Ceil() Dec {

// MaxSortableDec is the largest Dec that can be passed into SortableDecBytes()
// Its negative form is the least Dec that can be passed in.
var MaxSortableDec = OneDec().Quo(SmallestDec())
var MaxSortableDec Dec

func init() {
MaxSortableDec = OneDec().Quo(SmallestDec())
}
Comment on lines +711 to +715
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not-blocking, but why didn't we keep the old one-liner?


// ValidSortableDec ensures that a Dec is within the sortable bounds,
// a Dec can't have a precision of less than 10^-18.
Expand Down
3 changes: 3 additions & 0 deletions types/decimal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,9 @@ func (s *decimalTestSuite) TestPower() {
for i, tc := range testCases {
res := tc.input.Power(tc.power)
s.Require().True(tc.expected.Sub(res).Abs().LTE(sdk.SmallestDec()), "unexpected result for test case %d, input: %v", i, tc.input)
s.Require().True(tc.expected.Sub(tc.input.PowerMut(tc.power)).Abs().LTE(sdk.SmallestDec()),
"unexpected result for test case %d, input %v", i, tc.input)
s.Require().True(res.Equal(tc.input), "unexpected result for test case %d, input: %v", i, tc.input)
}
}

Expand Down