Skip to content

Commit

Permalink
fix: duplicate entries on GET /accounts when filtering on balances (#481
Browse files Browse the repository at this point in the history
)

* fix: duplicate entries

* fix: duplicate entries in GET /accounts when filtering by balances

Also add 'balanceAsset' query param on GET /accounts to allow specifiying the asset to filter.
  • Loading branch information
gfyrag authored Feb 29, 2024
1 parent 187ea53 commit ad9eba2
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 14 deletions.
1 change: 1 addition & 0 deletions pkg/api/controllers/account_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ func (ctl *AccountController) GetAccounts(c *gin.Context) {
WithAddressFilter(c.Query("address")).
WithBalanceFilter(balance).
WithBalanceOperatorFilter(balanceOperator).
WithBalanceAssetFilter(c.Query("balanceAsset")).
WithMetadataFilter(c.QueryMap("metadata")).
WithPageSize(pageSize)
}
Expand Down
42 changes: 35 additions & 7 deletions pkg/api/controllers/account_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ func TestGetAccounts(t *testing.T) {
}, false)
require.Equal(t, http.StatusOK, rsp.Result().StatusCode)

rsp = internal.PostTransaction(t, api, controllers.PostTransaction{
Postings: core.Postings{
{
Source: "world",
Destination: "fred",
Amount: core.NewMonetaryInt(10),
Asset: "EUR",
},
},
}, false)
require.Equal(t, http.StatusOK, rsp.Result().StatusCode)

meta := core.Metadata{
"roles": "admin",
"accountId": float64(3),
Expand All @@ -67,16 +79,17 @@ func TestGetAccounts(t *testing.T) {

rsp = internal.CountAccounts(api, url.Values{})
require.Equal(t, http.StatusOK, rsp.Result().StatusCode)
require.Equal(t, "3", rsp.Header().Get("Count"))
require.Equal(t, "4", rsp.Header().Get("Count"))

t.Run("all", func(t *testing.T) {
rsp = internal.GetAccounts(api, url.Values{})
assert.Equal(t, http.StatusOK, rsp.Result().StatusCode)
cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body)
// 3 accounts: world, bob, alice
assert.Len(t, cursor.Data, 3)
assert.Len(t, cursor.Data, 4)
assert.Equal(t, []core.Account{
{Address: "world", Metadata: core.Metadata{}},
{Address: "fred", Metadata: core.Metadata{}},
{Address: "bob", Metadata: meta},
{Address: "alice", Metadata: core.Metadata{}},
}, cursor.Data)
Expand Down Expand Up @@ -261,6 +274,18 @@ func TestGetAccounts(t *testing.T) {
assert.Equal(t, "alice", string(cursor.Data[0].Address))
})

t.Run("filter by balance >= 0 and asset specified", func(t *testing.T) {
rsp = internal.GetAccounts(api, url.Values{
"balanceAsset": []string{"EUR"},
"balance": []string{"0"},
controllers.QueryKeyBalanceOperator: []string{"gte"},
})
assert.Equal(t, http.StatusOK, rsp.Result().StatusCode)
cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body)
assert.Len(t, cursor.Data, 1)
assert.Equal(t, "fred", string(cursor.Data[0].Address))
})

t.Run("filter by balance > 120", func(t *testing.T) {
rsp = internal.GetAccounts(api, url.Values{
"balance": []string{"120"},
Expand Down Expand Up @@ -290,8 +315,9 @@ func TestGetAccounts(t *testing.T) {
})
assert.Equal(t, http.StatusOK, rsp.Result().StatusCode)
cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body)
assert.Len(t, cursor.Data, 1)
assert.Len(t, cursor.Data, 2)
assert.Equal(t, "world", string(cursor.Data[0].Address))
assert.Equal(t, "fred", string(cursor.Data[1].Address))
})

t.Run("filter by balance <= 100", func(t *testing.T) {
Expand All @@ -301,9 +327,10 @@ func TestGetAccounts(t *testing.T) {
})
assert.Equal(t, http.StatusOK, rsp.Result().StatusCode)
cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body)
assert.Len(t, cursor.Data, 2)
assert.Len(t, cursor.Data, 3)
assert.Equal(t, "world", string(cursor.Data[0].Address))
assert.Equal(t, "bob", string(cursor.Data[1].Address))
assert.Equal(t, "fred", string(cursor.Data[1].Address))
assert.Equal(t, "bob", string(cursor.Data[2].Address))
})

t.Run("filter by balance = 100", func(t *testing.T) {
Expand All @@ -325,9 +352,10 @@ func TestGetAccounts(t *testing.T) {
})
assert.Equal(t, http.StatusOK, rsp.Result().StatusCode)
cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body)
assert.Len(t, cursor.Data, 2)
assert.Len(t, cursor.Data, 3)
assert.Equal(t, "world", string(cursor.Data[0].Address))
assert.Equal(t, "alice", string(cursor.Data[1].Address))
assert.Equal(t, "fred", string(cursor.Data[1].Address))
assert.Equal(t, "alice", string(cursor.Data[2].Address))
})

t.Run("invalid balance", func(t *testing.T) {
Expand Down
5 changes: 5 additions & 0 deletions pkg/api/controllers/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ paths:
type: integer
format: int64
example: 2400
- name: balanceAsset
in: query
description: Filter accounts by their balance asset
schema:
type: string
- name: balanceOperator
x-speakeasy-ignore: true
in: query
Expand Down
7 changes: 7 additions & 0 deletions pkg/ledger/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ type AccountsQueryFilters struct {
Balance string
BalanceOperator BalanceOperator
Metadata map[string]string
BalanceAsset string
}

type BalanceOperator string
Expand Down Expand Up @@ -220,6 +221,12 @@ func (a *AccountsQuery) WithMetadataFilter(metadata map[string]string) *Accounts
return a
}

func (a *AccountsQuery) WithBalanceAssetFilter(value string) *AccountsQuery {
a.Filters.BalanceAsset = value

return a
}

type BalancesQuery struct {
PageSize uint
Offset uint
Expand Down
5 changes: 5 additions & 0 deletions pkg/storage/sqlstorage/accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ func (s *Store) buildAccountsQuery(p ledger.AccountsQuery) (*sqlbuilder.SelectBu
sb := sqlbuilder.NewSelectBuilder()
t := AccPaginationToken{}
sb.From(s.schema.Table("accounts"))
sb.Distinct()

var (
address = p.Filters.Address
metadata = p.Filters.Metadata
balance = p.Filters.Balance
balanceAsset = p.Filters.BalanceAsset
balanceOperator = p.Filters.BalanceOperator
)

Expand Down Expand Up @@ -78,6 +80,9 @@ func (s *Store) buildAccountsQuery(p ledger.AccountsQuery) (*sqlbuilder.SelectBu

if balance != "" {
sb.Join(s.schema.Table("volumes"), "accounts.address = volumes.account")
if balanceAsset != "" {
sb = sb.Where(sb.E("volumes.asset", balanceAsset))
}
balanceOperation := "volumes.input - volumes.output"

balanceValue, err := strconv.ParseInt(balance, 10, 0)
Expand Down
97 changes: 90 additions & 7 deletions pkg/storage/sqlstorage/accounts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,72 @@ package sqlstorage_test

import (
"context"
"testing"

"github.com/numary/ledger/pkg/core"
"github.com/numary/ledger/pkg/ledger"
"github.com/numary/ledger/pkg/storage/sqlstorage"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
)

func testAccounts(t *testing.T, store *sqlstorage.Store) {

err := store.Commit(context.Background(),
core.ExpandedTransaction{
Transaction: core.Transaction{
TransactionData: core.TransactionData{
Postings: []core.Posting{
{
Source: "world",
Destination: "us_bank",
Amount: core.NewMonetaryInt(100),
Asset: "USD/2",
},
{
Source: "world",
Destination: "eu_bank",
Amount: core.NewMonetaryInt(100),
Asset: "EUR/2",
},
},
},
},
PreCommitVolumes: map[string]core.AssetsVolumes{
"world": map[string]core.Volumes{
"USD/2": {},
"EUR/2": {},
},
"us_bank": map[string]core.Volumes{
"USD/2": {},
},
"eu_bank": map[string]core.Volumes{
"EUR/2": {},
},
},
PostCommitVolumes: map[string]core.AssetsVolumes{
"world": map[string]core.Volumes{
"USD/2": {
Output: core.NewMonetaryInt(100),
},
"EUR/2": {
Output: core.NewMonetaryInt(100),
},
},
"us_bank": map[string]core.Volumes{
"USD/2": {
Input: core.NewMonetaryInt(100),
},
},
"eu_bank": map[string]core.Volumes{
"EUR/2": {
Input: core.NewMonetaryInt(100),
},
},
},
},
)
require.NoError(t, err)

t.Run("success balance", func(t *testing.T) {
q := ledger.AccountsQuery{
PageSize: 10,
Expand All @@ -22,6 +79,35 @@ func testAccounts(t *testing.T, store *sqlstorage.Store) {
_, err := store.GetAccounts(context.Background(), q)
assert.NoError(t, err, "balance filter should not fail")
})
t.Run("filter balance when multiple assets match", func(t *testing.T) {
q := ledger.AccountsQuery{
PageSize: 10,
Filters: ledger.AccountsQueryFilters{
Balance: "0",
BalanceOperator: "lt",
},
}

accounts, err := store.GetAccounts(context.Background(), q)
require.NoError(t, err, "balance filter should not fail")
require.Len(t, accounts.Data, 1)
require.EqualValues(t, "world", accounts.Data[0].Address)
})
t.Run("filter balance when specifying asset", func(t *testing.T) {
q := ledger.AccountsQuery{
PageSize: 10,
Filters: ledger.AccountsQueryFilters{
Balance: "0",
BalanceOperator: "gt",
BalanceAsset: "USD/2",
},
}

accounts, err := store.GetAccounts(context.Background(), q)
require.NoError(t, err, "balance filter should not fail")
require.Len(t, accounts.Data, 1)
require.EqualValues(t, "us_bank", accounts.Data[0].Address)
})

t.Run("panic invalid balance", func(t *testing.T) {
q := ledger.AccountsQuery{
Expand Down Expand Up @@ -67,19 +153,16 @@ func testAccounts(t *testing.T, store *sqlstorage.Store) {
})

t.Run("success get accounts with address filters", func(t *testing.T) {
err := store.Commit(context.Background(), tx1, tx2, tx3, tx4)
assert.NoError(t, err)

q := ledger.AccountsQuery{
PageSize: 10,
Filters: ledger.AccountsQueryFilters{
Address: "users:1",
Address: "us_bank",
},
}

accounts, err := store.GetAccounts(context.Background(), q)
assert.NoError(t, err, "balance operator filter should not fail")
assert.Equal(t, len(accounts.Data), 1)
assert.Equal(t, accounts.Data[0].Address, core.AccountAddress("users:1"))
assert.Equal(t, accounts.Data[0].Address, core.AccountAddress("us_bank"))
})
}

0 comments on commit ad9eba2

Please sign in to comment.