Skip to content

Commit

Permalink
feat: make /balances accept multiple addresses (#450)
Browse files Browse the repository at this point in the history
  • Loading branch information
gfyrag authored Jul 27, 2023
1 parent 529ca39 commit b3773ef
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 31 deletions.
15 changes: 11 additions & 4 deletions pkg/api/controllers/balance_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"strings"

"github.com/gin-gonic/gin"
"github.com/numary/ledger/pkg/api/apierrors"
Expand All @@ -22,7 +23,7 @@ func NewBalanceController() BalanceController {
func (ctl *BalanceController) GetBalancesAggregated(c *gin.Context) {
l, _ := c.Get("ledger")

balancesQuery := ledger.NewBalancesQuery().
balancesQuery := ledger.NewAggregatedBalancesQuery().
WithAddressFilter(c.Query("address"))
balances, err := l.(*ledger.Ledger).GetBalancesAggregated(
c.Request.Context(), *balancesQuery)
Expand Down Expand Up @@ -66,7 +67,7 @@ func (ctl *BalanceController) GetBalances(c *gin.Context) {
balancesQuery = balancesQuery.
WithOffset(token.Offset).
WithAfterAddress(token.AfterAddress).
WithAddressFilter(token.AddressRegexpFilter).
WithAddressFilter(token.AddressRegexpFilter...).
WithPageSize(token.PageSize)

} else if c.Query(QueryKeyCursorDeprecated) != "" {
Expand Down Expand Up @@ -96,7 +97,7 @@ func (ctl *BalanceController) GetBalances(c *gin.Context) {
balancesQuery = balancesQuery.
WithOffset(token.Offset).
WithAfterAddress(token.AfterAddress).
WithAddressFilter(token.AddressRegexpFilter).
WithAddressFilter(token.AddressRegexpFilter...).
WithPageSize(token.PageSize)

} else {
Expand All @@ -106,9 +107,15 @@ func (ctl *BalanceController) GetBalances(c *gin.Context) {
return
}

addresses := c.QueryArray("address")
allAddresses := make([]string, 0)
for _, address := range addresses {
allAddresses = append(allAddresses, strings.Split(address, ",")...)
}

balancesQuery = balancesQuery.
WithAfterAddress(c.Query("after")).
WithAddressFilter(c.Query("address")).
WithAddressFilter(allAddresses...).
WithPageSize(pageSize)
}

Expand Down
6 changes: 3 additions & 3 deletions pkg/api/controllers/pagination_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ func TestCursor(t *testing.T) {
res, err := base64.RawURLEncoding.DecodeString(cursor.Next)
require.NoError(t, err)
require.Equal(t,
`{"pageSize":3,"offset":3,"after":"accounts:15","address":"accounts:.*"}`,
`{"pageSize":3,"offset":3,"after":"accounts:15","address":["accounts:.*"]}`,
string(res))

httpResponse = internal.GetBalances(api, url.Values{
Expand All @@ -566,12 +566,12 @@ func TestCursor(t *testing.T) {
res, err = base64.RawURLEncoding.DecodeString(cursor.Previous)
require.NoError(t, err)
require.Equal(t,
`{"pageSize":3,"offset":0,"after":"accounts:15","address":"accounts:.*"}`,
`{"pageSize":3,"offset":0,"after":"accounts:15","address":["accounts:.*"]}`,
string(res))
res, err = base64.RawURLEncoding.DecodeString(cursor.Next)
require.NoError(t, err)
require.Equal(t,
`{"pageSize":3,"offset":6,"after":"accounts:15","address":"accounts:.*"}`,
`{"pageSize":3,"offset":6,"after":"accounts:15","address":["accounts:.*"]}`,
string(res))
})

Expand Down
2 changes: 1 addition & 1 deletion pkg/ledger/ledger.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func (l *Ledger) GetBalances(ctx context.Context, q BalancesQuery) (api.Cursor[c
return l.store.GetBalances(ctx, q)
}

func (l *Ledger) GetBalancesAggregated(ctx context.Context, q BalancesQuery) (core.AssetsBalances, error) {
func (l *Ledger) GetBalancesAggregated(ctx context.Context, q AggregatedBalancesQuery) (core.AssetsBalances, error) {
return l.store.GetBalancesAggregated(ctx, q)
}

Expand Down
46 changes: 43 additions & 3 deletions pkg/ledger/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type Store interface {
CountAccounts(context.Context, AccountsQuery) (uint64, error)
GetAccounts(context.Context, AccountsQuery) (api.Cursor[core.Account], error)
GetBalances(context.Context, BalancesQuery) (api.Cursor[core.AccountsBalances], error)
GetBalancesAggregated(context.Context, BalancesQuery) (core.AssetsBalances, error)
GetBalancesAggregated(context.Context, AggregatedBalancesQuery) (core.AssetsBalances, error)
GetLastLog(context.Context) (*core.Log, error)
GetLogs(context.Context, *LogsQuery) (api.Cursor[core.Log], error)
LoadMapping(context.Context) (*core.Mapping, error)
Expand Down Expand Up @@ -228,7 +228,7 @@ type BalancesQuery struct {
}

type BalancesQueryFilters struct {
AddressRegexp string
AddressRegexp []string
}

func NewBalancesQuery() *BalancesQuery {
Expand All @@ -249,7 +249,7 @@ func (b *BalancesQuery) WithOffset(offset uint) *BalancesQuery {
return b
}

func (b *BalancesQuery) WithAddressFilter(address string) *BalancesQuery {
func (b *BalancesQuery) WithAddressFilter(address ...string) *BalancesQuery {
b.Filters.AddressRegexp = address

return b
Expand All @@ -260,6 +260,46 @@ func (b *BalancesQuery) WithPageSize(pageSize uint) *BalancesQuery {
return b
}

type AggregatedBalancesQuery struct {
PageSize uint
Offset uint
AfterAddress string
Filters AggregatedBalancesQueryFilters
}

type AggregatedBalancesQueryFilters struct {
AddressRegexp string
}

func NewAggregatedBalancesQuery() *AggregatedBalancesQuery {
return &AggregatedBalancesQuery{
PageSize: QueryDefaultPageSize,
}
}

func (b *AggregatedBalancesQuery) WithAfterAddress(after string) *AggregatedBalancesQuery {
b.AfterAddress = after

return b
}

func (b *AggregatedBalancesQuery) WithOffset(offset uint) *AggregatedBalancesQuery {
b.Offset = offset

return b
}

func (b *AggregatedBalancesQuery) WithAddressFilter(address string) *AggregatedBalancesQuery {
b.Filters.AddressRegexp = address

return b
}

func (b *AggregatedBalancesQuery) WithPageSize(pageSize uint) *AggregatedBalancesQuery {
b.PageSize = pageSize
return b
}

type LogsQuery struct {
AfterID uint64
PageSize uint
Expand Down
50 changes: 33 additions & 17 deletions pkg/storage/sqlstorage/balances.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ import (
"github.com/numary/ledger/pkg/ledger"
)

func (s *Store) GetBalancesAggregated(ctx context.Context, q ledger.BalancesQuery) (core.AssetsBalances, error) {
func (s *Store) GetBalancesAggregated(ctx context.Context, q ledger.AggregatedBalancesQuery) (core.AssetsBalances, error) {
sb := sqlbuilder.NewSelectBuilder()
sb.Select("asset", "sum(input - output)")
sb.From(s.schema.Table("volumes"))
sb.GroupBy("asset")

if q.Filters.AddressRegexp != "" {
if len(q.Filters.AddressRegexp) > 0 {
switch s.Schema().Flavor() {
case sqlbuilder.PostgreSQL:
src := strings.Split(q.Filters.AddressRegexp, ":")
Expand Down Expand Up @@ -116,27 +116,43 @@ func (s *Store) GetBalances(ctx context.Context, q ledger.BalancesQuery) (api.Cu
t.AfterAddress = q.AfterAddress
}

if q.Filters.AddressRegexp != "" {
if len(q.Filters.AddressRegexp) > 0 {
switch s.Schema().Flavor() {
case sqlbuilder.PostgreSQL:
src := strings.Split(q.Filters.AddressRegexp, ":")
if q.Filters.AddressRegexp[len(q.Filters.AddressRegexp)-2:] != ".*" {
sb.Where(fmt.Sprintf("jsonb_array_length(account_json) = %d", len(src)))
} else {
src[len(src)-1] = src[len(src)-1][:len(src[len(src)-1])-2]
}

for i, segment := range src {
if segment == ".*" || segment == "*" || segment == "" {
continue
if len(q.Filters.AddressRegexp) == 1 {
src := strings.Split(q.Filters.AddressRegexp[0], ":")
if q.Filters.AddressRegexp[0][len(q.Filters.AddressRegexp[0])-2:] != ".*" {
sb.Where(fmt.Sprintf("jsonb_array_length(account_json) = %d", len(src)))
} else {
src[len(src)-1] = src[len(src)-1][:len(src[len(src)-1])-2]
}

arg := sb.Args.Add(strings.ReplaceAll(segment, "\\", "\\\\"))
sb.Where(fmt.Sprintf("account_json @@ ('$[%d] like_regex \"' || %s::text || '\"')::jsonpath", i, arg))
for i, segment := range src {
if segment == ".*" || segment == "*" || segment == "" {
continue
}

arg := sb.Args.Add(strings.ReplaceAll(segment, "\\", "\\\\"))
sb.Where(fmt.Sprintf("account_json @@ ('$[%d] like_regex \"' || %s::text || '\"')::jsonpath", i, arg))
}
} else {
asAnys := make([]any, 0)
for _, v := range q.Filters.AddressRegexp {
asAnys = append(asAnys, v)
}
sb = sb.Where(sb.In("account", asAnys...))
}
case sqlbuilder.SQLite:
arg := sb.Args.Add("^" + q.Filters.AddressRegexp + "$")
sb.Where("account REGEXP " + arg)
if len(q.Filters.AddressRegexp) == 1 {
arg := sb.Args.Add("^" + q.Filters.AddressRegexp[0] + "$")
sb.Where("account REGEXP " + arg)
} else {
asAnys := make([]any, 0)
for _, v := range q.Filters.AddressRegexp {
asAnys = append(asAnys, v)
}
sb = sb.Where(sb.In("account", asAnys...))
}
}
t.AddressRegexpFilter = q.Filters.AddressRegexp
}
Expand Down
31 changes: 29 additions & 2 deletions pkg/storage/sqlstorage/balances_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,33 @@ func testGetBalances(t *testing.T, store *sqlstorage.Store) {
}, cursor.Data)
})

t.Run("on 2 accounts", func(t *testing.T) {
cursor, err := store.GetBalances(context.Background(),
ledger.BalancesQuery{
Filters: ledger.BalancesQueryFilters{
AddressRegexp: []string{"central_bank", "users:1"},
},
PageSize: 10,
})
assert.NoError(t, err)
assert.Equal(t, 10, cursor.PageSize)
assert.Equal(t, false, cursor.HasMore)
assert.Equal(t, "", cursor.Previous)
assert.Equal(t, "", cursor.Next)
assert.Equal(t, []core.AccountsBalances{
{
"users:1": core.AssetsBalances{
"USD": core.NewMonetaryInt(1),
},
},
{
"central_bank": core.AssetsBalances{
"USD": core.NewMonetaryInt(199),
},
},
}, cursor.Data)
})

t.Run("limit", func(t *testing.T) {
cursor, err := store.GetBalances(context.Background(),
ledger.BalancesQuery{
Expand Down Expand Up @@ -114,7 +141,7 @@ func testGetBalances(t *testing.T, store *sqlstorage.Store) {
ledger.BalancesQuery{
PageSize: 10,
AfterAddress: "world",
Filters: ledger.BalancesQueryFilters{AddressRegexp: "users:.+"},
Filters: ledger.BalancesQueryFilters{AddressRegexp: []string{"users:.+"}},
})
assert.NoError(t, err)
assert.Equal(t, 10, cursor.PageSize)
Expand All @@ -135,7 +162,7 @@ func testGetBalancesAggregated(t *testing.T, store *sqlstorage.Store) {
err := store.Commit(context.Background(), tx1, tx2, tx3)
assert.NoError(t, err)

q := ledger.BalancesQuery{
q := ledger.AggregatedBalancesQuery{
PageSize: 10,
}
cursor, err := store.GetBalancesAggregated(context.Background(), q)
Expand Down
2 changes: 1 addition & 1 deletion pkg/storage/sqlstorage/pagination.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type BalancesPaginationToken struct {
PageSize uint `json:"pageSize"`
Offset uint `json:"offset"`
AfterAddress string `json:"after,omitempty"`
AddressRegexpFilter string `json:"address,omitempty"`
AddressRegexpFilter[] string `json:"address,omitempty"`
}

type LogsPaginationToken struct {
Expand Down

0 comments on commit b3773ef

Please sign in to comment.