diff --git a/CHANGELOG.md b/CHANGELOG.md index d3a551910..2b03feb20 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -58,6 +58,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * [#1223](https://github.com/NibiruChain/nibiru/pull/1223) - chore(deps): bump github.com/golang/protobuf from 1.5.2 to 1.5.3 * [#1205](https://github.com/NibiruChain/nibiru/pull/1205) - test: first testing framework skeleton and example * [#1228](https://github.com/NibiruChain/nibiru/pull/1228) - feat: update github.com/CosmWasm/wasmd 0.29.2 +* [#1237](https://github.com/NibiruChain/nibiru/pull/1237) - feat: reduce gas on openposition ### Bug Fixes diff --git a/x/common/testutil/mock/perp_interfaces.go b/x/common/testutil/mock/perp_interfaces.go index 25b0d7f02..528260eba 100644 --- a/x/common/testutil/mock/perp_interfaces.go +++ b/x/common/testutil/mock/perp_interfaces.go @@ -305,18 +305,18 @@ func (mr *MockVpoolKeeperMockRecorder) GetAllPools(arg0 interface{}) *gomock.Cal } // GetBaseAssetPrice mocks base method. -func (m *MockVpoolKeeper) GetBaseAssetPrice(arg0 types1.Context, arg1 asset.Pair, arg2 types0.Direction, arg3 types1.Dec) (types1.Dec, error) { +func (m *MockVpoolKeeper) GetBaseAssetPrice(arg0 types0.Vpool, arg1 types0.Direction, arg2 types1.Dec) (types1.Dec, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetBaseAssetPrice", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "GetBaseAssetPrice", arg0, arg1, arg2) ret0, _ := ret[0].(types1.Dec) ret1, _ := ret[1].(error) return ret0, ret1 } // GetBaseAssetPrice indicates an expected call of GetBaseAssetPrice. -func (mr *MockVpoolKeeperMockRecorder) GetBaseAssetPrice(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockVpoolKeeperMockRecorder) GetBaseAssetPrice(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBaseAssetPrice", reflect.TypeOf((*MockVpoolKeeper)(nil).GetBaseAssetPrice), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBaseAssetPrice", reflect.TypeOf((*MockVpoolKeeper)(nil).GetBaseAssetPrice), arg0, arg1, arg2) } // GetBaseAssetTWAP mocks base method. @@ -394,19 +394,19 @@ func (mr *MockVpoolKeeperMockRecorder) GetMarkPriceTWAP(arg0, arg1, arg2 interfa return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMarkPriceTWAP", reflect.TypeOf((*MockVpoolKeeper)(nil).GetMarkPriceTWAP), arg0, arg1, arg2) } -// GetMaxLeverage mocks base method. -func (m *MockVpoolKeeper) GetMaxLeverage(arg0 types1.Context, arg1 asset.Pair) (types1.Dec, error) { +// GetPool mocks base method. +func (m *MockVpoolKeeper) GetPool(arg0 types1.Context, arg1 asset.Pair) (types0.Vpool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetMaxLeverage", arg0, arg1) - ret0, _ := ret[0].(types1.Dec) + ret := m.ctrl.Call(m, "GetPool", arg0, arg1) + ret0, _ := ret[0].(types0.Vpool) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetMaxLeverage indicates an expected call of GetMaxLeverage. -func (mr *MockVpoolKeeperMockRecorder) GetMaxLeverage(arg0, arg1 interface{}) *gomock.Call { +// GetPool indicates an expected call of GetPool. +func (mr *MockVpoolKeeperMockRecorder) GetPool(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMaxLeverage", reflect.TypeOf((*MockVpoolKeeper)(nil).GetMaxLeverage), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPool", reflect.TypeOf((*MockVpoolKeeper)(nil).GetPool), arg0, arg1) } // GetQuoteAssetPrice mocks base method. @@ -455,12 +455,13 @@ func (mr *MockVpoolKeeperMockRecorder) IsOverSpreadLimit(arg0, arg1 interface{}) } // SwapBaseForQuote mocks base method. -func (m *MockVpoolKeeper) SwapBaseForQuote(arg0 types1.Context, arg1 asset.Pair, arg2 types0.Direction, arg3, arg4 types1.Dec, arg5 bool) (types1.Dec, error) { +func (m *MockVpoolKeeper) SwapBaseForQuote(arg0 types1.Context, arg1 types0.Vpool, arg2 types0.Direction, arg3, arg4 types1.Dec, arg5 bool) (types0.Vpool, types1.Dec, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SwapBaseForQuote", arg0, arg1, arg2, arg3, arg4, arg5) - ret0, _ := ret[0].(types1.Dec) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret0, _ := ret[0].(types0.Vpool) + ret1, _ := ret[1].(types1.Dec) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // SwapBaseForQuote indicates an expected call of SwapBaseForQuote. @@ -470,12 +471,13 @@ func (mr *MockVpoolKeeperMockRecorder) SwapBaseForQuote(arg0, arg1, arg2, arg3, } // SwapQuoteForBase mocks base method. -func (m *MockVpoolKeeper) SwapQuoteForBase(arg0 types1.Context, arg1 asset.Pair, arg2 types0.Direction, arg3, arg4 types1.Dec, arg5 bool) (types1.Dec, error) { +func (m *MockVpoolKeeper) SwapQuoteForBase(arg0 types1.Context, arg1 types0.Vpool, arg2 types0.Direction, arg3, arg4 types1.Dec, arg5 bool) (types0.Vpool, types1.Dec, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SwapQuoteForBase", arg0, arg1, arg2, arg3, arg4, arg5) - ret0, _ := ret[0].(types1.Dec) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret0, _ := ret[0].(types0.Vpool) + ret1, _ := ret[1].(types1.Dec) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // SwapQuoteForBase indicates an expected call of SwapQuoteForBase. diff --git a/x/perp/keeper/calc.go b/x/perp/keeper/calc.go index b63897773..591e25d29 100644 --- a/x/perp/keeper/calc.go +++ b/x/perp/keeper/calc.go @@ -3,6 +3,8 @@ package keeper import ( "fmt" + vpooltypes "github.com/NibiruChain/nibiru/x/vpool/types" + sdk "github.com/cosmos/cosmos-sdk/types" "github.com/NibiruChain/nibiru/x/common/asset" @@ -86,19 +88,16 @@ position without making it go underwater. - err: error */ func (k Keeper) calcFreeCollateral( - ctx sdk.Context, pos types.Position, + ctx sdk.Context, vpool vpooltypes.Vpool, pos types.Position, ) (freeCollateral sdk.Dec, err error) { if err = pos.Pair.Validate(); err != nil { return } - if err = k.requireVpool(ctx, pos.Pair); err != nil { - return - } - positionNotional, unrealizedPnL, err := k. GetPreferencePositionNotionalAndUnrealizedPnL( ctx, + vpool, pos, types.PnLPreferenceOption_MIN, ) diff --git a/x/perp/keeper/calc_unit_test.go b/x/perp/keeper/calc_unit_test.go index 05fd6ef97..8d72c69a0 100644 --- a/x/perp/keeper/calc_unit_test.go +++ b/x/perp/keeper/calc_unit_test.go @@ -15,67 +15,6 @@ import ( vpooltypes "github.com/NibiruChain/nibiru/x/vpool/types" ) -func TestCalcFreeCollateralErrors(t *testing.T) { - testCases := []struct { - name string - test func() - }{ - { - name: "invalid token pair - error", - test: func() { - k, _, ctx := getKeeper(t) - alice := testutil.AccAddress() - pos := types.ZeroPosition(ctx, asset.Pair("foobar"), alice) - _, err := k.calcFreeCollateral(ctx, pos) - - require.Error(t, err) - require.ErrorIs(t, err, asset.ErrInvalidTokenPair) - }, - }, - { - name: "token pair not found - error", - test: func() { - k, mocks, ctx := getKeeper(t) - - mocks.mockVpoolKeeper.EXPECT().ExistsPool(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)).Return(false) - - pos := types.ZeroPosition(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD), testutil.AccAddress()) - - _, err := k.calcFreeCollateral(ctx, pos) - - require.Error(t, err) - require.ErrorIs(t, err, types.ErrPairNotFound) - }, - }, - { - name: "zero position", - test: func() { - k, mocks, ctx := getKeeper(t) - - mocks.mockVpoolKeeper.EXPECT(). - ExistsPool(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)).Return(true) - mocks.mockVpoolKeeper.EXPECT(). - GetMaintenanceMarginRatio(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)). - Return(sdk.MustNewDecFromStr("0.0625"), nil) - - pos := types.ZeroPosition(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD), testutil.AccAddress()) - - freeCollateral, err := k.calcFreeCollateral(ctx, pos) - - require.NoError(t, err) - assert.EqualValues(t, sdk.ZeroDec(), freeCollateral) - }, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - tc.test() - }) - } -} - func TestCalcFreeCollateralSuccess(t *testing.T) { testCases := []struct { name string @@ -148,6 +87,7 @@ func TestCalcFreeCollateralSuccess(t *testing.T) { t.Run(tc.name, func(t *testing.T) { k, mocks, ctx := getKeeper(t) + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} pos := types.Position{ TraderAddress: testutil.AccAddress().String(), Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD), @@ -159,13 +99,11 @@ func TestCalcFreeCollateralSuccess(t *testing.T) { } t.Log("mock vpool keeper") - mocks.mockVpoolKeeper.EXPECT().ExistsPool(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)).Return(true) mocks.mockVpoolKeeper.EXPECT(). GetMaintenanceMarginRatio(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)). Return(sdk.MustNewDecFromStr("0.0625"), nil) mocks.mockVpoolKeeper.EXPECT().GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, tc.vpoolDirection, sdk.OneDec(), ).Return(tc.positionNotional, nil) @@ -177,7 +115,7 @@ func TestCalcFreeCollateralSuccess(t *testing.T) { 15*time.Minute, ).Return(tc.positionNotional, nil) - freeCollateral, err := k.calcFreeCollateral(ctx, pos) + freeCollateral, err := k.calcFreeCollateral(ctx, vpool, pos) require.NoError(t, err) assert.EqualValues(t, tc.expectedFreeCollateral, freeCollateral) diff --git a/x/perp/keeper/clearing_house.go b/x/perp/keeper/clearing_house.go index 484722345..011a3299f 100644 --- a/x/perp/keeper/clearing_house.go +++ b/x/perp/keeper/clearing_house.go @@ -5,7 +5,6 @@ import ( "fmt" "github.com/NibiruChain/collections" - sdk "github.com/cosmos/cosmos-sdk/types" "github.com/NibiruChain/nibiru/x/common/asset" @@ -38,30 +37,33 @@ func (k Keeper) OpenPosition( leverage sdk.Dec, baseAmtLimit sdk.Dec, ) (positionResp *types.PositionResp, err error) { - err = k.checkOpenPositionRequirements(ctx, pair, quoteAssetAmount, leverage) + vpool, err := k.VpoolKeeper.GetPool(ctx, pair) if err != nil { - return nil, err + return nil, types.ErrPairNotFound } - // require params - params := k.GetParams(ctx) + err = k.checkOpenPositionRequirements(vpool, quoteAssetAmount, leverage) + if err != nil { + return nil, err + } position, err := k.Positions.Get(ctx, collections.Join(pair, traderAddr)) isNewPosition := errors.Is(err, collections.ErrNotFound) if isNewPosition { position = types.ZeroPosition(ctx, pair, traderAddr) - k.Positions.Insert(ctx, collections.Join(pair, traderAddr), position) } else if err != nil && !isNewPosition { return nil, err } sameSideLong := position.Size_.IsPositive() && side == types.Side_BUY sameSideShort := position.Size_.IsNegative() && side == types.Side_SELL + + var updatedVpool vpooltypes.Vpool var openSideMatchesPosition = sameSideLong || sameSideShort if isNewPosition || openSideMatchesPosition { - // increase position case - positionResp, err = k.increasePosition( + updatedVpool, positionResp, err = k.increasePosition( ctx, + vpool, position, side, /* openNotional */ leverage.MulInt(quoteAssetAmount), @@ -71,9 +73,9 @@ func (k Keeper) OpenPosition( return nil, err } } else { - // everything else decreases the position - positionResp, err = k.openReversePosition( + updatedVpool, positionResp, err = k.openReversePosition( ctx, + vpool, position, /* quoteAssetAmount */ quoteAssetAmount.ToDec(), /* leverage */ leverage, @@ -84,7 +86,7 @@ func (k Keeper) OpenPosition( } } - if err = k.afterPositionUpdate(ctx, pair, traderAddr, params, isNewPosition, *positionResp); err != nil { + if err = k.afterPositionUpdate(ctx, updatedVpool, traderAddr, *positionResp); err != nil { return nil, err } @@ -97,16 +99,7 @@ func (k Keeper) OpenPosition( // - Checks that quote asset is not zero. // - Checks that leverage is not zero. // - Checks that leverage is below requirement. -func (k Keeper) checkOpenPositionRequirements( - ctx sdk.Context, - pair asset.Pair, - quoteAssetAmount sdk.Int, - leverage sdk.Dec, -) error { - if err := k.requireVpool(ctx, pair); err != nil { - return err - } - +func (k Keeper) checkOpenPositionRequirements(vpool vpooltypes.Vpool, quoteAssetAmount sdk.Int, leverage sdk.Dec) error { if quoteAssetAmount.IsZero() { return types.ErrQuoteAmountIsZero } @@ -115,11 +108,7 @@ func (k Keeper) checkOpenPositionRequirements( return types.ErrLeverageIsZero } - maxLeverage, err := k.VpoolKeeper.GetMaxLeverage(ctx, pair) - if err != nil { - return err - } - if leverage.GT(maxLeverage) { + if leverage.GT(vpool.Config.MaxLeverage) { return types.ErrLeverageIsTooHigh } @@ -129,13 +118,11 @@ func (k Keeper) checkOpenPositionRequirements( // afterPositionUpdate is called when a position has been updated. func (k Keeper) afterPositionUpdate( ctx sdk.Context, - pair asset.Pair, + vpool vpooltypes.Vpool, traderAddr sdk.AccAddress, - params types.Params, - isNewPosition bool, positionResp types.PositionResp, ) (err error) { - // update position in state + pair := vpool.Pair if !positionResp.Position.Size_.IsZero() { k.Positions.Insert(ctx, collections.Join(pair, traderAddr), *positionResp.Position) } @@ -147,6 +134,7 @@ func (k Keeper) afterPositionUpdate( if !positionResp.Position.Size_.IsZero() { marginRatio, err := k.GetMarginRatio( ctx, + vpool, *positionResp.Position, types.MarginCalculationPriceOption_MAX_PNL, ) @@ -154,10 +142,7 @@ func (k Keeper) afterPositionUpdate( return err } - maintenanceMarginRatio, err := k.VpoolKeeper.GetMaintenanceMarginRatio(ctx, pair) - if err != nil { - return err - } + maintenanceMarginRatio := vpool.Config.MaintenanceMarginRatio if err = validateMarginRatio(marginRatio, maintenanceMarginRatio, true); err != nil { return types.ErrMarginRatioTooLow } @@ -191,7 +176,7 @@ func (k Keeper) afterPositionUpdate( // calculate positionNotional (it's different depends on long or short side) // long: unrealizedPnl = positionNotional - openNotional => positionNotional = openNotional + unrealizedPnl // short: unrealizedPnl = openNotional - positionNotional => positionNotional = openNotional - unrealizedPnl - var positionNotional sdk.Dec = sdk.ZeroDec() + positionNotional := sdk.ZeroDec() if positionResp.Position.Size_.IsPositive() { positionNotional = positionResp.Position.OpenNotional.Add(positionResp.UnrealizedPnlAfter) } else if positionResp.Position.Size_.IsNegative() { @@ -246,24 +231,25 @@ ret: */ func (k Keeper) increasePosition( ctx sdk.Context, + vpool vpooltypes.Vpool, currentPosition types.Position, side types.Side, increasedNotional sdk.Dec, baseAmtLimit sdk.Dec, leverage sdk.Dec, -) (positionResp *types.PositionResp, err error) { +) (updatedVpool vpooltypes.Vpool, positionResp *types.PositionResp, err error) { positionResp = &types.PositionResp{} - positionResp.ExchangedPositionSize, err = k.swapQuoteForBase( + updatedVpool, positionResp.ExchangedPositionSize, err = k.swapQuoteForBase( ctx, - currentPosition.Pair, + vpool, side, increasedNotional, baseAmtLimit, /* skipFluctuationLimitCheck */ false, ) if err != nil { - return nil, err + return vpooltypes.Vpool{}, nil, err } increaseMarginRequirement := increasedNotional.Quo(leverage) @@ -274,16 +260,17 @@ func (k Keeper) increasePosition( increaseMarginRequirement, ) if err != nil { - return nil, err + return vpooltypes.Vpool{}, nil, err } positionNotional, unrealizedPnL, err := k.getPositionNotionalAndUnrealizedPnL( ctx, + updatedVpool, currentPosition, types.PnLCalcOption_SPOT_PRICE, ) if err != nil { - return nil, err + return vpooltypes.Vpool{}, nil, err } positionResp.ExchangedNotionalValue = increasedNotional @@ -303,31 +290,34 @@ func (k Keeper) increasePosition( BlockNumber: ctx.BlockHeight(), } - return positionResp, nil + return updatedVpool, positionResp, nil } // TODO test: openReversePosition | https://github.com/NibiruChain/nibiru/issues/299 func (k Keeper) openReversePosition( ctx sdk.Context, + vpool vpooltypes.Vpool, currentPosition types.Position, quoteAssetAmount sdk.Dec, leverage sdk.Dec, baseAmtLimit sdk.Dec, -) (positionResp *types.PositionResp, err error) { +) (updatedVpool vpooltypes.Vpool, positionResp *types.PositionResp, err error) { notionalToDecreaseBy := leverage.Mul(quoteAssetAmount) currentPositionNotional, _, err := k.getPositionNotionalAndUnrealizedPnL( ctx, + vpool, currentPosition, types.PnLCalcOption_SPOT_PRICE, ) if err != nil { - return nil, err + return vpooltypes.Vpool{}, nil, err } if currentPositionNotional.GT(notionalToDecreaseBy) { // position reduction return k.decreasePosition( ctx, + vpool, currentPosition, notionalToDecreaseBy, baseAmtLimit, @@ -337,6 +327,7 @@ func (k Keeper) openReversePosition( // close and reverse return k.closeAndOpenReversePosition( ctx, + vpool, currentPosition, quoteAssetAmount, leverage, @@ -368,13 +359,14 @@ ret: */ func (k Keeper) decreasePosition( ctx sdk.Context, + vpool vpooltypes.Vpool, currentPosition types.Position, decreasedNotional sdk.Dec, baseAmtLimit sdk.Dec, skipFluctuationLimitCheck bool, -) (positionResp *types.PositionResp, err error) { +) (updatedVpool vpooltypes.Vpool, positionResp *types.PositionResp, err error) { if currentPosition.Size_.IsZero() { - return nil, fmt.Errorf("current position size is zero, nothing to decrease") + return vpooltypes.Vpool{}, nil, fmt.Errorf("current position size is zero, nothing to decrease") } positionResp = &types.PositionResp{ @@ -384,11 +376,12 @@ func (k Keeper) decreasePosition( currentPositionNotional, currentUnrealizedPnL, err := k. getPositionNotionalAndUnrealizedPnL( ctx, + vpool, currentPosition, types.PnLCalcOption_SPOT_PRICE, ) if err != nil { - return nil, err + return vpooltypes.Vpool{}, nil, err } var sideToTake types.Side @@ -399,16 +392,16 @@ func (k Keeper) decreasePosition( sideToTake = types.Side_BUY } - positionResp.ExchangedPositionSize, err = k.swapQuoteForBase( + updatedVpool, positionResp.ExchangedPositionSize, err = k.swapQuoteForBase( ctx, - currentPosition.Pair, + vpool, sideToTake, decreasedNotional, baseAmtLimit, skipFluctuationLimitCheck, ) if err != nil { - return nil, err + return vpooltypes.Vpool{}, nil, err } positionResp.RealizedPnl = currentUnrealizedPnL.Mul( @@ -422,7 +415,7 @@ func (k Keeper) decreasePosition( positionResp.RealizedPnl, ) if err != nil { - return nil, err + return vpooltypes.Vpool{}, nil, err } positionResp.BadDebt = remaining.BadDebt @@ -443,7 +436,7 @@ func (k Keeper) decreasePosition( } if remainOpenNotional.IsNegative() { - return nil, fmt.Errorf("value of open notional < 0") + return vpooltypes.Vpool{}, nil, fmt.Errorf("value of open notional < 0") } positionResp.Position = &types.Position{ @@ -456,7 +449,7 @@ func (k Keeper) decreasePosition( BlockNumber: ctx.BlockHeight(), } - return positionResp, nil + return updatedVpool, positionResp, nil } /* @@ -479,37 +472,40 @@ ret: */ func (k Keeper) closeAndOpenReversePosition( ctx sdk.Context, + vpool vpooltypes.Vpool, existingPosition types.Position, quoteAssetAmount sdk.Dec, leverage sdk.Dec, baseAmtLimit sdk.Dec, -) (positionResp *types.PositionResp, err error) { +) (updatedVpool vpooltypes.Vpool, positionResp *types.PositionResp, err error) { trader, err := sdk.AccAddressFromBech32(existingPosition.TraderAddress) if err != nil { - return nil, err + return vpooltypes.Vpool{}, nil, err } - closePositionResp, err := k.closePositionEntirely( + updatedVpool, closePositionResp, err := k.closePositionEntirely( ctx, + vpool, existingPosition, /* quoteAssetAmountLimit */ sdk.ZeroDec(), /* skipFluctuationLimitCheck */ false, ) if err != nil { - return nil, err + return vpooltypes.Vpool{}, nil, err } if closePositionResp.BadDebt.IsPositive() { - return nil, fmt.Errorf("underwater position") + return vpooltypes.Vpool{}, nil, fmt.Errorf("underwater position") } reverseNotionalValue := leverage.Mul(quoteAssetAmount) remainingReverseNotionalValue := reverseNotionalValue.Sub( closePositionResp.ExchangedNotionalValue) + var increasePositionResp *types.PositionResp if remainingReverseNotionalValue.IsNegative() { // should never happen as this should also be checked in the caller - return nil, fmt.Errorf( + return vpooltypes.Vpool{}, nil, fmt.Errorf( "provided quote asset amount and leverage not large enough to close position. need %s but got %s", closePositionResp.ExchangedNotionalValue.String(), reverseNotionalValue.String()) } else if remainingReverseNotionalValue.IsPositive() { @@ -518,7 +514,7 @@ func (k Keeper) closeAndOpenReversePosition( updatedBaseAmtLimit = baseAmtLimit.Sub(closePositionResp.ExchangedPositionSize.Abs()) } if updatedBaseAmtLimit.IsNegative() { - return nil, fmt.Errorf( + return vpooltypes.Vpool{}, nil, fmt.Errorf( "position size changed by greater than the specified base limit: %s", baseAmtLimit.String(), ) @@ -537,8 +533,9 @@ func (k Keeper) closeAndOpenReversePosition( existingPosition.Pair, trader, ) - increasePositionResp, err := k.increasePosition( + updatedVpool, increasePositionResp, err = k.increasePosition( ctx, + updatedVpool, newPosition, sideToTake, remainingReverseNotionalValue, @@ -546,7 +543,7 @@ func (k Keeper) closeAndOpenReversePosition( leverage, ) if err != nil { - return nil, err + return vpooltypes.Vpool{}, nil, err } positionResp = &types.PositionResp{ @@ -565,7 +562,7 @@ func (k Keeper) closeAndOpenReversePosition( positionResp = closePositionResp } - return positionResp, nil + return updatedVpool, positionResp, nil } /* @@ -584,17 +581,18 @@ ret: */ func (k Keeper) closePositionEntirely( ctx sdk.Context, + vpool vpooltypes.Vpool, currentPosition types.Position, quoteAssetAmountLimit sdk.Dec, skipFluctuationLimitCheck bool, -) (positionResp *types.PositionResp, err error) { +) (updatedVpool vpooltypes.Vpool, positionResp *types.PositionResp, err error) { if currentPosition.Size_.IsZero() { - return nil, fmt.Errorf("zero position size") + return vpooltypes.Vpool{}, nil, fmt.Errorf("zero position size") } trader, err := sdk.AccAddressFromBech32(currentPosition.TraderAddress) if err != nil { - return nil, err + return vpooltypes.Vpool{}, nil, err } positionResp = &types.PositionResp{ @@ -606,11 +604,12 @@ func (k Keeper) closePositionEntirely( // calculate unrealized PnL _, unrealizedPnL, err := k.getPositionNotionalAndUnrealizedPnL( ctx, + vpool, currentPosition, types.PnLCalcOption_SPOT_PRICE, ) if err != nil { - return nil, err + return vpooltypes.Vpool{}, nil, err } positionResp.RealizedPnl = unrealizedPnL @@ -618,7 +617,7 @@ func (k Keeper) closePositionEntirely( remaining, err := k.CalcRemainMarginWithFundingPayment( ctx, currentPosition, unrealizedPnL) if err != nil { - return nil, err + return vpooltypes.Vpool{}, nil, err } positionResp.BadDebt = remaining.BadDebt @@ -632,16 +631,16 @@ func (k Keeper) closePositionEntirely( } else { sideToTake = types.Side_BUY } - exchangedNotionalValue, err := k.swapBaseForQuote( + updatedVpool, exchangedNotionalValue, err := k.swapBaseForQuote( ctx, - currentPosition.Pair, + vpool, sideToTake, currentPosition.Size_.Abs(), quoteAssetAmountLimit, skipFluctuationLimitCheck, ) if err != nil { - return nil, err + return vpooltypes.Vpool{}, nil, err } positionResp.ExchangedNotionalValue = exchangedNotionalValue @@ -657,10 +656,10 @@ func (k Keeper) closePositionEntirely( err = k.Positions.Delete(ctx, collections.Join(currentPosition.Pair, trader)) if err != nil { - return nil, err + return vpooltypes.Vpool{}, nil, err } - return positionResp, nil + return updatedVpool, positionResp, nil } /* @@ -682,8 +681,14 @@ func (k Keeper) ClosePosition(ctx sdk.Context, pair asset.Pair, traderAddr sdk.A return nil, err } - positionResp, err := k.closePositionEntirely( + vpool, err := k.VpoolKeeper.GetPool(ctx, pair) + if err != nil { + return nil, types.ErrPairNotFound + } + + updatedVpool, positionResp, err := k.closePositionEntirely( ctx, + vpool, position, /* quoteAssetAmountLimit */ sdk.ZeroDec(), /* skipFluctuationLimitCheck */ false, @@ -698,10 +703,8 @@ func (k Keeper) ClosePosition(ctx sdk.Context, pair asset.Pair, traderAddr sdk.A if err = k.afterPositionUpdate( ctx, - pair, + updatedVpool, traderAddr, - k.GetParams(ctx), - /* isNewPosition */ false, *positionResp, ); err != nil { return nil, err @@ -773,12 +776,12 @@ ret: */ func (k Keeper) swapQuoteForBase( ctx sdk.Context, - pair asset.Pair, + vpool vpooltypes.Vpool, side types.Side, quoteAssetAmount sdk.Dec, baseAssetLimit sdk.Dec, skipFluctuationLimitCheck bool, -) (baseAssetAmount sdk.Dec, err error) { +) (updatedVpool vpooltypes.Vpool, baseAssetAmount sdk.Dec, err error) { var quoteAssetDirection vpooltypes.Direction if side == types.Side_BUY { quoteAssetDirection = vpooltypes.Direction_ADD_TO_POOL @@ -787,16 +790,16 @@ func (k Keeper) swapQuoteForBase( quoteAssetDirection = vpooltypes.Direction_REMOVE_FROM_POOL } - baseAssetAmount, err = k.VpoolKeeper.SwapQuoteForBase( - ctx, pair, quoteAssetDirection, quoteAssetAmount, baseAssetLimit, skipFluctuationLimitCheck) + updatedVpool, baseAssetAmount, err = k.VpoolKeeper.SwapQuoteForBase( + ctx, vpool, quoteAssetDirection, quoteAssetAmount, baseAssetLimit, skipFluctuationLimitCheck) if err != nil { - return sdk.Dec{}, err + return vpooltypes.Vpool{}, sdk.Dec{}, err } if side == types.Side_SELL { baseAssetAmount = baseAssetAmount.Neg() } - k.OnSwapEnd(ctx, pair, quoteAssetAmount, baseAssetAmount) - return baseAssetAmount, nil + k.OnSwapEnd(ctx, vpool.Pair, quoteAssetAmount, baseAssetAmount) + return updatedVpool, baseAssetAmount, nil } /* @@ -818,12 +821,12 @@ ret: */ func (k Keeper) swapBaseForQuote( ctx sdk.Context, - pair asset.Pair, + vpool vpooltypes.Vpool, side types.Side, baseAssetAmount sdk.Dec, quoteAssetLimit sdk.Dec, skipFluctuationLimitCheck bool, -) (baseAmount sdk.Dec, err error) { +) (updatedVpool vpooltypes.Vpool, baseAmount sdk.Dec, err error) { var baseAssetDirection vpooltypes.Direction if side == types.Side_SELL { baseAssetDirection = vpooltypes.Direction_ADD_TO_POOL @@ -831,16 +834,19 @@ func (k Keeper) swapBaseForQuote( // side == types.Side_BUY baseAssetDirection = vpooltypes.Direction_REMOVE_FROM_POOL } - quoteAssetAmount, err := k.VpoolKeeper.SwapBaseForQuote( - ctx, pair, baseAssetDirection, baseAssetAmount, quoteAssetLimit, skipFluctuationLimitCheck) + + updatedVpool, quoteAssetAmount, err := k.VpoolKeeper.SwapBaseForQuote( + ctx, vpool, baseAssetDirection, baseAssetAmount, quoteAssetLimit, skipFluctuationLimitCheck) if err != nil { - return sdk.Dec{}, err + return vpooltypes.Vpool{}, sdk.Dec{}, err } + if side == types.Side_SELL { baseAssetAmount = baseAssetAmount.Neg() } - k.OnSwapEnd(ctx, pair, quoteAssetAmount, baseAssetAmount) - return quoteAssetAmount, err + + k.OnSwapEnd(ctx, vpool.Pair, quoteAssetAmount, baseAssetAmount) + return updatedVpool, quoteAssetAmount, err } // OnSwapEnd recalculates perp metrics for a particular pair. diff --git a/x/perp/keeper/clearing_house_integration_test.go b/x/perp/keeper/clearing_house_integration_test.go index 2fdaf29e6..a656b5039 100644 --- a/x/perp/keeper/clearing_house_integration_test.go +++ b/x/perp/keeper/clearing_house_integration_test.go @@ -48,7 +48,6 @@ func TestOpenPosition(t *testing.T) { ts := NewTestSuite(t) alice := testutil.AccAddress() - pairBtcUsdc := asset.Registry.Pair(denoms.BTC, denoms.USDC) startBlockTime := time.Now() @@ -443,32 +442,32 @@ func TestOpenPositionSuccess(t *testing.T) { require.NoError(t, err) t.Log("assert position response") - assert.EqualValues(t, asset.Registry.Pair(denoms.BTC, denoms.NUSD), resp.Position.Pair) - assert.EqualValues(t, traderAddr.String(), resp.Position.TraderAddress) - assert.EqualValues(t, tc.expectedMargin, resp.Position.Margin, "margin") - assert.EqualValues(t, tc.expectedOpenNotional, resp.Position.OpenNotional, "open notional") - assert.EqualValues(t, tc.expectedSize, resp.Position.Size_, "position size") - assert.EqualValues(t, ctx.BlockHeight(), resp.Position.BlockNumber) - assert.EqualValues(t, sdk.ZeroDec(), resp.Position.LatestCumulativePremiumFraction) - assert.EqualValues(t, tc.leverage.MulInt(tc.margin), resp.ExchangedNotionalValue) - assert.EqualValues(t, exchangedSize, resp.ExchangedPositionSize) - assert.EqualValues(t, sdk.ZeroDec(), resp.BadDebt) - assert.EqualValues(t, sdk.ZeroDec(), resp.FundingPayment) - assert.EqualValues(t, tc.expectedRealizedPnl, resp.RealizedPnl) - assert.EqualValues(t, tc.expectedUnrealizedPnl, resp.UnrealizedPnlAfter) - assert.EqualValues(t, tc.expectedMarginToVault, resp.MarginToVault) - assert.EqualValues(t, tc.expectedPositionNotional, resp.PositionNotional) + require.EqualValues(t, asset.Registry.Pair(denoms.BTC, denoms.NUSD), resp.Position.Pair) + require.EqualValues(t, traderAddr.String(), resp.Position.TraderAddress) + require.EqualValues(t, tc.expectedMargin, resp.Position.Margin, "margin") + require.EqualValues(t, tc.expectedOpenNotional, resp.Position.OpenNotional, "open notional") + require.EqualValues(t, tc.expectedSize, resp.Position.Size_, "position size") + require.EqualValues(t, ctx.BlockHeight(), resp.Position.BlockNumber) + require.EqualValues(t, sdk.ZeroDec(), resp.Position.LatestCumulativePremiumFraction) + require.EqualValues(t, tc.leverage.MulInt(tc.margin), resp.ExchangedNotionalValue) + require.EqualValues(t, exchangedSize, resp.ExchangedPositionSize) + require.EqualValues(t, sdk.ZeroDec(), resp.BadDebt) + require.EqualValues(t, sdk.ZeroDec(), resp.FundingPayment) + require.EqualValues(t, tc.expectedRealizedPnl, resp.RealizedPnl) + require.EqualValues(t, tc.expectedUnrealizedPnl, resp.UnrealizedPnlAfter) + require.EqualValues(t, tc.expectedMarginToVault, resp.MarginToVault) + require.EqualValues(t, tc.expectedPositionNotional, resp.PositionNotional) t.Log("assert position in state") position, err := nibiruApp.PerpKeeper.Positions.Get(ctx, collections.Join(asset.Registry.Pair(denoms.BTC, denoms.NUSD), traderAddr)) require.NoError(t, err) - assert.EqualValues(t, asset.Registry.Pair(denoms.BTC, denoms.NUSD), position.Pair) - assert.EqualValues(t, traderAddr.String(), position.TraderAddress) - assert.EqualValues(t, tc.expectedMargin, position.Margin, "margin") - assert.EqualValues(t, tc.expectedOpenNotional, position.OpenNotional, "open notional") - assert.EqualValues(t, tc.expectedSize, position.Size_, "position size") - assert.EqualValues(t, ctx.BlockHeight(), position.BlockNumber) - assert.EqualValues(t, sdk.ZeroDec(), position.LatestCumulativePremiumFraction) + require.EqualValues(t, asset.Registry.Pair(denoms.BTC, denoms.NUSD), position.Pair) + require.EqualValues(t, traderAddr.String(), position.TraderAddress) + require.EqualValues(t, tc.expectedMargin, position.Margin, "margin") + require.EqualValues(t, tc.expectedOpenNotional, position.OpenNotional, "open notional") + require.EqualValues(t, tc.expectedSize, position.Size_, "position size") + require.EqualValues(t, ctx.BlockHeight(), position.BlockNumber) + require.EqualValues(t, sdk.ZeroDec(), position.LatestCumulativePremiumFraction) exchangedNotional := tc.leverage.MulInt(tc.margin) feePoolFee := nibiruApp.PerpKeeper.GetParams(ctx).FeePoolFeeRatio.Mul(exchangedNotional).RoundInt() diff --git a/x/perp/keeper/clearing_house_unit_test.go b/x/perp/keeper/clearing_house_unit_test.go index 2b1c527dc..ab9e830ce 100644 --- a/x/perp/keeper/clearing_house_unit_test.go +++ b/x/perp/keeper/clearing_house_unit_test.go @@ -112,15 +112,16 @@ func TestSwapQuoteAssetForBase(t *testing.T) { { name: "long position - buy", setMocks: func(ctx sdk.Context, mocks mockedDependencies) { + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} mocks.mockVpoolKeeper.EXPECT(). SwapQuoteForBase( ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_ADD_TO_POOL, /*quoteAmount=*/ sdk.NewDec(10), /*baseLimit=*/ sdk.NewDec(1), /* skipFluctuationLimitCheck */ false, - ).Return(sdk.NewDec(5), nil) + ).Return(vpool, sdk.NewDec(5), nil) }, side: types.Side_BUY, expectedBaseAmount: sdk.NewDec(5), @@ -128,15 +129,16 @@ func TestSwapQuoteAssetForBase(t *testing.T) { { name: "short position - sell", setMocks: func(ctx sdk.Context, mocks mockedDependencies) { + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} mocks.mockVpoolKeeper.EXPECT(). SwapQuoteForBase( ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_REMOVE_FROM_POOL, /*quoteAmount=*/ sdk.NewDec(10), /*baseLimit=*/ sdk.NewDec(1), /* skipFluctuationLimitCheck */ false, - ).Return(sdk.NewDec(5), nil) + ).Return(vpool, sdk.NewDec(5), nil) }, side: types.Side_SELL, expectedBaseAmount: sdk.NewDec(-5), @@ -150,9 +152,11 @@ func TestSwapQuoteAssetForBase(t *testing.T) { tc.setMocks(ctx, mocks) - baseAmount, err := perpKeeper.swapQuoteForBase( + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} + + _, baseAmount, err := perpKeeper.swapQuoteForBase( ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, tc.side, sdk.NewDec(10), sdk.NewDec(1), @@ -170,7 +174,7 @@ func TestIncreasePosition(t *testing.T) { name string initPosition types.Position given func(ctx sdk.Context, mocks mockedDependencies, perpKeeper Keeper) - when func(ctx sdk.Context, perpKeeper Keeper, initPosition types.Position) (*types.PositionResp, error) + when func(ctx sdk.Context, perpKeeper Keeper, initPosition types.Position) (vpooltypes.Vpool, *types.PositionResp, error) then func(t *testing.T, ctx sdk.Context, initPosition types.Position, resp *types.PositionResp, err error) }{ { @@ -189,20 +193,22 @@ func TestIncreasePosition(t *testing.T) { }, given: func(ctx sdk.Context, mocks mockedDependencies, perpKeeper Keeper) { t.Log("mock vpool") + vpool := vpooltypes.Vpool{ + Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD), + } mocks.mockVpoolKeeper.EXPECT(). SwapQuoteForBase( ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, /*quoteAssetDirection=*/ vpooltypes.Direction_ADD_TO_POOL, /*quoteAssetAmount=*/ sdk.NewDec(100), /*baseAssetLimit=*/ sdk.NewDec(50), /* skipFluctuationLimitCheck */ false, - ).Return( /*baseAssetAmount=*/ sdk.NewDec(50), nil) + ).Return(vpool /*baseAssetAmount=*/, sdk.NewDec(50), nil) mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_ADD_TO_POOL, /*baseAssetAmount=*/ sdk.NewDec(100), ). @@ -216,10 +222,11 @@ func TestIncreasePosition(t *testing.T) { }, ) }, - when: func(ctx sdk.Context, perpKeeper Keeper, initPosition types.Position) (*types.PositionResp, error) { + when: func(ctx sdk.Context, perpKeeper Keeper, initPosition types.Position) (vpooltypes.Vpool, *types.PositionResp, error) { t.Log("Increase position with 10 NUSD margin and 10x leverage.") return perpKeeper.increasePosition( ctx, + vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)}, initPosition, types.Side_BUY, /*openNotional=*/ sdk.NewDec(100), // NUSD @@ -263,20 +270,22 @@ func TestIncreasePosition(t *testing.T) { }, given: func(ctx sdk.Context, mocks mockedDependencies, perpKeeper Keeper) { t.Log("mock vpool") + vpool := vpooltypes.Vpool{ + Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD), + } mocks.mockVpoolKeeper.EXPECT(). SwapQuoteForBase( ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, /*quoteAssetDirection=*/ vpooltypes.Direction_ADD_TO_POOL, /*quoteAssetAmount=*/ sdk.NewDec(100), /*baseAssetLimit=*/ sdk.NewDec(101), /* skipFluctuationLimitCheck */ false, - ).Return( /*baseAssetAmount=*/ sdk.NewDec(101), nil) + ).Return(vpool /*baseAssetAmount=*/, sdk.NewDec(101), nil) mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_ADD_TO_POOL, /*baseAssetAmount=*/ sdk.NewDec(100), ). @@ -288,10 +297,11 @@ func TestIncreasePosition(t *testing.T) { LatestCumulativePremiumFraction: sdk.MustNewDecFromStr("0.02"), }) }, - when: func(ctx sdk.Context, perpKeeper Keeper, initPosition types.Position) (*types.PositionResp, error) { + when: func(ctx sdk.Context, perpKeeper Keeper, initPosition types.Position) (vpooltypes.Vpool, *types.PositionResp, error) { t.Log("Increase position with 10 NUSD margin and 10x leverage.") return perpKeeper.increasePosition( ctx, + vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)}, initPosition, types.Side_BUY, /*openNotional=*/ sdk.NewDec(100), // NUSD @@ -338,20 +348,22 @@ func TestIncreasePosition(t *testing.T) { }, given: func(ctx sdk.Context, mocks mockedDependencies, perpKeeper Keeper) { t.Log("mock vpool") + vpool := vpooltypes.Vpool{ + Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD), + } mocks.mockVpoolKeeper.EXPECT(). SwapQuoteForBase( ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, /*quoteAssetDirection=*/ vpooltypes.Direction_ADD_TO_POOL, /*quoteAssetAmount=*/ sdk.NewDec(100), /*baseAssetLimit=*/ sdk.NewDec(110), /* skipFluctuationLimitCheck */ false, - ).Return( /*baseAssetAmount=*/ sdk.NewDec(110), nil) + ).Return(vpool /*baseAssetAmount=*/, sdk.NewDec(110), nil) mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_ADD_TO_POOL, /*baseAssetAmount=*/ sdk.NewDec(110), ). @@ -363,10 +375,11 @@ func TestIncreasePosition(t *testing.T) { LatestCumulativePremiumFraction: sdk.MustNewDecFromStr("0.2"), }) }, - when: func(ctx sdk.Context, perpKeeper Keeper, initPosition types.Position) (*types.PositionResp, error) { + when: func(ctx sdk.Context, perpKeeper Keeper, initPosition types.Position) (vpooltypes.Vpool, *types.PositionResp, error) { t.Log("Increase position with 10 NUSD margin and 10x leverage.") return perpKeeper.increasePosition( ctx, + vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)}, initPosition, types.Side_BUY, /*openNotional=*/ sdk.NewDec(100), // NUSD @@ -411,20 +424,22 @@ func TestIncreasePosition(t *testing.T) { }, given: func(ctx sdk.Context, mocks mockedDependencies, perpKeeper Keeper) { t.Log("mock vpool") + vpool := vpooltypes.Vpool{ + Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD), + } mocks.mockVpoolKeeper.EXPECT(). SwapQuoteForBase( ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, /*quoteAssetDirection=*/ vpooltypes.Direction_REMOVE_FROM_POOL, /*quoteAssetAmount=*/ sdk.NewDec(100), /*baseAssetLimit=*/ sdk.NewDec(200), /* skipFluctuationLimitCheck */ false, - ).Return( /*baseAssetAmount=*/ sdk.NewDec(200), nil) + ).Return(vpool /*baseAssetAmount=*/, sdk.NewDec(200), nil) mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_REMOVE_FROM_POOL, /*baseAssetAmount=*/ sdk.NewDec(100), ). @@ -436,10 +451,11 @@ func TestIncreasePosition(t *testing.T) { LatestCumulativePremiumFraction: sdk.MustNewDecFromStr("0.02"), }) }, - when: func(ctx sdk.Context, perpKeeper Keeper, initPosition types.Position) (*types.PositionResp, error) { + when: func(ctx sdk.Context, perpKeeper Keeper, initPosition types.Position) (vpooltypes.Vpool, *types.PositionResp, error) { t.Log("Increase position with 10 NUSD margin and 10x leverage.") return perpKeeper.increasePosition( ctx, + vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)}, initPosition, types.Side_SELL, /*openNotional=*/ sdk.NewDec(100), // NUSD @@ -484,20 +500,22 @@ func TestIncreasePosition(t *testing.T) { }, given: func(ctx sdk.Context, mocks mockedDependencies, perpKeeper Keeper) { t.Log("mock vpool") + vpool := vpooltypes.Vpool{ + Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD), + } mocks.mockVpoolKeeper.EXPECT(). SwapQuoteForBase( ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, /*quoteAssetDirection=*/ vpooltypes.Direction_REMOVE_FROM_POOL, /*quoteAssetAmount=*/ sdk.NewDec(100), /*baseAssetLimit=*/ sdk.NewDec(99), /* skipFluctuationLimitCheck */ false, - ).Return( /*baseAssetAmount=*/ sdk.NewDec(99), nil) + ).Return(vpool /*baseAssetAmount=*/, sdk.NewDec(99), nil) mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_REMOVE_FROM_POOL, /*baseAssetAmount=*/ sdk.NewDec(100), ). @@ -509,10 +527,11 @@ func TestIncreasePosition(t *testing.T) { LatestCumulativePremiumFraction: sdk.MustNewDecFromStr("0.02"), }) }, - when: func(ctx sdk.Context, perpKeeper Keeper, initPosition types.Position) (*types.PositionResp, error) { + when: func(ctx sdk.Context, perpKeeper Keeper, initPosition types.Position) (vpooltypes.Vpool, *types.PositionResp, error) { t.Log("Increase position with 10 NUSD margin and 10x leverage.") return perpKeeper.increasePosition( ctx, + vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)}, initPosition, types.Side_SELL, /*openNotional=*/ sdk.NewDec(100), // NUSD @@ -560,20 +579,22 @@ func TestIncreasePosition(t *testing.T) { }, given: func(ctx sdk.Context, mocks mockedDependencies, perpKeeper Keeper) { t.Log("mock vpool") + vpool := vpooltypes.Vpool{ + Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD), + } mocks.mockVpoolKeeper.EXPECT(). SwapQuoteForBase( ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, /*quoteAssetDirection=*/ vpooltypes.Direction_REMOVE_FROM_POOL, /*quoteAssetAmount=*/ sdk.NewDec(105), /*baseAssetLimit=*/ sdk.NewDec(100), /* skipFluctuationLimitCheck */ false, - ).Return( /*baseAssetAmount=*/ sdk.NewDec(100), nil) + ).Return(vpool /*baseAssetAmount=*/, sdk.NewDec(100), nil) mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_REMOVE_FROM_POOL, /*baseAssetAmount=*/ sdk.NewDec(100), ). @@ -585,10 +606,11 @@ func TestIncreasePosition(t *testing.T) { LatestCumulativePremiumFraction: sdk.MustNewDecFromStr("-0.3"), }) }, - when: func(ctx sdk.Context, perpKeeper Keeper, initPosition types.Position) (*types.PositionResp, error) { + when: func(ctx sdk.Context, perpKeeper Keeper, initPosition types.Position) (vpooltypes.Vpool, *types.PositionResp, error) { t.Log("Increase position with 10.5 NUSD margin and 10x leverage.") return perpKeeper.increasePosition( ctx, + vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)}, initPosition, types.Side_SELL, /*openNotional=*/ sdk.NewDec(105), // NUSD @@ -625,7 +647,7 @@ func TestIncreasePosition(t *testing.T) { tc.given(ctx, mocks, perpKeeper) - resp, err := tc.when(ctx, perpKeeper, tc.initPosition) + _, resp, err := tc.when(ctx, perpKeeper, tc.initPosition) tc.then(t, ctx, tc.initPosition, resp, err) }) @@ -822,10 +844,12 @@ func TestClosePositionEntirely(t *testing.T) { SetPosition(perpKeeper, ctx, tc.initialPosition) t.Log("mock vpool") + vpool := vpooltypes.Vpool{ + Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD), + } mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, tc.direction, /*baseAssetAmount=*/ tc.initialPosition.Size_.Abs(), ). @@ -834,19 +858,21 @@ func TestClosePositionEntirely(t *testing.T) { mocks.mockVpoolKeeper.EXPECT(). SwapBaseForQuote( ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, /*quoteAssetDirection=*/ tc.direction, /*baseAssetAmount=*/ tc.initialPosition.Size_.Abs(), /*quoteAssetLimit=*/ tc.quoteAssetLimit, /* skipFluctuationLimitCheck */ false, - ).Return( /*quoteAssetAmount=*/ tc.newPositionNotional, nil) + ).Return(vpool /*quoteAssetAmount=*/, tc.newPositionNotional, nil) t.Log("set up pair metadata and last cumulative funding rate") SetPairMetadata(perpKeeper, ctx, tc.pairMetadata) t.Log("close position") - resp, err := perpKeeper.closePositionEntirely( + + _, resp, err := perpKeeper.closePositionEntirely( ctx, + vpool, tc.initialPosition, /*quoteAssetLimit=*/ tc.quoteAssetLimit, // NUSD /* skipFluctuationLimitCheck */ false, @@ -1118,10 +1144,10 @@ func TestDecreasePosition(t *testing.T) { perpKeeper, mocks, ctx := getKeeper(t) t.Log("mock vpool") + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, tc.baseAssetDir, /*baseAssetAmount=*/ tc.initialPosition.Size_.Abs(), ). @@ -1130,12 +1156,12 @@ func TestDecreasePosition(t *testing.T) { mocks.mockVpoolKeeper.EXPECT(). SwapQuoteForBase( ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, /*quoteAssetDirection=*/ tc.quoteAssetDir, /*quoteAssetAmount=*/ tc.quoteAmountToDecrease, /*baseAssetLimit=*/ tc.exchangedBaseAmount.Abs(), /* skipFluctuationLimitCheck */ false, - ).Return( /*baseAssetAmount=*/ tc.baseAssetLimit, nil) + ).Return(vpool /*baseAssetAmount=*/, tc.baseAssetLimit, nil) t.Log("set up pair metadata and last cumulative funding rate") SetPairMetadata(perpKeeper, ctx, types.PairMetadata{ @@ -1144,8 +1170,9 @@ func TestDecreasePosition(t *testing.T) { }) t.Log("decrease position") - resp, err := perpKeeper.decreasePosition( + _, resp, err := perpKeeper.decreasePosition( ctx, + vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)}, tc.initialPosition, /*openNotional=*/ tc.quoteAmountToDecrease, // NUSD /*baseLimit=*/ tc.baseAssetLimit, // BTC @@ -1517,10 +1544,10 @@ func TestCloseAndOpenReversePosition(t *testing.T) { SetPosition(perpKeeper, ctx, currentPosition) t.Log("mock vpool") + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, tc.mockBaseDir, /*baseAssetAmount=*/ currentPosition.Size_.Abs(), ). @@ -1529,23 +1556,23 @@ func TestCloseAndOpenReversePosition(t *testing.T) { mocks.mockVpoolKeeper.EXPECT(). SwapBaseForQuote( ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, tc.mockBaseDir, /*baseAssetAmount=*/ currentPosition.Size_.Abs(), /*quoteAssetLimit=*/ sdk.ZeroDec(), /* skipFluctuationLimitCheck */ false, - ).Return( /*quoteAssetAmount=*/ tc.mockQuoteAmount, nil) + ).Return(vpool /*quoteAssetAmount=*/, tc.mockQuoteAmount, nil) if tc.expectedErr == nil { mocks.mockVpoolKeeper.EXPECT(). SwapQuoteForBase( ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, /*quoteAssetDirection=*/ tc.mockQuoteDir, /*quoteAssetAmount=*/ tc.inputQuoteAmount.Mul(tc.inputLeverage).Sub(tc.mockQuoteAmount), /*baseAssetLimit=*/ sdk.MaxDec(tc.inputBaseAssetLimit.Sub(currentPosition.Size_.Abs()), sdk.ZeroDec()), /* skipFluctuationLimitCheck */ false, - ).Return( /*baseAssetAmount=*/ tc.mockBaseAmount, nil) + ).Return(vpool /*baseAssetAmount=*/, tc.mockBaseAmount, nil) } t.Log("set up pair metadata and last cumulative funding rate") @@ -1555,8 +1582,9 @@ func TestCloseAndOpenReversePosition(t *testing.T) { }) t.Log("close position and open reverse") - resp, err := perpKeeper.closeAndOpenReversePosition( + _, resp, err := perpKeeper.closeAndOpenReversePosition( ctx, + vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)}, currentPosition, /*quoteAssetAmount=*/ tc.inputQuoteAmount, // NUSD /*leverage=*/ tc.inputLeverage, @@ -1802,10 +1830,13 @@ func TestClosePosition(t *testing.T) { perpKeeper.SetParams(ctx, params) t.Log("mock vpool keeper") + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} + mocks.mockVpoolKeeper.EXPECT(). + GetPool(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)). + Return(vpool, nil) mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, tc.baseAssetDir, /*baseAssetAmount=*/ tc.initialPosition.Size_.Abs(), ). @@ -1814,12 +1845,12 @@ func TestClosePosition(t *testing.T) { mocks.mockVpoolKeeper.EXPECT(). SwapBaseForQuote( ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, /*baseAssetDirection=*/ tc.baseAssetDir, /*baseAssetAmount=*/ tc.initialPosition.Size_.Abs(), /*quoteAssetLimit=*/ sdk.ZeroDec(), /* skipFluctuationLimitCheck */ false, - ).Return( /*quoteAssetAmount=*/ tc.newPositionNotional, nil) + ).Return(vpool /*quoteAssetAmount=*/, tc.newPositionNotional, nil) mocks.mockVpoolKeeper.EXPECT(). GetMarkPrice( @@ -1965,10 +1996,14 @@ func TestClosePositionWithBadDebt(t *testing.T) { perpKeeper.SetParams(ctx, types.DefaultParams()) t.Log("mock vpool keeper") + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} + mocks.mockVpoolKeeper.EXPECT(). + GetPool(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)). + Return(vpool, nil) + mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, tc.baseAssetDir, /*baseAssetAmount=*/ tc.initialPosition.Size_.Abs(), ). @@ -1978,12 +2013,12 @@ func TestClosePositionWithBadDebt(t *testing.T) { mocks.mockVpoolKeeper.EXPECT(). SwapBaseForQuote( ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, /*baseAssetDirection=*/ tc.baseAssetDir, /*baseAssetAmount=*/ tc.initialPosition.Size_.Abs(), /*quoteAssetLimit=*/ sdk.ZeroDec(), /* skipFluctuationLimitCheck */ false, - ).Return( /*quoteAssetAmount=*/ tc.newPositionNotional, nil) + ).Return(vpool /*quoteAssetAmount=*/, tc.newPositionNotional, nil) t.Log("set up pair metadata and last cumulative funding rate") SetPairMetadata(perpKeeper, ctx, types.PairMetadata{ diff --git a/x/perp/keeper/grpc_query.go b/x/perp/keeper/grpc_query.go index 0870bc53c..3d96c42a1 100644 --- a/x/perp/keeper/grpc_query.go +++ b/x/perp/keeper/grpc_query.go @@ -74,16 +74,21 @@ func (q queryServer) position(ctx sdk.Context, pair asset.Pair, trader sdk.AccAd return nil, err } - positionNotional, unrealizedPnl, err := q.k.getPositionNotionalAndUnrealizedPnL(ctx, position, types.PnLCalcOption_SPOT_PRICE) + vpool, err := q.k.VpoolKeeper.GetPool(ctx, pair) + if err != nil { + return nil, types.ErrPairNotFound + } + + positionNotional, unrealizedPnl, err := q.k.getPositionNotionalAndUnrealizedPnL(ctx, vpool, position, types.PnLCalcOption_SPOT_PRICE) if err != nil { return nil, err } - marginRatioMark, err := q.k.GetMarginRatio(ctx, position, types.MarginCalculationPriceOption_MAX_PNL) + marginRatioMark, err := q.k.GetMarginRatio(ctx, vpool, position, types.MarginCalculationPriceOption_MAX_PNL) if err != nil { return nil, err } - marginRatioIndex, err := q.k.GetMarginRatio(ctx, position, types.MarginCalculationPriceOption_INDEX) + marginRatioIndex, err := q.k.GetMarginRatio(ctx, vpool, position, types.MarginCalculationPriceOption_INDEX) if err != nil { // The index portion of the query fails silently as not to distrupt all // position queries when oracles aren't posting prices. diff --git a/x/perp/keeper/liquidate.go b/x/perp/keeper/liquidate.go index 88a521bd6..9a3d1dea5 100644 --- a/x/perp/keeper/liquidate.go +++ b/x/perp/keeper/liquidate.go @@ -31,7 +31,11 @@ func (k Keeper) Liquidate( pair asset.Pair, trader sdk.AccAddress, ) (liquidatorFee sdk.Coin, perpEcosystemFundFee sdk.Coin, err error) { - err = k.requireVpool(ctx, pair) + vpool, err := k.VpoolKeeper.GetPool(ctx, pair) + if err != nil { + return sdk.Coin{}, sdk.Coin{}, types.ErrPairNotFound + } + if err != nil { _ = ctx.EventManager().EmitTypedEvent(&types.LiquidationFailedEvent{ // nolint:errcheck Pair: pair, @@ -55,6 +59,7 @@ func (k Keeper) Liquidate( marginRatio, err := k.GetMarginRatio( ctx, + vpool, position, types.MarginCalculationPriceOption_MAX_PNL, ) @@ -68,7 +73,7 @@ func (k Keeper) Liquidate( } if isOverSpreadLimit { marginRatioBasedOnOracle, err := k.GetMarginRatio( - ctx, position, types.MarginCalculationPriceOption_INDEX) + ctx, vpool, position, types.MarginCalculationPriceOption_INDEX) if err != nil { return liquidatorFee, perpEcosystemFundFee, err } @@ -94,7 +99,7 @@ func (k Keeper) Liquidate( } marginRatioBasedOnSpot, err := k.GetMarginRatio( - ctx, position, types.MarginCalculationPriceOption_SPOT) + ctx, vpool, position, types.MarginCalculationPriceOption_SPOT) if err != nil { return } @@ -140,13 +145,19 @@ func (k Keeper) ExecuteFullLiquidation( ) (liquidationResp types.LiquidateResp, err error) { params := k.GetParams(ctx) + vpool, err := k.VpoolKeeper.GetPool(ctx, position.Pair) + if err != nil { + return types.LiquidateResp{}, types.ErrPairNotFound + } + traderAddr, err := sdk.AccAddressFromBech32(position.TraderAddress) if err != nil { return types.LiquidateResp{}, err } - positionResp, err := k.closePositionEntirely( + _, positionResp, err := k.closePositionEntirely( ctx, + vpool, /* currentPosition */ *position, /* quoteAssetAmountLimit */ sdk.ZeroDec(), /* skipFluctuationLimitCheck */ true, @@ -286,6 +297,11 @@ func (k Keeper) ExecutePartialLiquidation( ) (types.LiquidateResp, error) { params := k.GetParams(ctx) + vpool, err := k.VpoolKeeper.GetPool(ctx, currentPosition.Pair) + if err != nil { + return types.LiquidateResp{}, types.ErrPairNotFound + } + traderAddr, err := sdk.AccAddressFromBech32(currentPosition.TraderAddress) if err != nil { return types.LiquidateResp{}, err @@ -299,8 +315,7 @@ func (k Keeper) ExecutePartialLiquidation( } partiallyLiquidatedPositionNotional, err := k.VpoolKeeper.GetBaseAssetPrice( - ctx, - currentPosition.Pair, + vpool, baseAssetDir, /* abs= */ currentPosition.Size_.Mul(params.PartialLiquidationRatio), ) @@ -308,8 +323,9 @@ func (k Keeper) ExecutePartialLiquidation( return types.LiquidateResp{}, err } - positionResp, err := k.decreasePosition( + _, positionResp, err := k.decreasePosition( /* ctx */ ctx, + vpool, /* currentPosition */ *currentPosition, /* quoteAssetAmount */ partiallyLiquidatedPositionNotional, /* baseAmtLimit */ sdk.ZeroDec(), diff --git a/x/perp/keeper/liquidate_unit_test.go b/x/perp/keeper/liquidate_unit_test.go index 2d98a3b1c..6c052f15c 100644 --- a/x/perp/keeper/liquidate_unit_test.go +++ b/x/perp/keeper/liquidate_unit_test.go @@ -109,8 +109,13 @@ func TestLiquidateIntoPartialLiquidation(t *testing.T) { }) t.Log("mock vpool keeper") + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} mocks.mockVpoolKeeper.EXPECT(). - ExistsPool(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)).Return(true).Times(2) + GetPool(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)). + Times(2). + Return(vpool, nil) + mocks.mockVpoolKeeper.EXPECT(). + ExistsPool(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)).Return(true).Times(1) mocks.mockVpoolKeeper.EXPECT(). GetMaintenanceMarginRatio(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)). Return(sdk.MustNewDecFromStr("0.0625"), nil) @@ -130,16 +135,14 @@ func TestLiquidateIntoPartialLiquidation(t *testing.T) { Return(tc.newPositionNotional, nil) mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_ADD_TO_POOL, sdk.OneDec(), ). Return(tc.newPositionNotional, nil).Times(3) mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_ADD_TO_POOL, tc.exchangedSize, ). @@ -147,13 +150,13 @@ func TestLiquidateIntoPartialLiquidation(t *testing.T) { mocks.mockVpoolKeeper.EXPECT(). SwapQuoteForBase( ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_REMOVE_FROM_POOL, /* quoteAmt */ tc.exchangedNotional, /* baseLimit */ sdk.ZeroDec(), /* skipFluctuationLimitCheck */ true, ). - Return(tc.exchangedSize, nil) + Return(vpool, tc.exchangedSize, nil) t.Log("mock account keeper") mocks.mockAccountKeeper.EXPECT(). @@ -284,8 +287,12 @@ func TestLiquidateIntoFullLiquidation(t *testing.T) { }) t.Log("mock vpool keeper") + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} mocks.mockVpoolKeeper.EXPECT(). - ExistsPool(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)).Return(true).Times(2) + GetPool(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)).Times(2). + Return(vpool, nil) + mocks.mockVpoolKeeper.EXPECT(). + ExistsPool(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)).Return(true) mocks.mockVpoolKeeper.EXPECT(). GetMaintenanceMarginRatio(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)). Return(sdk.MustNewDecFromStr("0.0625"), nil) @@ -304,8 +311,7 @@ func TestLiquidateIntoFullLiquidation(t *testing.T) { Return(tc.newPositionNotional, nil) mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_ADD_TO_POOL, tc.initialPositionSize, ). @@ -313,13 +319,13 @@ func TestLiquidateIntoFullLiquidation(t *testing.T) { mocks.mockVpoolKeeper.EXPECT(). SwapBaseForQuote( ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_ADD_TO_POOL, /* baseAmt */ tc.initialPositionSize, /* quoteLimit */ sdk.ZeroDec(), /* skipFluctuationLimitCheck */ true, ). - Return(tc.newPositionNotional, nil) + Return(vpool, tc.newPositionNotional, nil) t.Log("mock account keeper") mocks.mockAccountKeeper.EXPECT(). @@ -455,8 +461,12 @@ func TestLiquidateIntoFullLiquidationWithBadDebt(t *testing.T) { }) t.Log("mock vpool keeper") + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} + mocks.mockVpoolKeeper.EXPECT(). + GetPool(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)).Times(2). + Return(vpool, nil) mocks.mockVpoolKeeper.EXPECT(). - ExistsPool(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)).Return(true).Times(2) + ExistsPool(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)).Return(true) mocks.mockVpoolKeeper.EXPECT(). GetMaintenanceMarginRatio(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)). Return(sdk.MustNewDecFromStr("0.0625"), nil) @@ -475,8 +485,7 @@ func TestLiquidateIntoFullLiquidationWithBadDebt(t *testing.T) { Return(tc.newPositionNotional, nil) mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_ADD_TO_POOL, tc.initialPositionSize, ). @@ -484,13 +493,13 @@ func TestLiquidateIntoFullLiquidationWithBadDebt(t *testing.T) { mocks.mockVpoolKeeper.EXPECT(). SwapBaseForQuote( ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_ADD_TO_POOL, /* baseAmt */ tc.initialPositionSize, /* quoteLimit */ sdk.ZeroDec(), /* skipFluctuationLimitCheck */ true, ). - Return(tc.newPositionNotional, nil) + Return(vpool, tc.newPositionNotional, nil) t.Log("mock account keeper") mocks.mockAccountKeeper.EXPECT(). @@ -934,11 +943,12 @@ func TestKeeper_ExecuteFullLiquidation(t *testing.T) { }) t.Log("mock vpool") + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} + mocks.mockVpoolKeeper.EXPECT().GetPool(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)).Return(vpool, nil) mocks.mockVpoolKeeper.EXPECT().ExistsPool(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)).AnyTimes().Return(true) mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, baseAssetDirection, /*baseAssetAmount=*/ tc.initialPositionSize.Abs(), ). @@ -946,12 +956,12 @@ func TestKeeper_ExecuteFullLiquidation(t *testing.T) { mocks.mockVpoolKeeper.EXPECT(). SwapBaseForQuote( ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, baseAssetDirection, /*baseAssetAmount=*/ tc.initialPositionSize.Abs(), /*quoteAssetAssetLimit=*/ sdk.ZeroDec(), /* skipFluctuationLimitCheck */ true, - ).Return( /*quoteAssetAmount=*/ tc.baseAssetPriceInQuote, nil) + ).Return(vpool /*quoteAssetAmount=*/, tc.baseAssetPriceInQuote, nil) mocks.mockVpoolKeeper.EXPECT(). GetMarkPrice(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)). Return(sdk.OneDec(), nil) @@ -1225,11 +1235,14 @@ func TestKeeper_ExecutePartialLiquidation(t *testing.T) { }) t.Log("mock vpool") + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} + mocks.mockVpoolKeeper.EXPECT(). + GetPool(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)). + Return(vpool, nil) mocks.mockVpoolKeeper.EXPECT().ExistsPool(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)).AnyTimes().Return(true) mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, baseAssetDirection, /*baseAssetAmount=*/ tc.initialPositionSize.Mul(tc.partialLiquidationRatio), ). @@ -1237,8 +1250,7 @@ func TestKeeper_ExecutePartialLiquidation(t *testing.T) { mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, baseAssetDirection, /*baseAssetAmount=*/ tc.initialPositionSize.Abs(), ). @@ -1248,22 +1260,22 @@ func TestKeeper_ExecutePartialLiquidation(t *testing.T) { mocks.mockVpoolKeeper.EXPECT(). SwapQuoteForBase( ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_ADD_TO_POOL, /*baseAssetAmount=*/ tc.baseAssetPriceInQuote.Mul(tc.partialLiquidationRatio), /*quoteAssetAssetLimit=*/ sdk.ZeroDec(), /* skipFluctuationLimitCheck */ true, - ).Return( /*quoteAssetAmount=*/ tc.baseAssetPriceInQuote.Mul(tc.partialLiquidationRatio), nil) + ).Return(vpool /*quoteAssetAmount=*/, tc.baseAssetPriceInQuote.Mul(tc.partialLiquidationRatio), nil) } else { mocks.mockVpoolKeeper.EXPECT(). SwapQuoteForBase( ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_REMOVE_FROM_POOL, /*baseAssetAmount=*/ tc.baseAssetPriceInQuote.Mul(tc.partialLiquidationRatio), /*quoteAssetAssetLimit=*/ sdk.ZeroDec(), /* skipFluctuationLimitCheck */ true, - ).Return( /*quoteAssetAmount=*/ tc.baseAssetPriceInQuote.Mul(tc.partialLiquidationRatio), nil) + ).Return(vpool /*quoteAssetAmount=*/, tc.baseAssetPriceInQuote.Mul(tc.partialLiquidationRatio), nil) } mocks.mockVpoolKeeper.EXPECT(). diff --git a/x/perp/keeper/margin.go b/x/perp/keeper/margin.go index 3ed770f3b..af195b361 100644 --- a/x/perp/keeper/margin.go +++ b/x/perp/keeper/margin.go @@ -20,9 +20,9 @@ to it. Adding margin increases the margin ratio of the corresponding position. func (k Keeper) AddMargin( ctx sdk.Context, pair asset.Pair, traderAddr sdk.AccAddress, margin sdk.Coin, ) (res *types.MsgAddMarginResponse, err error) { - // validate vpool exists - if err = k.requireVpool(ctx, pair); err != nil { - return nil, err + vpool, err := k.VpoolKeeper.GetPool(ctx, pair) + if err != nil { + return nil, types.ErrPairNotFound } // ------------- AddMargin ------------- @@ -54,7 +54,7 @@ func (k Keeper) AddMargin( position.BlockNumber = ctx.BlockHeight() k.Positions.Insert(ctx, collections.Join(position.Pair, traderAddr), position) - positionNotional, unrealizedPnl, err := k.getPositionNotionalAndUnrealizedPnL(ctx, position, types.PnLCalcOption_SPOT_PRICE) + positionNotional, unrealizedPnl, err := k.getPositionNotionalAndUnrealizedPnL(ctx, vpool, position, types.PnLCalcOption_SPOT_PRICE) if err != nil { return nil, err } @@ -114,9 +114,9 @@ ret: func (k Keeper) RemoveMargin( ctx sdk.Context, pair asset.Pair, traderAddr sdk.AccAddress, margin sdk.Coin, ) (marginOut sdk.Coin, fundingPayment sdk.Dec, position types.Position, err error) { - // validate vpool exists - if err = k.requireVpool(ctx, pair); err != nil { - return sdk.Coin{}, sdk.Dec{}, types.Position{}, err + vpool, err := k.VpoolKeeper.GetPool(ctx, pair) + if err != nil { + return sdk.Coin{}, sdk.Dec{}, types.Position{}, types.ErrPairNotFound } // ------------- RemoveMargin ------------- @@ -137,7 +137,7 @@ func (k Keeper) RemoveMargin( position.Margin = remainingMargin.Margin position.LatestCumulativePremiumFraction = remainingMargin.LatestCumulativePremiumFraction - freeCollateral, err := k.calcFreeCollateral(ctx, position) + freeCollateral, err := k.calcFreeCollateral(ctx, vpool, position) if err != nil { return sdk.Coin{}, sdk.Dec{}, types.Position{}, err } else if !freeCollateral.IsPositive() { @@ -146,7 +146,7 @@ func (k Keeper) RemoveMargin( k.Positions.Insert(ctx, collections.Join(position.Pair, traderAddr), position) - positionNotional, unrealizedPnl, err := k.getPositionNotionalAndUnrealizedPnL(ctx, position, types.PnLCalcOption_SPOT_PRICE) + positionNotional, unrealizedPnl, err := k.getPositionNotionalAndUnrealizedPnL(ctx, vpool, position, types.PnLCalcOption_SPOT_PRICE) if err != nil { return sdk.Coin{}, sdk.Dec{}, types.Position{}, err } @@ -187,7 +187,7 @@ func (k Keeper) RemoveMargin( // GetMarginRatio calculates the MarginRatio from a Position func (k Keeper) GetMarginRatio( - ctx sdk.Context, position types.Position, priceOption types.MarginCalculationPriceOption, + ctx sdk.Context, vpool vpooltypes.Vpool, position types.Position, priceOption types.MarginCalculationPriceOption, ) (marginRatio sdk.Dec, err error) { if position.Size_.IsZero() { return sdk.Dec{}, types.ErrPositionZero @@ -202,18 +202,21 @@ func (k Keeper) GetMarginRatio( case types.MarginCalculationPriceOption_MAX_PNL: positionNotional, unrealizedPnL, err = k.GetPreferencePositionNotionalAndUnrealizedPnL( ctx, + vpool, position, types.PnLPreferenceOption_MAX, ) case types.MarginCalculationPriceOption_INDEX: positionNotional, unrealizedPnL, err = k.getPositionNotionalAndUnrealizedPnL( ctx, + vpool, position, types.PnLCalcOption_ORACLE, ) case types.MarginCalculationPriceOption_SPOT: positionNotional, unrealizedPnL, err = k.getPositionNotionalAndUnrealizedPnL( ctx, + vpool, position, types.PnLCalcOption_SPOT_PRICE, ) @@ -293,6 +296,7 @@ Returns: */ func (k Keeper) getPositionNotionalAndUnrealizedPnL( ctx sdk.Context, + vpool vpooltypes.Vpool, currentPosition types.Position, pnlCalcOption types.PnLCalcOption, ) (positionNotional sdk.Dec, unrealizedPnL sdk.Dec, err error) { @@ -325,8 +329,7 @@ func (k Keeper) getPositionNotionalAndUnrealizedPnL( } case types.PnLCalcOption_SPOT_PRICE: positionNotional, err = k.VpoolKeeper.GetBaseAssetPrice( - ctx, - currentPosition.Pair, + vpool, baseAssetDirection, positionSizeAbs, ) @@ -388,11 +391,13 @@ Returns: */ func (k Keeper) GetPreferencePositionNotionalAndUnrealizedPnL( ctx sdk.Context, + vpool vpooltypes.Vpool, position types.Position, pnLPreferenceOption types.PnLPreferenceOption, ) (positionNotional sdk.Dec, unrealizedPnl sdk.Dec, err error) { spotPositionNotional, spotPricePnl, err := k.getPositionNotionalAndUnrealizedPnL( ctx, + vpool, position, types.PnLCalcOption_SPOT_PRICE, ) @@ -409,6 +414,7 @@ func (k Keeper) GetPreferencePositionNotionalAndUnrealizedPnL( twapPositionNotional, twapPnl, err := k.getPositionNotionalAndUnrealizedPnL( ctx, + vpool, position, types.PnLCalcOption_TWAP, ) diff --git a/x/perp/keeper/margin_unit_test.go b/x/perp/keeper/margin_unit_test.go index 91e051057..c845f6d8f 100644 --- a/x/perp/keeper/margin_unit_test.go +++ b/x/perp/keeper/margin_unit_test.go @@ -85,8 +85,9 @@ func TestGetMarginRatio_Errors(t *testing.T) { Size_: sdk.ZeroDec(), } + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} _, err := k.GetMarginRatio( - ctx, pos, types.MarginCalculationPriceOption_MAX_PNL) + ctx, vpool, pos, types.MarginCalculationPriceOption_MAX_PNL) assert.EqualError(t, err, types.ErrPositionZero.Error()) }, }, @@ -141,10 +142,10 @@ func TestGetMarginRatio(t *testing.T) { perpKeeper, mocks, ctx := getKeeper(t) t.Log("Mock vpool spot price") + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_ADD_TO_POOL, tc.position.Size_.Abs(), ). @@ -166,7 +167,8 @@ func TestGetMarginRatio(t *testing.T) { }) marginRatio, err := perpKeeper.GetMarginRatio( - ctx, tc.position, types.MarginCalculationPriceOption_MAX_PNL) + ctx, vpool, tc.position, types.MarginCalculationPriceOption_MAX_PNL, + ) require.NoError(t, err) require.Equal(t, tc.expectedMarginRatio, marginRatio) @@ -188,7 +190,7 @@ func TestRemoveMargin(t *testing.T) { traderAddr := testutilevents.AccAddress() pair := asset.NewPair("osmo", "nusd") - mocks.mockVpoolKeeper.EXPECT().ExistsPool(ctx, pair).Return(true) + mocks.mockVpoolKeeper.EXPECT().GetPool(ctx, pair).Return(vpooltypes.Vpool{Pair: pair}, nil) t.Log("Set vpool defined by pair on PerpKeeper") SetPairMetadata(perpKeeper, ctx, types.PairMetadata{ @@ -223,13 +225,13 @@ func TestRemoveMargin(t *testing.T) { marginToWithdraw := sdk.NewInt64Coin(pair.QuoteDenom(), 100) t.Log("mock vpool keeper") - mocks.mockVpoolKeeper.EXPECT().ExistsPool(ctx, pair).AnyTimes().Return(true) + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} + mocks.mockVpoolKeeper.EXPECT().GetPool(ctx, pair).Return(vpool, nil) mocks.mockVpoolKeeper.EXPECT().GetMaintenanceMarginRatio(ctx, pair). Return(sdk.MustNewDecFromStr("0.0625"), nil) mocks.mockVpoolKeeper.EXPECT().GetMarkPrice(ctx, pair).Return(sdk.OneDec(), nil) mocks.mockVpoolKeeper.EXPECT().GetBaseAssetPrice( - ctx, - pair, + vpool, vpooltypes.Direction_ADD_TO_POOL, sdk.NewDec(1_000), ).Return(sdk.NewDec(1000), nil).Times(2) @@ -291,14 +293,14 @@ func TestRemoveMargin(t *testing.T) { marginToWithdraw := sdk.NewInt64Coin(pair.QuoteDenom(), 100) t.Log("mock vpool keeper") - mocks.mockVpoolKeeper.EXPECT().ExistsPool(ctx, pair).Return(true) + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} + mocks.mockVpoolKeeper.EXPECT().GetPool(ctx, pair).Return(vpool, nil) mocks.mockVpoolKeeper.EXPECT().GetMaintenanceMarginRatio(ctx, pair). Return(sdk.MustNewDecFromStr("0.0625"), nil) - mocks.mockVpoolKeeper.EXPECT().ExistsPool(ctx, pair).Return(true) mocks.mockVpoolKeeper.EXPECT().GetMarkPrice(ctx, pair).Return(sdk.OneDec(), nil) mocks.mockVpoolKeeper.EXPECT().GetBaseAssetPrice( - ctx, pair, vpooltypes.Direction_ADD_TO_POOL, sdk.NewDec(1_000)). + vpool, vpooltypes.Direction_ADD_TO_POOL, sdk.NewDec(1_000)). Return(sdk.NewDec(1000), nil).Times(2) mocks.mockVpoolKeeper.EXPECT().GetBaseAssetTWAP( ctx, pair, vpooltypes.Direction_ADD_TO_POOL, sdk.NewDec(1_000), @@ -387,7 +389,8 @@ func TestRemoveMargin(t *testing.T) { marginToWithdraw := sdk.NewInt64Coin(pair.QuoteDenom(), 100) t.Log("mock vpool keeper") - mocks.mockVpoolKeeper.EXPECT().ExistsPool(ctx, pair).Return(true) + vpool := vpooltypes.Vpool{Pair: pair} + mocks.mockVpoolKeeper.EXPECT().GetPool(ctx, pair).Return(vpool, nil) t.Log("set pair metadata") SetPairMetadata(perpKeeper, ctx, types.PairMetadata{ @@ -447,7 +450,7 @@ func TestAddMargin(t *testing.T) { Pair: pair, LatestCumulativePremiumFraction: sdk.ZeroDec(), }) - mocks.mockVpoolKeeper.EXPECT().ExistsPool(ctx, pair).Return(true) + mocks.mockVpoolKeeper.EXPECT().GetPool(ctx, pair).Return(vpooltypes.Vpool{Pair: pair}, nil) t.Log("set a position") SetPosition(perpKeeper, ctx, types.Position{ @@ -480,8 +483,9 @@ func TestAddMargin(t *testing.T) { traderAddr := testutilevents.AccAddress() margin := sdk.NewInt64Coin("unusd", 100) - mocks.mockVpoolKeeper.EXPECT().ExistsPool(ctx, pair).Return(true) - mocks.mockVpoolKeeper.EXPECT().GetBaseAssetPrice(ctx, pair, vpooltypes.Direction_ADD_TO_POOL, sdk.NewDec(1000)).Return(sdk.NewDec(1000), nil) + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} + mocks.mockVpoolKeeper.EXPECT().GetPool(ctx, pair).Return(vpool, nil) + mocks.mockVpoolKeeper.EXPECT().GetBaseAssetPrice(vpool, vpooltypes.Direction_ADD_TO_POOL, sdk.NewDec(1000)).Return(sdk.NewDec(1000), nil) mocks.mockVpoolKeeper.EXPECT().GetMarkPrice(ctx, pair).Return(sdk.OneDec(), nil) t.Log("set pair metadata") @@ -551,8 +555,9 @@ func TestAddMargin(t *testing.T) { traderAddr := testutilevents.AccAddress() margin := sdk.NewInt64Coin("unusd", 100) - mocks.mockVpoolKeeper.EXPECT().ExistsPool(ctx, pair).Return(true) - mocks.mockVpoolKeeper.EXPECT().GetBaseAssetPrice(ctx, pair, vpooltypes.Direction_ADD_TO_POOL, sdk.NewDec(1000)).Return(sdk.NewDec(1000), nil) + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} + mocks.mockVpoolKeeper.EXPECT().GetPool(ctx, pair).Return(vpool, nil) + mocks.mockVpoolKeeper.EXPECT().GetBaseAssetPrice(vpool, vpooltypes.Direction_ADD_TO_POOL, sdk.NewDec(1000)).Return(sdk.NewDec(1000), nil) mocks.mockVpoolKeeper.EXPECT().GetMarkPrice(ctx, pair).Return(sdk.OneDec(), nil) t.Log("set pair metadata") @@ -641,10 +646,10 @@ func TestGetPositionNotionalAndUnrealizedPnl(t *testing.T) { Margin: sdk.NewDec(1), }, setMocks: func(ctx sdk.Context, mocks mockedDependencies) { + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_ADD_TO_POOL, sdk.NewDec(10), ). @@ -664,10 +669,10 @@ func TestGetPositionNotionalAndUnrealizedPnl(t *testing.T) { Margin: sdk.NewDec(1), }, setMocks: func(ctx sdk.Context, mocks mockedDependencies) { + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_ADD_TO_POOL, sdk.NewDec(10), ). @@ -777,10 +782,10 @@ func TestGetPositionNotionalAndUnrealizedPnl(t *testing.T) { Margin: sdk.NewDec(1), }, setMocks: func(ctx sdk.Context, mocks mockedDependencies) { + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_REMOVE_FROM_POOL, sdk.NewDec(10), ). @@ -800,10 +805,10 @@ func TestGetPositionNotionalAndUnrealizedPnl(t *testing.T) { Margin: sdk.NewDec(1), }, setMocks: func(ctx sdk.Context, mocks mockedDependencies) { + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_REMOVE_FROM_POOL, sdk.NewDec(10), ). @@ -912,9 +917,11 @@ func TestGetPositionNotionalAndUnrealizedPnl(t *testing.T) { tc.setMocks(ctx, mocks) + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} positionalNotional, unrealizedPnl, err := perpKeeper. getPositionNotionalAndUnrealizedPnL( ctx, + vpool, tc.initialPosition, tc.pnlCalcOption, ) @@ -949,10 +956,10 @@ func TestGetPreferencePositionNotionalAndUnrealizedPnL(t *testing.T) { }, setMocks: func(ctx sdk.Context, mocks mockedDependencies) { t.Log("Mock vpool spot price") + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_ADD_TO_POOL, sdk.NewDec(10), ). @@ -983,10 +990,10 @@ func TestGetPreferencePositionNotionalAndUnrealizedPnL(t *testing.T) { }, setMocks: func(ctx sdk.Context, mocks mockedDependencies) { t.Log("Mock vpool spot price") + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_ADD_TO_POOL, sdk.NewDec(10), ). @@ -1017,10 +1024,10 @@ func TestGetPreferencePositionNotionalAndUnrealizedPnL(t *testing.T) { }, setMocks: func(ctx sdk.Context, mocks mockedDependencies) { t.Log("Mock vpool spot price") + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_ADD_TO_POOL, sdk.NewDec(10), ). @@ -1051,10 +1058,10 @@ func TestGetPreferencePositionNotionalAndUnrealizedPnL(t *testing.T) { }, setMocks: func(ctx sdk.Context, mocks mockedDependencies) { t.Log("Mock vpool spot price") + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} mocks.mockVpoolKeeper.EXPECT(). GetBaseAssetPrice( - ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, vpooltypes.Direction_ADD_TO_POOL, sdk.NewDec(10), ). @@ -1083,9 +1090,11 @@ func TestGetPreferencePositionNotionalAndUnrealizedPnL(t *testing.T) { tc.setMocks(ctx, mocks) + vpool := vpooltypes.Vpool{Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD)} positionalNotional, unrealizedPnl, err := perpKeeper. GetPreferencePositionNotionalAndUnrealizedPnL( ctx, + vpool, tc.initPosition, tc.pnlPreferenceOption, ) diff --git a/x/perp/keeper/open_gas_test.go b/x/perp/keeper/open_gas_test.go new file mode 100644 index 000000000..6f090014e --- /dev/null +++ b/x/perp/keeper/open_gas_test.go @@ -0,0 +1,43 @@ +package keeper_test + +import ( + "testing" + "time" + + sdk "github.com/cosmos/cosmos-sdk/types" + + "github.com/NibiruChain/nibiru/x/common/asset" + "github.com/NibiruChain/nibiru/x/common/denoms" + testutilevents "github.com/NibiruChain/nibiru/x/common/testutil" + . "github.com/NibiruChain/nibiru/x/oracle/integration_test/action" + . "github.com/NibiruChain/nibiru/x/perp/integration/action" + perptypes "github.com/NibiruChain/nibiru/x/perp/types" + . "github.com/NibiruChain/nibiru/x/testutil" + . "github.com/NibiruChain/nibiru/x/testutil/action" + "github.com/NibiruChain/nibiru/x/testutil/assertion" +) + +func TestOpenGasConsumed(t *testing.T) { + ts := NewTestSuite(t) + + alice := testutilevents.AccAddress() + pairBtcUsdc := asset.Registry.Pair(denoms.BTC, denoms.USDC) + + testCases := TestCases{ + TC("open position gas consumed"). + Given( + createInitVPool(), + SetBlockTime(time.Now()), + SetBlockNumber(1), + SetPairPrice(pairBtcUsdc, sdk.NewDec(10000)), + FundAccount(alice, sdk.NewCoins(sdk.NewCoin(denoms.USDC, sdk.NewInt(1020)))), + ). + When( + OpenPosition(alice, pairBtcUsdc, perptypes.Side_BUY, sdk.NewInt(1000), sdk.NewDec(10), sdk.ZeroDec()), + ).Then( + assertion.GasConsumedShouldBe(148190), + ), + } + + ts.WithTestCases(testCases...).Run() +} diff --git a/x/perp/types/expected_keepers.go b/x/perp/types/expected_keepers.go index 4a7c51dd8..067f9296f 100644 --- a/x/perp/types/expected_keepers.go +++ b/x/perp/types/expected_keepers.go @@ -52,21 +52,21 @@ type OracleKeeper interface { type VpoolKeeper interface { SwapBaseForQuote( ctx sdk.Context, - pair asset.Pair, + vpool vpooltypes.Vpool, dir vpooltypes.Direction, baseAssetAmount sdk.Dec, quoteAmountLimit sdk.Dec, skipFluctuationLimitCheck bool, - ) (sdk.Dec, error) + ) (vpooltypes.Vpool, sdk.Dec, error) SwapQuoteForBase( ctx sdk.Context, - pair asset.Pair, + vpool vpooltypes.Vpool, dir vpooltypes.Direction, quoteAssetAmount sdk.Dec, baseAmountLimit sdk.Dec, skipFluctuationLimitCheck bool, - ) (sdk.Dec, error) + ) (vpooltypes.Vpool, sdk.Dec, error) GetBaseAssetTWAP( ctx sdk.Context, @@ -77,8 +77,7 @@ type VpoolKeeper interface { ) (quoteAssetAmount sdk.Dec, err error) GetBaseAssetPrice( - ctx sdk.Context, - pair asset.Pair, + vpool vpooltypes.Vpool, direction vpooltypes.Direction, baseAssetAmount sdk.Dec, ) (quoteAssetAmount sdk.Dec, err error) @@ -102,10 +101,10 @@ type VpoolKeeper interface { ) (quoteAssetAmount sdk.Dec, err error) GetAllPools(ctx sdk.Context) []vpooltypes.Vpool + GetPool(ctx sdk.Context, pair asset.Pair) (vpooltypes.Vpool, error) IsOverSpreadLimit(ctx sdk.Context, pair asset.Pair) (bool, error) GetMaintenanceMarginRatio(ctx sdk.Context, pair asset.Pair) (sdk.Dec, error) - GetMaxLeverage(ctx sdk.Context, pair asset.Pair) (sdk.Dec, error) ExistsPool(ctx sdk.Context, pair asset.Pair) bool GetSettlementPrice(ctx sdk.Context, pair asset.Pair) (sdk.Dec, error) GetLastSnapshot(ctx sdk.Context, pool vpooltypes.Vpool) (vpooltypes.ReserveSnapshot, error) diff --git a/x/testutil/assertion/gas.go b/x/testutil/assertion/gas.go new file mode 100644 index 000000000..fc0e15de0 --- /dev/null +++ b/x/testutil/assertion/gas.go @@ -0,0 +1,27 @@ +package assertion + +import ( + "fmt" + + sdk "github.com/cosmos/cosmos-sdk/types" + + "github.com/NibiruChain/nibiru/app" + "github.com/NibiruChain/nibiru/x/testutil" +) + +type gasConsumedShouldBe struct { + gasConsumed uint64 +} + +func (g gasConsumedShouldBe) Do(_ *app.NibiruApp, ctx sdk.Context) (sdk.Context, error) { + gasUsed := ctx.GasMeter().GasConsumed() + if g.gasConsumed != gasUsed { + return ctx, fmt.Errorf("gas consumed should be %d, but got %d", g.gasConsumed, gasUsed) + } + + return ctx, nil +} + +func GasConsumedShouldBe(gasConsumed uint64) testutil.Action { + return &gasConsumedShouldBe{gasConsumed: gasConsumed} +} diff --git a/x/vpool/abci_test.go b/x/vpool/abci_test.go index bb69e6737..890183cb4 100644 --- a/x/vpool/abci_test.go +++ b/x/vpool/abci_test.go @@ -53,9 +53,11 @@ func TestSnapshotUpdates(t *testing.T) { assert.EqualValues(t, expectedSnapshot, snapshot) t.Log("affect mark price") - baseAmtAbs, err := vpoolKeeper.SwapQuoteForBase( + vpool, err := vpoolKeeper.GetPool(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)) + require.NoError(t, err) + _, baseAmtAbs, err := vpoolKeeper.SwapQuoteForBase( ctx, - asset.Registry.Pair(denoms.BTC, denoms.NUSD), + vpool, types.Direction_ADD_TO_POOL, sdk.NewDec(250), // ← dyAmm sdk.ZeroDec(), diff --git a/x/vpool/keeper/keeper.go b/x/vpool/keeper/keeper.go index eb445416c..40db877e2 100644 --- a/x/vpool/keeper/keeper.go +++ b/x/vpool/keeper/keeper.go @@ -65,48 +65,43 @@ ret: */ func (k Keeper) SwapBaseForQuote( ctx sdk.Context, - pair asset.Pair, + pool types.Vpool, dir types.Direction, baseAmt sdk.Dec, quoteLimit sdk.Dec, skipFluctuationLimitCheck bool, -) (quoteAmtAbs sdk.Dec, err error) { +) (updatedVpool types.Vpool, quoteAmtAbs sdk.Dec, err error) { if baseAmt.IsZero() { - return sdk.ZeroDec(), nil + return pool, sdk.ZeroDec(), nil } - if _, err = k.oracleKeeper.GetExchangeRate(ctx, pair); err != nil { - return sdk.Dec{}, types.ErrNoValidPrice.Wrapf("%s", pair) - } - - pool, err := k.Pools.Get(ctx, pair) - if err != nil { - return sdk.Dec{}, types.ErrPairNotSupported + if _, err = k.oracleKeeper.GetExchangeRate(ctx, pool.Pair); err != nil { + return pool, sdk.Dec{}, types.ErrNoValidPrice.Wrapf("%s", pool.Pair) } baseAmtAbs := baseAmt.Abs() quoteAmtAbs, err = pool.GetQuoteAmountByBaseAmount(baseAmtAbs.MulInt64(dir.ToMultiplier())) if err != nil { - return sdk.Dec{}, err + return pool, sdk.Dec{}, err } if err := pool.HasEnoughReservesForTrade(quoteAmtAbs, baseAmtAbs); err != nil { - return sdk.Dec{}, err + return pool, sdk.Dec{}, err } if err := checkIfLimitIsViolated(quoteLimit, quoteAmtAbs, dir); err != nil { - return sdk.Dec{}, err + return pool, sdk.Dec{}, err } quoteDelta := quoteAmtAbs.Neg().MulInt64(dir.ToMultiplier()) baseAmt = baseAmtAbs.MulInt64(dir.ToMultiplier()) - pool, err = k.executeSwap(ctx, pool, quoteDelta, baseAmt, skipFluctuationLimitCheck) + updatedVpool, err = k.executeSwap(ctx, pool, quoteDelta, baseAmt, skipFluctuationLimitCheck) if err != nil { - return sdk.Dec{}, fmt.Errorf("error updating reserve: %w", err) + return pool, sdk.Dec{}, fmt.Errorf("error updating reserve: %w", err) } - return quoteAmtAbs, err + return updatedVpool, quoteAmtAbs, err } func (k Keeper) executeSwap( @@ -162,50 +157,45 @@ ret: */ func (k Keeper) SwapQuoteForBase( ctx sdk.Context, - pair asset.Pair, + vpool types.Vpool, dir types.Direction, quoteAmt sdk.Dec, baseLimit sdk.Dec, skipFluctuationLimitCheck bool, -) (baseAmtAbs sdk.Dec, err error) { +) (updatedVpool types.Vpool, baseAmtAbs sdk.Dec, err error) { if quoteAmt.IsZero() { - return sdk.ZeroDec(), nil + return types.Vpool{}, sdk.ZeroDec(), nil } - if _, err = k.oracleKeeper.GetExchangeRate(ctx, pair); err != nil { - return sdk.Dec{}, types.ErrNoValidPrice.Wrapf("%s", pair) - } - - pool, err := k.Pools.Get(ctx, pair) - if err != nil { - return sdk.Dec{}, types.ErrPairNotSupported + if _, err = k.oracleKeeper.GetExchangeRate(ctx, vpool.Pair); err != nil { + return types.Vpool{}, sdk.Dec{}, types.ErrNoValidPrice.Wrapf("%s", vpool.Pair) } // check trade limit ratio on quote in either direction quoteAmtAbs := quoteAmt.Abs() - baseAmtAbs, err = pool.GetBaseAmountByQuoteAmount( + baseAmtAbs, err = vpool.GetBaseAmountByQuoteAmount( quoteAmtAbs.MulInt64(dir.ToMultiplier())) if err != nil { - return sdk.Dec{}, err + return types.Vpool{}, sdk.Dec{}, err } - if err := pool.HasEnoughReservesForTrade(quoteAmtAbs, baseAmtAbs); err != nil { - return sdk.Dec{}, err + if err := vpool.HasEnoughReservesForTrade(quoteAmtAbs, baseAmtAbs); err != nil { + return types.Vpool{}, sdk.Dec{}, err } if err := checkIfLimitIsViolated(baseLimit, baseAmtAbs, dir); err != nil { - return sdk.Dec{}, err + return types.Vpool{}, sdk.Dec{}, err } quoteAmt = quoteAmtAbs.MulInt64(dir.ToMultiplier()) baseDelta := baseAmtAbs.Neg().MulInt64(dir.ToMultiplier()) - pool, err = k.executeSwap(ctx, pool, quoteAmt, baseDelta, skipFluctuationLimitCheck) + updatedVpool, err = k.executeSwap(ctx, vpool, quoteAmt, baseDelta, skipFluctuationLimitCheck) if err != nil { - return sdk.Dec{}, fmt.Errorf("error updating reserve: %w", err) + return types.Vpool{}, sdk.Dec{}, fmt.Errorf("error updating reserve: %w", err) } - return baseAmtAbs, err + return updatedVpool, baseAmtAbs, err } // checkIfLimitIsViolated checks if the limit is violated by the amount. @@ -322,26 +312,6 @@ func (k Keeper) GetMaintenanceMarginRatio(ctx sdk.Context, pair asset.Pair) (sdk return pool.Config.MaintenanceMarginRatio, nil } -/* -GetMaxLeverage returns the maximum leverage required to open a position in the pool. - -args: - - ctx: the cosmos-sdk context - - pair: the asset pair - -ret: - - sdk.Dec: The maintenance margin ratio for the pool - - error -*/ -func (k Keeper) GetMaxLeverage(ctx sdk.Context, pair asset.Pair) (sdk.Dec, error) { - pool, err := k.Pools.Get(ctx, pair) - if err != nil { - return sdk.Dec{}, err - } - - return pool.Config.MaxLeverage, nil -} - /* GetAllPools returns an array of all the pools @@ -354,3 +324,7 @@ ret: func (k Keeper) GetAllPools(ctx sdk.Context) []types.Vpool { return k.Pools.Iterate(ctx, collections.Range[asset.Pair]{}).Values() } + +func (k Keeper) GetPool(ctx sdk.Context, pair asset.Pair) (types.Vpool, error) { + return k.Pools.Get(ctx, pair) +} diff --git a/x/vpool/keeper/keeper_test.go b/x/vpool/keeper/keeper_test.go index a7e0391a6..096327e06 100644 --- a/x/vpool/keeper/keeper_test.go +++ b/x/vpool/keeper/keeper_test.go @@ -68,16 +68,6 @@ func TestSwapQuoteForBase(t *testing.T) { expectedBaseReserve: sdk.MustNewDecFromStr("5050505.050505050505050505"), expectedBaseAmount: sdk.MustNewDecFromStr("50505.050505050505050505"), }, - { - name: "pair not supported", - pair: "abc:xyz", - direction: types.Direction_ADD_TO_POOL, - quoteAmount: sdk.NewDec(10), - baseLimit: sdk.NewDec(10), - skipFluctuationLimitCheck: false, - - expectedErr: types.ErrPairNotSupported, - }, { name: "base amount less than base limit in Long", pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD), @@ -185,10 +175,12 @@ func TestSwapQuoteForBase(t *testing.T) { MaxLeverage: sdk.MustNewDecFromStr("15"), }, )) + vpool, err := vpoolKeeper.GetPool(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)) + require.NoError(t, err) - baseAmt, err := vpoolKeeper.SwapQuoteForBase( + _, baseAmt, err := vpoolKeeper.SwapQuoteForBase( ctx, - tc.pair, + vpool, tc.direction, tc.quoteAmount, tc.baseLimit, @@ -261,16 +253,6 @@ func TestSwapBaseForQuote(t *testing.T) { expectedBaseReserve: sdk.NewDec(4_900_000), expectedQuoteAssetAmount: sdk.MustNewDecFromStr("204081.632653061224489796"), }, - { - name: "pair not supported", - pair: "abc:xyz", - direction: types.Direction_ADD_TO_POOL, - baseAmt: sdk.NewDec(10), - quoteLimit: sdk.NewDec(10), - skipFluctuationLimitCheck: false, - - expectedErr: types.ErrPairNotSupported, - }, { name: "quote amount less than quote limit in Long", pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD), @@ -363,7 +345,8 @@ func TestSwapBaseForQuote(t *testing.T) { pfKeeper := mock.NewMockOracleKeeper(gomock.NewController(t)) vpoolKeeper, ctx := VpoolKeeper(t, pfKeeper) - pfKeeper.EXPECT().GetExchangeRate(gomock.Any(), gomock.Any()).Return(sdk.NewDec(1), nil).AnyTimes() + pfKeeper.EXPECT(). + GetExchangeRate(gomock.Any(), gomock.Any()).Return(sdk.NewDec(1), nil).AnyTimes() assert.NoError(t, vpoolKeeper.CreatePool( ctx, @@ -379,9 +362,11 @@ func TestSwapBaseForQuote(t *testing.T) { }, )) - quoteAssetAmount, err := vpoolKeeper.SwapBaseForQuote( + vpool, err := vpoolKeeper.GetPool(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)) + require.NoError(t, err) + _, quoteAssetAmount, err := vpoolKeeper.SwapBaseForQuote( ctx, - tc.pair, + vpool, tc.direction, tc.baseAmt, tc.quoteLimit, @@ -702,43 +687,3 @@ func TestGetMaintenanceMarginRatio(t *testing.T) { }) } } - -func TestGetMaxLeverage(t *testing.T) { - tests := []struct { - name string - pool types.Vpool - - expectedMaxLeverage sdk.Dec - }{ - { - name: "zero fluctuation limit ratio", - pool: types.Vpool{ - Pair: asset.Registry.Pair(denoms.BTC, denoms.NUSD), - QuoteAssetReserve: sdk.OneDec(), - BaseAssetReserve: sdk.OneDec(), - Config: types.VpoolConfig{ - FluctuationLimitRatio: sdk.ZeroDec(), - MaintenanceMarginRatio: sdk.MustNewDecFromStr("0.42"), - MaxLeverage: sdk.MustNewDecFromStr("15"), - MaxOracleSpreadRatio: sdk.OneDec(), - TradeLimitRatio: sdk.OneDec(), - }, - }, - expectedMaxLeverage: sdk.MustNewDecFromStr("15"), - }, - } - - for _, tc := range tests { - tc := tc - t.Run(tc.name, func(t *testing.T) { - vpoolKeeper, ctx := VpoolKeeper(t, - mock.NewMockOracleKeeper(gomock.NewController(t)), - ) - vpoolKeeper.Pools.Insert(ctx, tc.pool.Pair, tc.pool) - - maxLeverage, err := vpoolKeeper.GetMaxLeverage(ctx, asset.Registry.Pair(denoms.BTC, denoms.NUSD)) - assert.EqualValues(t, tc.expectedMaxLeverage, maxLeverage) - assert.NoError(t, err) - }) - } -} diff --git a/x/vpool/keeper/prices.go b/x/vpool/keeper/prices.go index 07716fd5a..acf9c50f2 100644 --- a/x/vpool/keeper/prices.go +++ b/x/vpool/keeper/prices.go @@ -55,17 +55,11 @@ ret: - err: error */ func (k Keeper) GetBaseAssetPrice( - ctx sdk.Context, - pair asset.Pair, + vpool types.Vpool, dir types.Direction, baseAssetAmount sdk.Dec, ) (quoteAmount sdk.Dec, err error) { - pool, err := k.Pools.Get(ctx, pair) - if err != nil { - return sdk.ZeroDec(), err - } - - return pool.GetQuoteAmountByBaseAmount(baseAssetAmount.MulInt64(dir.ToMultiplier())) + return vpool.GetQuoteAmountByBaseAmount(baseAssetAmount.MulInt64(dir.ToMultiplier())) } /* diff --git a/x/vpool/keeper/prices_test.go b/x/vpool/keeper/prices_test.go index d5d6b236e..8db271d27 100644 --- a/x/vpool/keeper/prices_test.go +++ b/x/vpool/keeper/prices_test.go @@ -137,7 +137,10 @@ func TestGetBaseAssetPrice(t *testing.T) { }, )) - quoteAmount, err := vpoolKeeper.GetBaseAssetPrice(ctx, tc.pair, tc.direction, tc.baseAmount) + vpool, err := vpoolKeeper.GetPool(ctx, tc.pair) + require.NoError(t, err) + + quoteAmount, err := vpoolKeeper.GetBaseAssetPrice(vpool, tc.direction, tc.baseAmount) if tc.expectedErr != nil { require.ErrorIs(t, err, tc.expectedErr, "expected error: %w, got: %w", tc.expectedErr, err) diff --git a/x/vpool/keeper/query_server.go b/x/vpool/keeper/query_server.go index c9bba8689..04246c0bb 100644 --- a/x/vpool/keeper/query_server.go +++ b/x/vpool/keeper/query_server.go @@ -82,9 +82,12 @@ func (q queryServer) BaseAssetPrice( ctx := sdk.UnwrapSDKContext(goCtx) + vpool, err := q.k.GetPool(ctx, req.Pair) + if err != nil { + return nil, types.ErrPairNotSupported + } priceInQuoteDenom, err := q.k.GetBaseAssetPrice( - ctx, - req.Pair, + vpool, req.Direction, req.BaseAssetAmount, )