diff --git a/pkg/api/controllers/balance_controller.go b/pkg/api/controllers/balance_controller.go index 250da6dee..76a7be329 100644 --- a/pkg/api/controllers/balance_controller.go +++ b/pkg/api/controllers/balance_controller.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "net/http" + "strings" "github.com/gin-gonic/gin" "github.com/numary/ledger/pkg/api/apierrors" @@ -22,7 +23,7 @@ func NewBalanceController() BalanceController { func (ctl *BalanceController) GetBalancesAggregated(c *gin.Context) { l, _ := c.Get("ledger") - balancesQuery := ledger.NewBalancesQuery(). + balancesQuery := ledger.NewAggregatedBalancesQuery(). WithAddressFilter(c.Query("address")) balances, err := l.(*ledger.Ledger).GetBalancesAggregated( c.Request.Context(), *balancesQuery) @@ -66,7 +67,7 @@ func (ctl *BalanceController) GetBalances(c *gin.Context) { balancesQuery = balancesQuery. WithOffset(token.Offset). WithAfterAddress(token.AfterAddress). - WithAddressFilter(token.AddressRegexpFilter). + WithAddressFilter(token.AddressRegexpFilter...). WithPageSize(token.PageSize) } else if c.Query(QueryKeyCursorDeprecated) != "" { @@ -96,7 +97,7 @@ func (ctl *BalanceController) GetBalances(c *gin.Context) { balancesQuery = balancesQuery. WithOffset(token.Offset). WithAfterAddress(token.AfterAddress). - WithAddressFilter(token.AddressRegexpFilter). + WithAddressFilter(token.AddressRegexpFilter...). WithPageSize(token.PageSize) } else { @@ -106,9 +107,15 @@ func (ctl *BalanceController) GetBalances(c *gin.Context) { return } + addresses := c.QueryArray("address") + allAddresses := make([]string, 0) + for _, address := range addresses { + allAddresses = append(allAddresses, strings.Split(address, ",")...) + } + balancesQuery = balancesQuery. WithAfterAddress(c.Query("after")). - WithAddressFilter(c.Query("address")). + WithAddressFilter(allAddresses...). WithPageSize(pageSize) } diff --git a/pkg/api/controllers/pagination_test.go b/pkg/api/controllers/pagination_test.go index 6044f1053..778e654a6 100644 --- a/pkg/api/controllers/pagination_test.go +++ b/pkg/api/controllers/pagination_test.go @@ -554,7 +554,7 @@ func TestCursor(t *testing.T) { res, err := base64.RawURLEncoding.DecodeString(cursor.Next) require.NoError(t, err) require.Equal(t, - `{"pageSize":3,"offset":3,"after":"accounts:15","address":"accounts:.*"}`, + `{"pageSize":3,"offset":3,"after":"accounts:15","address":["accounts:.*"]}`, string(res)) httpResponse = internal.GetBalances(api, url.Values{ @@ -566,12 +566,12 @@ func TestCursor(t *testing.T) { res, err = base64.RawURLEncoding.DecodeString(cursor.Previous) require.NoError(t, err) require.Equal(t, - `{"pageSize":3,"offset":0,"after":"accounts:15","address":"accounts:.*"}`, + `{"pageSize":3,"offset":0,"after":"accounts:15","address":["accounts:.*"]}`, string(res)) res, err = base64.RawURLEncoding.DecodeString(cursor.Next) require.NoError(t, err) require.Equal(t, - `{"pageSize":3,"offset":6,"after":"accounts:15","address":"accounts:.*"}`, + `{"pageSize":3,"offset":6,"after":"accounts:15","address":["accounts:.*"]}`, string(res)) }) diff --git a/pkg/ledger/ledger.go b/pkg/ledger/ledger.go index 3e1298c80..832248b30 100644 --- a/pkg/ledger/ledger.go +++ b/pkg/ledger/ledger.go @@ -157,7 +157,7 @@ func (l *Ledger) GetBalances(ctx context.Context, q BalancesQuery) (api.Cursor[c return l.store.GetBalances(ctx, q) } -func (l *Ledger) GetBalancesAggregated(ctx context.Context, q BalancesQuery) (core.AssetsBalances, error) { +func (l *Ledger) GetBalancesAggregated(ctx context.Context, q AggregatedBalancesQuery) (core.AssetsBalances, error) { return l.store.GetBalancesAggregated(ctx, q) } diff --git a/pkg/ledger/storage.go b/pkg/ledger/storage.go index a5d610f07..28674d602 100644 --- a/pkg/ledger/storage.go +++ b/pkg/ledger/storage.go @@ -19,7 +19,7 @@ type Store interface { CountAccounts(context.Context, AccountsQuery) (uint64, error) GetAccounts(context.Context, AccountsQuery) (api.Cursor[core.Account], error) GetBalances(context.Context, BalancesQuery) (api.Cursor[core.AccountsBalances], error) - GetBalancesAggregated(context.Context, BalancesQuery) (core.AssetsBalances, error) + GetBalancesAggregated(context.Context, AggregatedBalancesQuery) (core.AssetsBalances, error) GetLastLog(context.Context) (*core.Log, error) GetLogs(context.Context, *LogsQuery) (api.Cursor[core.Log], error) LoadMapping(context.Context) (*core.Mapping, error) @@ -228,7 +228,7 @@ type BalancesQuery struct { } type BalancesQueryFilters struct { - AddressRegexp string + AddressRegexp []string } func NewBalancesQuery() *BalancesQuery { @@ -249,7 +249,7 @@ func (b *BalancesQuery) WithOffset(offset uint) *BalancesQuery { return b } -func (b *BalancesQuery) WithAddressFilter(address string) *BalancesQuery { +func (b *BalancesQuery) WithAddressFilter(address ...string) *BalancesQuery { b.Filters.AddressRegexp = address return b @@ -260,6 +260,46 @@ func (b *BalancesQuery) WithPageSize(pageSize uint) *BalancesQuery { return b } +type AggregatedBalancesQuery struct { + PageSize uint + Offset uint + AfterAddress string + Filters AggregatedBalancesQueryFilters +} + +type AggregatedBalancesQueryFilters struct { + AddressRegexp string +} + +func NewAggregatedBalancesQuery() *AggregatedBalancesQuery { + return &AggregatedBalancesQuery{ + PageSize: QueryDefaultPageSize, + } +} + +func (b *AggregatedBalancesQuery) WithAfterAddress(after string) *AggregatedBalancesQuery { + b.AfterAddress = after + + return b +} + +func (b *AggregatedBalancesQuery) WithOffset(offset uint) *AggregatedBalancesQuery { + b.Offset = offset + + return b +} + +func (b *AggregatedBalancesQuery) WithAddressFilter(address string) *AggregatedBalancesQuery { + b.Filters.AddressRegexp = address + + return b +} + +func (b *AggregatedBalancesQuery) WithPageSize(pageSize uint) *AggregatedBalancesQuery { + b.PageSize = pageSize + return b +} + type LogsQuery struct { AfterID uint64 PageSize uint diff --git a/pkg/storage/sqlstorage/balances.go b/pkg/storage/sqlstorage/balances.go index d480abb59..29a154b84 100644 --- a/pkg/storage/sqlstorage/balances.go +++ b/pkg/storage/sqlstorage/balances.go @@ -15,13 +15,13 @@ import ( "github.com/numary/ledger/pkg/ledger" ) -func (s *Store) GetBalancesAggregated(ctx context.Context, q ledger.BalancesQuery) (core.AssetsBalances, error) { +func (s *Store) GetBalancesAggregated(ctx context.Context, q ledger.AggregatedBalancesQuery) (core.AssetsBalances, error) { sb := sqlbuilder.NewSelectBuilder() sb.Select("asset", "sum(input - output)") sb.From(s.schema.Table("volumes")) sb.GroupBy("asset") - if q.Filters.AddressRegexp != "" { + if len(q.Filters.AddressRegexp) > 0 { switch s.Schema().Flavor() { case sqlbuilder.PostgreSQL: src := strings.Split(q.Filters.AddressRegexp, ":") @@ -116,27 +116,43 @@ func (s *Store) GetBalances(ctx context.Context, q ledger.BalancesQuery) (api.Cu t.AfterAddress = q.AfterAddress } - if q.Filters.AddressRegexp != "" { + if len(q.Filters.AddressRegexp) > 0 { switch s.Schema().Flavor() { case sqlbuilder.PostgreSQL: - src := strings.Split(q.Filters.AddressRegexp, ":") - if q.Filters.AddressRegexp[len(q.Filters.AddressRegexp)-2:] != ".*" { - sb.Where(fmt.Sprintf("jsonb_array_length(account_json) = %d", len(src))) - } else { - src[len(src)-1] = src[len(src)-1][:len(src[len(src)-1])-2] - } - - for i, segment := range src { - if segment == ".*" || segment == "*" || segment == "" { - continue + if len(q.Filters.AddressRegexp) == 1 { + src := strings.Split(q.Filters.AddressRegexp[0], ":") + if q.Filters.AddressRegexp[0][len(q.Filters.AddressRegexp[0])-2:] != ".*" { + sb.Where(fmt.Sprintf("jsonb_array_length(account_json) = %d", len(src))) + } else { + src[len(src)-1] = src[len(src)-1][:len(src[len(src)-1])-2] } - arg := sb.Args.Add(strings.ReplaceAll(segment, "\\", "\\\\")) - sb.Where(fmt.Sprintf("account_json @@ ('$[%d] like_regex \"' || %s::text || '\"')::jsonpath", i, arg)) + for i, segment := range src { + if segment == ".*" || segment == "*" || segment == "" { + continue + } + + arg := sb.Args.Add(strings.ReplaceAll(segment, "\\", "\\\\")) + sb.Where(fmt.Sprintf("account_json @@ ('$[%d] like_regex \"' || %s::text || '\"')::jsonpath", i, arg)) + } + } else { + asAnys := make([]any, 0) + for _, v := range q.Filters.AddressRegexp { + asAnys = append(asAnys, v) + } + sb = sb.Where(sb.In("account", asAnys...)) } case sqlbuilder.SQLite: - arg := sb.Args.Add("^" + q.Filters.AddressRegexp + "$") - sb.Where("account REGEXP " + arg) + if len(q.Filters.AddressRegexp) == 1 { + arg := sb.Args.Add("^" + q.Filters.AddressRegexp[0] + "$") + sb.Where("account REGEXP " + arg) + } else { + asAnys := make([]any, 0) + for _, v := range q.Filters.AddressRegexp { + asAnys = append(asAnys, v) + } + sb = sb.Where(sb.In("account", asAnys...)) + } } t.AddressRegexpFilter = q.Filters.AddressRegexp } diff --git a/pkg/storage/sqlstorage/balances_test.go b/pkg/storage/sqlstorage/balances_test.go index 187736b74..f70cae53d 100644 --- a/pkg/storage/sqlstorage/balances_test.go +++ b/pkg/storage/sqlstorage/balances_test.go @@ -45,6 +45,33 @@ func testGetBalances(t *testing.T, store *sqlstorage.Store) { }, cursor.Data) }) + t.Run("on 2 accounts", func(t *testing.T) { + cursor, err := store.GetBalances(context.Background(), + ledger.BalancesQuery{ + Filters: ledger.BalancesQueryFilters{ + AddressRegexp: []string{"central_bank", "users:1"}, + }, + PageSize: 10, + }) + assert.NoError(t, err) + assert.Equal(t, 10, cursor.PageSize) + assert.Equal(t, false, cursor.HasMore) + assert.Equal(t, "", cursor.Previous) + assert.Equal(t, "", cursor.Next) + assert.Equal(t, []core.AccountsBalances{ + { + "users:1": core.AssetsBalances{ + "USD": core.NewMonetaryInt(1), + }, + }, + { + "central_bank": core.AssetsBalances{ + "USD": core.NewMonetaryInt(199), + }, + }, + }, cursor.Data) + }) + t.Run("limit", func(t *testing.T) { cursor, err := store.GetBalances(context.Background(), ledger.BalancesQuery{ @@ -114,7 +141,7 @@ func testGetBalances(t *testing.T, store *sqlstorage.Store) { ledger.BalancesQuery{ PageSize: 10, AfterAddress: "world", - Filters: ledger.BalancesQueryFilters{AddressRegexp: "users:.+"}, + Filters: ledger.BalancesQueryFilters{AddressRegexp: []string{"users:.+"}}, }) assert.NoError(t, err) assert.Equal(t, 10, cursor.PageSize) @@ -135,7 +162,7 @@ func testGetBalancesAggregated(t *testing.T, store *sqlstorage.Store) { err := store.Commit(context.Background(), tx1, tx2, tx3) assert.NoError(t, err) - q := ledger.BalancesQuery{ + q := ledger.AggregatedBalancesQuery{ PageSize: 10, } cursor, err := store.GetBalancesAggregated(context.Background(), q) diff --git a/pkg/storage/sqlstorage/pagination.go b/pkg/storage/sqlstorage/pagination.go index d542741e8..be9bd0255 100644 --- a/pkg/storage/sqlstorage/pagination.go +++ b/pkg/storage/sqlstorage/pagination.go @@ -32,7 +32,7 @@ type BalancesPaginationToken struct { PageSize uint `json:"pageSize"` Offset uint `json:"offset"` AfterAddress string `json:"after,omitempty"` - AddressRegexpFilter string `json:"address,omitempty"` + AddressRegexpFilter[] string `json:"address,omitempty"` } type LogsPaginationToken struct {