diff --git a/cmd/container.go b/cmd/container.go index 562e05f66..a8ce5c448 100644 --- a/cmd/container.go +++ b/cmd/container.go @@ -24,11 +24,13 @@ import ( "github.com/numary/ledger/pkg/api/middlewares" "github.com/numary/ledger/pkg/api/routes" "github.com/numary/ledger/pkg/bus" + "github.com/numary/ledger/pkg/contextlogger" "github.com/numary/ledger/pkg/ledger" "github.com/numary/ledger/pkg/redis" "github.com/numary/ledger/pkg/storage/sqlstorage" "github.com/sirupsen/logrus" "github.com/spf13/viper" + "github.com/uptrace/opentelemetry-go-extra/otellogrus" "github.com/xdg-go/scram" "go.opentelemetry.io/otel/trace" "go.uber.org/fx" @@ -43,12 +45,26 @@ func NewContainer(v *viper.Viper, userOptions ...fx.Option) *fx.App { options = append(options, fx.NopLogger) } + debug := viper.GetBool(debugFlag) + l := logrus.New() - if v.GetBool(debugFlag) { + if debug { l.Level = logrus.DebugLevel } - loggerFactory := logging.StaticLoggerFactory(logginglogrus.New(l)) - logging.SetFactory(loggerFactory) + if viper.GetBool(otlptraces.OtelTracesFlag) { + l.AddHook(otellogrus.NewHook(otellogrus.WithLevels( + logrus.PanicLevel, + logrus.FatalLevel, + logrus.ErrorLevel, + logrus.WarnLevel, + ))) + } + logging.SetFactory(contextlogger.NewFactory( + logging.StaticLoggerFactory(logginglogrus.New(l)), + )) + if debug { + sqlstorage.InstrumentalizeSQLDrivers() + } topics := v.GetStringSlice(publisherTopicMappingFlag) mapping := make(map[string]string) @@ -169,7 +185,7 @@ func NewContainer(v *viper.Viper, userOptions ...fx.Option) *fx.App { // Handle resolver options = append(options, ledger.ResolveModule( - v.GetInt64(numscriptCacheCapacity))) + v.GetInt64(cacheCapacityBytes), v.GetInt64(cacheMaxNumKeys))) // Api middlewares options = append(options, routes.ProvidePerLedgerMiddleware(func(tp trace.TracerProvider) []gin.HandlerFunc { diff --git a/cmd/container_test.go b/cmd/container_test.go index c19806b6d..c62a75a40 100644 --- a/cmd/container_test.go +++ b/cmd/container_test.go @@ -250,7 +250,8 @@ func TestContainers(t *testing.T) { // Default options v.Set(storageDriverFlag, sqlstorage.SQLite.String()) v.Set(storageDirFlag, "/tmp") - v.Set(numscriptCacheCapacity, 100) + v.Set(cacheCapacityBytes, 100000000) + v.Set(cacheMaxNumKeys, 100) //v.Set(storageSQLiteDBNameFlag, uuid.New()) tc.init(v) app := NewContainer(v, options...) diff --git a/cmd/root.go b/cmd/root.go index a000595c0..35a699d11 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -50,7 +50,8 @@ const ( commitPolicyFlag = "commit-policy" - numscriptCacheCapacity = "numscript-cache-capacity" + cacheCapacityBytes = "cache-capacity-bytes" + cacheMaxNumKeys = "cache-max-num-keys" ) var ( @@ -140,7 +141,9 @@ func NewRootCommand() *cobra.Command { root.PersistentFlags().Bool(authBearerUseScopesFlag, false, "Use scopes as defined by rfc https://datatracker.ietf.org/doc/html/rfc8693") root.PersistentFlags().String(commitPolicyFlag, "", "Transaction commit policy (default or allow-past-timestamps)") - root.PersistentFlags().Int(numscriptCacheCapacity, 100, "Capacity of the cache storing Numscript in RAM") + // 100 000 000 bytes is 100 MB + root.PersistentFlags().Int(cacheCapacityBytes, 100000000, "Capacity in bytes of the cache storing Numscript in RAM") + root.PersistentFlags().Int(cacheMaxNumKeys, 100, "Maximum number of Numscript to be stored in the cache in RAM") otlptraces.InitOTLPTracesFlags(root.PersistentFlags()) internal.InitHTTPBasicFlags(root) diff --git a/cmd/server_start.go b/cmd/server_start.go index b804ca7c4..ac9d624d8 100644 --- a/cmd/server_start.go +++ b/cmd/server_start.go @@ -6,13 +6,9 @@ import ( "net/http" "github.com/formancehq/go-libs/logging" - "github.com/formancehq/go-libs/logging/logginglogrus" - "github.com/formancehq/go-libs/otlp/otlptraces" "github.com/numary/ledger/pkg/api" - "github.com/sirupsen/logrus" "github.com/spf13/cobra" "github.com/spf13/viper" - "github.com/uptrace/opentelemetry-go-extra/otellogrus" "go.uber.org/fx" ) @@ -20,21 +16,6 @@ func NewServerStart() *cobra.Command { return &cobra.Command{ Use: "start", RunE: func(cmd *cobra.Command, args []string) error { - l := logrus.New() - if viper.GetBool(debugFlag) { - l.Level = logrus.DebugLevel - } - if viper.GetBool(otlptraces.OtelTracesFlag) { - l.AddHook(otellogrus.NewHook(otellogrus.WithLevels( - logrus.PanicLevel, - logrus.FatalLevel, - logrus.ErrorLevel, - logrus.WarnLevel, - ))) - } - loggerFactory := logging.StaticLoggerFactory(logginglogrus.New(l)) - logging.SetFactory(loggerFactory) - app := NewContainer( viper.GetViper(), fx.Invoke(func(lc fx.Lifecycle, h *api.API) { diff --git a/go.mod b/go.mod index 1a126455a..bfe7c7a8e 100755 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/numary/ledger go 1.18 require ( + github.com/DmitriyVTitov/size v1.5.0 github.com/Masterminds/semver/v3 v3.2.0 github.com/Shopify/sarama v1.37.2 github.com/ThreeDotsLabs/watermill v1.1.1 diff --git a/go.sum b/go.sum index 52970e308..50c31866f 100755 --- a/go.sum +++ b/go.sum @@ -42,6 +42,8 @@ github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg6 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= +github.com/DmitriyVTitov/size v1.5.0 h1:/PzqxYrOyOUX1BXj6J9OuVRVGe+66VL4D9FlUaW515g= +github.com/DmitriyVTitov/size v1.5.0/go.mod h1:le6rNI4CoLQV1b9gzp1+3d7hMAD/uu2QcJ+aYbNgiU0= github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= github.com/Masterminds/semver/v3 v3.2.0 h1:3MEsd0SM6jqZojhjLWWeBY+Kcjy9i6MQAeY7YgDP83g= github.com/Masterminds/semver/v3 v3.2.0/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= @@ -216,6 +218,8 @@ github.com/golang/glog v1.0.0/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0L github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= diff --git a/pkg/api/controllers/context_test.go b/pkg/api/controllers/context_test.go deleted file mode 100644 index 838dbb1ae..000000000 --- a/pkg/api/controllers/context_test.go +++ /dev/null @@ -1,62 +0,0 @@ -package controllers_test - -import ( - "context" - "net/url" - "testing" - - "github.com/google/uuid" - "github.com/numary/ledger/pkg" - "github.com/numary/ledger/pkg/api" - "github.com/numary/ledger/pkg/api/controllers" - "github.com/numary/ledger/pkg/api/internal" - "github.com/numary/ledger/pkg/core" - "github.com/stretchr/testify/require" - "go.uber.org/fx" -) - -func TestContext(t *testing.T) { - internal.RunTest(t, fx.Invoke(func(lc fx.Lifecycle, api *api.API) { - lc.Append(fx.Hook{ - OnStart: func(ctx context.Context) error { - t.Run("GET/stats", func(t *testing.T) { - rsp := internal.GetLedgerStats(api) - _, err := uuid.Parse(rsp.Header().Get(string(pkg.KeyContextID))) - require.NoError(t, err) - }) - t.Run("GET/logs", func(t *testing.T) { - rsp := internal.GetLedgerLogs(api, url.Values{}) - _, err := uuid.Parse(rsp.Header().Get(string(pkg.KeyContextID))) - require.NoError(t, err) - }) - t.Run("GET/accounts", func(t *testing.T) { - rsp := internal.GetAccounts(api, url.Values{}) - _, err := uuid.Parse(rsp.Header().Get(string(pkg.KeyContextID))) - require.NoError(t, err) - }) - t.Run("GET/transactions", func(t *testing.T) { - rsp := internal.GetTransactions(api, url.Values{}) - _, err := uuid.Parse(rsp.Header().Get(string(pkg.KeyContextID))) - require.NoError(t, err) - }) - t.Run("POST/transactions", func(t *testing.T) { - rsp := internal.PostTransaction(t, api, controllers.PostTransaction{}, true) - _, err := uuid.Parse(rsp.Header().Get(string(pkg.KeyContextID))) - require.NoError(t, err) - }) - t.Run("POST/transactions/batch", func(t *testing.T) { - rsp := internal.PostTransactionBatch(t, api, core.Transactions{}) - _, err := uuid.Parse(rsp.Header().Get(string(pkg.KeyContextID))) - require.NoError(t, err) - }) - t.Run("GET/balances", func(t *testing.T) { - rsp := internal.GetBalances(api, url.Values{}) - _, err := uuid.Parse(rsp.Header().Get(string(pkg.KeyContextID))) - require.NoError(t, err) - }) - - return nil - }, - }) - })) -} diff --git a/pkg/api/internal/testing.go b/pkg/api/internal/testing.go index 528658f09..6cfb68253 100644 --- a/pkg/api/internal/testing.go +++ b/pkg/api/internal/testing.go @@ -248,7 +248,8 @@ func RunTest(t *testing.T, options ...fx.Option) { options = append([]fx.Option{ api.Module(api.Config{StorageDriver: "sqlite", Version: "latest", UseScopes: true}), - ledger.ResolveModule(100), + // 100 000 000 bytes is 100 MB + ledger.ResolveModule(100000000, 100), ledgertesting.ProvideLedgerStorageDriver(), fx.Invoke(func(driver storage.Driver[ledger.Store], lc fx.Lifecycle) { lc.Append(fx.Hook{ diff --git a/pkg/api/middlewares/ledger_middleware.go b/pkg/api/middlewares/ledger_middleware.go index 705020b1c..c9a9a98c0 100644 --- a/pkg/api/middlewares/ledger_middleware.go +++ b/pkg/api/middlewares/ledger_middleware.go @@ -2,12 +2,9 @@ package middlewares import ( "context" - "fmt" + "net/http" - "github.com/formancehq/go-libs/logging" "github.com/gin-gonic/gin" - "github.com/google/uuid" - "github.com/numary/ledger/pkg" "github.com/numary/ledger/pkg/api/apierrors" "github.com/numary/ledger/pkg/contextlogger" "github.com/numary/ledger/pkg/ledger" @@ -28,36 +25,23 @@ func (m *LedgerMiddleware) LedgerMiddleware() gin.HandlerFunc { return func(c *gin.Context) { name := c.Param("ledger") if name == "" { + c.AbortWithStatus(http.StatusNotFound) return } - ctx, span := opentelemetry.Start(c.Request.Context(), "Ledger access") + span := opentelemetry.WrapGinContext(c, "Ledger access") defer span.End() - contextKeyID := uuid.NewString() - id := span.SpanContext().SpanID() - if id == [8]byte{} { - logging.GetLogger(ctx).Debugf( - "ledger middleware SpanID is empty, new id generated %s", contextKeyID) - } else { - contextKeyID = fmt.Sprint(id) - } - ctx = context.WithValue(ctx, pkg.KeyContextID, contextKeyID) - c.Header(string(pkg.KeyContextID), contextKeyID) - - loggerFactory := logging.StaticLoggerFactory( - contextlogger.New(ctx, logging.GetLogger(ctx))) - logging.SetFactory(loggerFactory) + contextlogger.WrapGinRequest(c) - l, err := m.resolver.GetLedger(ctx, name) + l, err := m.resolver.GetLedger(c.Request.Context(), name) if err != nil { apierrors.ResponseError(c, err) return } - defer l.Close(ctx) - c.Set("ledger", l) + defer l.Close(context.Background()) - c.Request = c.Request.WithContext(ctx) + c.Set("ledger", l) c.Next() } } diff --git a/pkg/api/middlewares/transaction.go b/pkg/api/middlewares/transaction.go index 01a7f53d6..70e973504 100644 --- a/pkg/api/middlewares/transaction.go +++ b/pkg/api/middlewares/transaction.go @@ -52,33 +52,39 @@ func newBufferedWriter(rw gin.ResponseWriter) *bufferedResponseWriter { func Transaction(locker Locker) func(c *gin.Context) { return func(c *gin.Context) { - ctx, span := opentelemetry.Start(c.Request.Context(), "Ledger locking") + ctx, span := opentelemetry.Start(c.Request.Context(), "Wait ledger lock") defer span.End() c.Request = c.Request.WithContext(ctx) - unlock, err := locker.Lock(c.Request.Context(), c.Param("ledger")) - if err != nil { - panic(err) - } - defer unlock(context.Background()) // Use a background context instead of the request one as it could have been cancelled - bufferedWriter := newBufferedWriter(c.Writer) - c.Request = c.Request.WithContext(storage.TransactionalContext(c.Request.Context())) c.Writer = bufferedWriter - defer func() { - _ = storage.RollbackTransaction(c.Request.Context()) - }() - - c.Next() - if c.Writer.Status() >= 200 && c.Writer.Status() < 300 && - storage.IsTransactionRegistered(c.Request.Context()) { - if err := storage.CommitTransaction(c.Request.Context()); err != nil { - apierrors.ResponseError(c, err) - return + func() { + unlock, err := locker.Lock(c.Request.Context(), c.Param("ledger")) + if err != nil { + panic(err) } - } + defer unlock(context.Background()) // Use a background context instead of the request one as it could have been cancelled + + ctx, span = opentelemetry.Start(c.Request.Context(), "Ledger locked") + defer span.End() + c.Request = c.Request.WithContext(ctx) + c.Request = c.Request.WithContext(storage.TransactionalContext(c.Request.Context())) + defer func() { + _ = storage.RollbackTransaction(c.Request.Context()) + }() + + c.Next() + + if c.Writer.Status() >= 200 && c.Writer.Status() < 300 && + storage.IsTransactionRegistered(c.Request.Context()) { + if err := storage.CommitTransaction(c.Request.Context()); err != nil { + apierrors.ResponseError(c, err) + return + } + } + }() if err := bufferedWriter.WriteResponse(); err != nil { _ = c.Error(err) diff --git a/pkg/context.go b/pkg/context.go deleted file mode 100644 index 0f149fe36..000000000 --- a/pkg/context.go +++ /dev/null @@ -1,5 +0,0 @@ -package pkg - -type ContextKeyIDType string - -var KeyContextID ContextKeyIDType = "contextID" diff --git a/pkg/contextlogger/contextlogger.go b/pkg/contextlogger/contextlogger.go index c111630f6..3b2fed6e9 100644 --- a/pkg/contextlogger/contextlogger.go +++ b/pkg/contextlogger/contextlogger.go @@ -4,67 +4,48 @@ import ( "context" "github.com/formancehq/go-libs/logging" - "github.com/numary/ledger/pkg" + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "go.opentelemetry.io/otel/trace" ) -var _ logging.Logger = &ContextLogger{} +type contextKey string -type ContextLogger struct { - ctx context.Context - underlyingLogger logging.Logger -} - -func New(ctx context.Context, logger logging.Logger) *ContextLogger { - return &ContextLogger{ - ctx: ctx, - underlyingLogger: logger, - } -} - -func (c ContextLogger) Debugf(format string, args ...any) { - id := c.ctx.Value(pkg.KeyContextID) - c.underlyingLogger. - WithFields(map[string]any{string(pkg.KeyContextID): id}). - Debugf(format, args...) -} +var loggerContextKey contextKey = "logger" -func (c ContextLogger) Infof(format string, args ...any) { - id := c.ctx.Value(pkg.KeyContextID) - c.underlyingLogger. - WithFields(map[string]any{string(pkg.KeyContextID): id}). - Infof(format, args...) +type Factory struct { + underlying logging.LoggerFactory } -func (c ContextLogger) Errorf(format string, args ...any) { - id := c.ctx.Value(pkg.KeyContextID) - c.underlyingLogger. - WithFields(map[string]any{string(pkg.KeyContextID): id}). - Errorf(format, args...) -} - -func (c ContextLogger) Debug(args ...any) { - c.underlyingLogger.Debug(args...) +func (c *Factory) Get(ctx context.Context) logging.Logger { + v := ctx.Value(loggerContextKey) + if v == nil { + return c.underlying.Get(ctx) + } + return v.(logging.Logger) } -func (c ContextLogger) Info(args ...any) { - c.underlyingLogger.Info(args...) +func NewFactory(underlyingFactory logging.LoggerFactory) *Factory { + return &Factory{ + underlying: underlyingFactory, + } } -func (c ContextLogger) Error(args ...any) { - c.underlyingLogger.Error(args...) -} +var _ logging.LoggerFactory = &Factory{} -func (c ContextLogger) WithFields(m map[string]any) logging.Logger { - m[string(pkg.KeyContextID)] = c.ctx.Value(pkg.KeyContextID) - return &ContextLogger{ - ctx: c.ctx, - underlyingLogger: c.underlyingLogger.WithFields(m), - } +func ContextWithLogger(ctx context.Context, logger logging.Logger) context.Context { + return context.WithValue(ctx, loggerContextKey, logger) } -func (c ContextLogger) WithContext(ctx context.Context) logging.Logger { - return &ContextLogger{ - ctx: ctx, - underlyingLogger: c.underlyingLogger, +func WrapGinRequest(c *gin.Context) { + span := trace.SpanFromContext(c.Request.Context()) + contextKeyID := uuid.NewString() + if span.SpanContext().SpanID().IsValid() { + contextKeyID = span.SpanContext().SpanID().String() } + c.Request = c.Request.WithContext( + ContextWithLogger(c.Request.Context(), logging.GetLogger(c.Request.Context()).WithFields(map[string]any{ + "contextID": contextKeyID, + })), + ) } diff --git a/pkg/core/numscript.go b/pkg/core/numscript.go index 530eb7b14..1a422d567 100644 --- a/pkg/core/numscript.go +++ b/pkg/core/numscript.go @@ -3,6 +3,7 @@ package core import ( "encoding/json" "fmt" + "sort" "strings" ) @@ -50,11 +51,21 @@ func TxsToScriptsData(txsData ...TransactionData) []ScriptData { } sb.WriteString("vars {\n") + accVars := make([]string, 0) for _, v := range accountsToVars { - sb.WriteString(fmt.Sprintf("\taccount $%s\n", v.name)) + accVars = append(accVars, v.name) } + sort.Strings(accVars) + for _, v := range accVars { + sb.WriteString(fmt.Sprintf("\taccount $%s\n", v)) + } + monVars := make([]string, 0) for _, v := range monetaryToVars { - sb.WriteString(fmt.Sprintf("\tmonetary $%s\n", v.name)) + monVars = append(monVars, v.name) + } + sort.Strings(monVars) + for _, v := range monVars { + sb.WriteString(fmt.Sprintf("\tmonetary $%s\n", v)) } sb.WriteString("}\n") diff --git a/pkg/ledger/cache.go b/pkg/ledger/cache.go new file mode 100644 index 000000000..41cabc0bf --- /dev/null +++ b/pkg/ledger/cache.go @@ -0,0 +1,19 @@ +package ledger + +import ( + "github.com/dgraph-io/ristretto" + "github.com/pkg/errors" +) + +func NewCache(capacityBytes, maxNumKeys int64, metrics bool) *ristretto.Cache { + cache, err := ristretto.NewCache(&ristretto.Config{ + NumCounters: maxNumKeys * 10, + MaxCost: capacityBytes, + BufferItems: 64, + Metrics: metrics, + }) + if err != nil { + panic(errors.Wrap(err, "creating cache")) + } + return cache +} diff --git a/pkg/ledger/executor.go b/pkg/ledger/executor.go index 150b31e69..063ee0b1e 100644 --- a/pkg/ledger/executor.go +++ b/pkg/ledger/executor.go @@ -7,16 +7,24 @@ import ( "fmt" "time" + "github.com/DmitriyVTitov/size" + "github.com/dgraph-io/ristretto" machine "github.com/formancehq/machine/core" "github.com/formancehq/machine/script/compiler" "github.com/formancehq/machine/vm" "github.com/formancehq/machine/vm/program" "github.com/numary/ledger/pkg/core" + "github.com/numary/ledger/pkg/opentelemetry" "github.com/numary/ledger/pkg/storage" "github.com/pkg/errors" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" ) func (l *Ledger) Execute(ctx context.Context, checkMapping, preview bool, scripts ...core.ScriptData) ([]core.ExpandedTransaction, error) { + ctx, span := opentelemetry.Start(ctx, "Execute") + defer span.End() + if len(scripts) == 0 { return []core.ExpandedTransaction{}, NewScriptError(ScriptErrorNoScript, "no script to execute") @@ -96,26 +104,13 @@ func (l *Ledger) Execute(ctx context.Context, checkMapping, preview bool, script "no script to execute") } - h := sha256.New() - if _, err = h.Write([]byte(script.Plain)); err != nil { - return []core.ExpandedTransaction{}, errors.Wrap(err, "hashing script") - } - curr := h.Sum(nil) - - var m *vm.Machine - if cachedP, found := l.cache.Get(curr); found { - m = vm.NewMachine(cachedP.(program.Program)) - } else { - newP, err := compiler.Compile(script.Plain) - if err != nil { - return []core.ExpandedTransaction{}, NewScriptError(ScriptErrorCompilationFailed, - err.Error()) - } - l.cache.Set(curr, *newP, 1) - m = vm.NewMachine(*newP) + m, err := NewMachineFromScript(script.Plain, l.cache, span) + if err != nil { + return []core.ExpandedTransaction{}, NewScriptError(ScriptErrorCompilationFailed, + err.Error()) } - if err = m.SetVarsFromJSON(script.Vars); err != nil { + if err := m.SetVarsFromJSON(script.Vars); err != nil { return []core.ExpandedTransaction{}, NewScriptError(ScriptErrorCompilationFailed, errors.Wrap(err, "could not set variables").Error()) } @@ -357,3 +352,30 @@ func (l *Ledger) Execute(ctx context.Context, checkMapping, preview bool, script return txs, nil } + +func NewMachineFromScript(script string, cache *ristretto.Cache, span trace.Span) (*vm.Machine, error) { + h := sha256.New() + if _, err := h.Write([]byte(script)); err != nil { + return nil, errors.Wrap(err, "hashing script") + } + curr := h.Sum(nil) + + if cachedProgram, found := cache.Get(curr); found { + span.SetAttributes(attribute.Bool("numscript-cache-hit", true)) + return vm.NewMachine(cachedProgram.(program.Program)), nil + } + + span.SetAttributes(attribute.Bool("numscript-cache-hit", false)) + prog, err := compiler.Compile(script) + if err != nil { + return nil, err + } + + progSizeBytes := size.Of(*prog) + if progSizeBytes == -1 { + return nil, fmt.Errorf("error while calculating the size in bytes of the program") + } + cache.Set(curr, *prog, int64(progSizeBytes)) + + return vm.NewMachine(*prog), nil +} diff --git a/pkg/ledger/executor_test.go b/pkg/ledger/executor_test.go index 7113788e3..75cf8db8f 100644 --- a/pkg/ledger/executor_test.go +++ b/pkg/ledger/executor_test.go @@ -2,14 +2,18 @@ package ledger_test import ( "context" + "crypto/sha256" "encoding/json" "fmt" "strconv" "testing" + "github.com/DmitriyVTitov/size" + "github.com/formancehq/machine/script/compiler" "github.com/numary/ledger/pkg/api/apierrors" "github.com/numary/ledger/pkg/core" "github.com/numary/ledger/pkg/ledger" + "github.com/numary/ledger/pkg/opentelemetry" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -750,16 +754,89 @@ func assertBalance(t *testing.T, l *ledger.Ledger, account, asset string, amount ) } +func TestNewMachineFromScript(t *testing.T) { + _, span := opentelemetry.Start(context.Background(), "TestNewMachineFromScript") + defer span.End() + + txData := core.TransactionData{} + for i := 0; i < nbPostings; i++ { + txData.Postings = append(txData.Postings, core.Posting{ + Source: "world", + Destination: "benchmarks:" + strconv.Itoa(i), + Asset: "COIN", + Amount: core.NewMonetaryInt(10), + }) + } + _, err := txData.Postings.Validate() + require.NoError(t, err) + scripts := core.TxsToScriptsData(txData) + script := scripts[0].Plain + + h := sha256.New() + _, err = h.Write([]byte(script)) + require.NoError(t, err) + key := h.Sum(nil) + keySizeBytes := size.Of(key) + require.NotEqual(t, -1, keySizeBytes) + + prog, err := compiler.Compile(script) + require.NoError(t, err) + progSizeBytes := size.Of(*prog) + require.NotEqual(t, -1, progSizeBytes) + + t.Run("exact size", func(t *testing.T) { + capacityBytes := int64(keySizeBytes + progSizeBytes) + + cache := ledger.NewCache(capacityBytes, 1, true) + + m, err := ledger.NewMachineFromScript(script, cache, span) + require.NoError(t, err) + require.NotNil(t, m) + cache.Wait() + require.Equal(t, uint64(0), cache.Metrics.Hits()) + require.Equal(t, uint64(1), cache.Metrics.Misses()) + require.Equal(t, uint64(1), cache.Metrics.KeysAdded()) + + m, err = ledger.NewMachineFromScript(script, cache, span) + require.NoError(t, err) + require.NotNil(t, m) + cache.Wait() + require.Equal(t, uint64(1), cache.Metrics.Hits()) + require.Equal(t, uint64(1), cache.Metrics.Misses()) + require.Equal(t, uint64(1), cache.Metrics.KeysAdded()) + }) + + t.Run("one byte too small", func(t *testing.T) { + capacityBytes := int64(keySizeBytes+progSizeBytes) - 1 + + cache := ledger.NewCache(capacityBytes, 1, true) + + m, err := ledger.NewMachineFromScript(script, cache, span) + require.NoError(t, err) + require.NotNil(t, m) + cache.Wait() + require.Equal(t, uint64(0), cache.Metrics.Hits()) + require.Equal(t, uint64(1), cache.Metrics.Misses()) + require.Equal(t, uint64(0), cache.Metrics.KeysAdded()) + + m, err = ledger.NewMachineFromScript(script, cache, span) + require.NoError(t, err) + require.NotNil(t, m) + cache.Wait() + require.Equal(t, uint64(0), cache.Metrics.Hits()) + require.Equal(t, uint64(2), cache.Metrics.Misses()) + require.Equal(t, uint64(0), cache.Metrics.KeysAdded()) + }) +} + var execRes []core.ExpandedTransaction -func BenchmarkLedger_PostTransactions(b *testing.B) { - runOnLedger(func(l *ledger.Ledger) { - defer func(l *ledger.Ledger, ctx context.Context) { - require.NoError(b, l.Close(ctx)) - }(l, context.Background()) +const nbPostings = 1000 +func BenchmarkLedger_PostTransactionsSingle(b *testing.B) { + runOnLedger(func(l *ledger.Ledger) { txData := core.TransactionData{} - for i := 0; i < 1000; i++ { + for i := 0; i < nbPostings; i++ { txData.Postings = append(txData.Postings, core.Posting{ Source: "world", Destination: "benchmarks:" + strconv.Itoa(i), @@ -779,12 +856,12 @@ func BenchmarkLedger_PostTransactions(b *testing.B) { res, err = l.Execute(context.Background(), true, true, script...) require.NoError(b, err) require.Len(b, res, 1) - require.Len(b, res[0].Postings, 1000) + require.Len(b, res[0].Postings, nbPostings) } execRes = res require.Len(b, execRes, 1) - require.Len(b, execRes[0].Postings, 1000) + require.Len(b, execRes[0].Postings, nbPostings) }) } @@ -907,10 +984,6 @@ func newTxsData(i int) []core.TransactionData { func BenchmarkLedger_PostTransactionsBatch(b *testing.B) { runOnLedger(func(l *ledger.Ledger) { - defer func(l *ledger.Ledger, ctx context.Context) { - require.NoError(b, l.Close(ctx)) - }(l, context.Background()) - txsData := newTxsData(1) b.ResetTimer() @@ -950,10 +1023,6 @@ func BenchmarkLedger_PostTransactionsBatch(b *testing.B) { func BenchmarkLedger_PostTransactionsBatch2(b *testing.B) { runOnLedger(func(l *ledger.Ledger) { - defer func(l *ledger.Ledger, ctx context.Context) { - require.NoError(b, l.Close(ctx)) - }(l, context.Background()) - b.ResetTimer() res := []core.ExpandedTransaction{} diff --git a/pkg/ledger/ledger.go b/pkg/ledger/ledger.go index 6653604e8..a4b80c37b 100644 --- a/pkg/ledger/ledger.go +++ b/pkg/ledger/ledger.go @@ -57,7 +57,6 @@ func (l *Ledger) Close(ctx context.Context) error { if err := l.store.Close(ctx); err != nil { return errors.Wrap(err, "closing store") } - l.cache.Close() return nil } diff --git a/pkg/ledger/ledger_test.go b/pkg/ledger/ledger_test.go index b9f71c28b..6b955ae76 100644 --- a/pkg/ledger/ledger_test.go +++ b/pkg/ledger/ledger_test.go @@ -10,14 +10,12 @@ import ( "testing" "time" - "github.com/dgraph-io/ristretto" "github.com/mitchellh/mapstructure" "github.com/numary/ledger/pkg/core" "github.com/numary/ledger/pkg/ledger" "github.com/numary/ledger/pkg/ledgertesting" "github.com/numary/ledger/pkg/storage" "github.com/pborman/uuid" - "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -71,21 +69,20 @@ func runOnLedger(f func(l *ledger.Ledger), ledgerOptions ...ledger.LedgerOption) if err != nil { return err } - cache, err := ristretto.NewCache(&ristretto.Config{ - - NumCounters: 1e7, // number of keys to track frequency of (10M). - MaxCost: 100, // maximum cost of cache. - BufferItems: 64, // number of keys per Get buffer. - }) - if err != nil { - panic(errors.Wrap(err, "creating ledger cache")) - } - l, err := ledger.NewLedger(store, ledger.NewNoOpMonitor(), cache, ledgerOptions...) + // 100 000 000 is 100MB + cache := ledger.NewCache(100000000, 100, true) + l, err := ledger.NewLedger(store, + ledger.NewNoOpMonitor(), + cache, + ledgerOptions...) if err != nil { panic(err) } lc.Append(fx.Hook{ - OnStop: l.Close, + OnStop: func(ctx context.Context) error { + cache.Close() + return l.Close(ctx) + }, }) f(l) return nil diff --git a/pkg/ledger/resolver.go b/pkg/ledger/resolver.go index 339ce91d1..b605ccb96 100644 --- a/pkg/ledger/resolver.go +++ b/pkg/ledger/resolver.go @@ -42,23 +42,14 @@ type Resolver struct { func NewResolver( storageFactory storage.Driver[Store], ledgerOptions []LedgerOption, - numscriptCacheCapacity int64, + cacheBytesCapacity, cacheMaxNumKeys int64, options ...ResolverOption, ) *Resolver { - cache, err := ristretto.NewCache(&ristretto.Config{ - NumCounters: 1e7, // number of keys to track frequency of (10M). - MaxCost: numscriptCacheCapacity, // maximum cost of cache. - BufferItems: 64, // number of keys per Get buffer. - }) - if err != nil { - panic(errors.Wrap(err, "creating ledger cache")) - } - options = append(DefaultResolverOptions, options...) r := &Resolver{ storageDriver: storageFactory, initializedStores: map[string]struct{}{}, - cache: cache, + cache: NewCache(cacheBytesCapacity, cacheMaxNumKeys, false), } for _, opt := range options { if err := opt.apply(r); err != nil { @@ -97,6 +88,10 @@ func (r *Resolver) GetLedger(ctx context.Context, name string) (*Ledger, error) return NewLedger(store, r.monitor, r.cache, r.ledgerOptions...) } +func (r *Resolver) Close() { + r.cache.Close() +} + const ResolverOptionsKey = `group:"_ledgerResolverOptions"` const ResolverLedgerOptionsKey = `name:"_ledgerResolverLedgerOptions"` @@ -106,12 +101,20 @@ func ProvideResolverOption(provider interface{}) fx.Option { ) } -func ResolveModule(numscriptCacheCapacity int64) fx.Option { +func ResolveModule(cacheBytesCapacity, cacheMaxNumKeys int64) fx.Option { return fx.Options( fx.Provide( fx.Annotate(func(storageFactory storage.Driver[Store], ledgerOptions []LedgerOption, options ...ResolverOption) *Resolver { - return NewResolver(storageFactory, ledgerOptions, numscriptCacheCapacity, options...) + return NewResolver(storageFactory, ledgerOptions, cacheBytesCapacity, cacheMaxNumKeys, options...) }, fx.ParamTags("", ResolverLedgerOptionsKey, ResolverOptionsKey)), ), + fx.Invoke(func(lc fx.Lifecycle, r *Resolver) { + lc.Append(fx.Hook{ + OnStop: func(ctx context.Context) error { + r.Close() + return nil + }, + }) + }), ) } diff --git a/pkg/opentelemetry/tracer.go b/pkg/opentelemetry/tracer.go index 27597d259..1c179ff6f 100644 --- a/pkg/opentelemetry/tracer.go +++ b/pkg/opentelemetry/tracer.go @@ -3,6 +3,7 @@ package opentelemetry import ( "context" + "github.com/gin-gonic/gin" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/trace" ) @@ -12,3 +13,9 @@ var Tracer = otel.Tracer("com.formance.ledger") func Start(ctx context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) { return Tracer.Start(ctx, name, opts...) } + +func WrapGinContext(ginContext *gin.Context, name string, opts ...trace.SpanStartOption) trace.Span { + ctx, span := Start(ginContext.Request.Context(), name, opts...) + ginContext.Request = ginContext.Request.WithContext(ctx) + return span +} diff --git a/pkg/storage/sqlstorage/driver.go b/pkg/storage/sqlstorage/driver.go index e0151fcbf..f0c11d9c7 100644 --- a/pkg/storage/sqlstorage/driver.go +++ b/pkg/storage/sqlstorage/driver.go @@ -33,37 +33,8 @@ func (d otelSQLDriverWithCheckNamedValueDisabled) CheckNamedValue(*driver.NamedV var _ = driver.NamedValueChecker(&otelSQLDriverWithCheckNamedValueDisabled{}) func UpdateSQLDriverMapping(flavor Flavor, name string) { - - // otelsql has a function Register which wrap the underlying driver, but does not mirror driver.NamedValuedChecker interface of the underlying driver - // pgx implements this interface and just return nil - // so, we need to manually wrap the driver to implements this interface and return a nil error - - db, err := sql.Open(name, "") - if err != nil { - panic(err) - } - - dri := db.Driver() - - if err = db.Close(); err != nil { - panic(err) - } - - wrappedDriver := otelsql.Wrap(dri, - otelsql.AllowRoot(), - otelsql.TraceQueryWithArgs(), - otelsql.TraceRowsAffected(), - otelsql.TraceRowsClose(), - otelsql.TraceRowsNext(), - ) - - driverName := fmt.Sprintf("otel-%s", name) - sql.Register(driverName, otelSQLDriverWithCheckNamedValueDisabled{ - wrappedDriver, - }) - cfg := sqlDrivers[flavor] - cfg.driverName = driverName + cfg.driverName = name sqlDrivers[flavor] = cfg } @@ -72,6 +43,35 @@ func init() { UpdateSQLDriverMapping(PostgreSQL, "pgx") } +func InstrumentalizeSQLDrivers() { + for flavor, config := range sqlDrivers { + // otelsql has a function Register which wrap the underlying driver, but does not mirror driver.NamedValuedChecker interface of the underlying driver + // pgx implements this interface and just return nil + // so, we need to manually wrap the driver to implements this interface and return a nil error + db, err := sql.Open(config.driverName, "") + if err != nil { + panic(err) + } + + dri := db.Driver() + + if err = db.Close(); err != nil { + panic(err) + } + + wrappedDriver := otelsql.Wrap(dri, + otelsql.AllowRoot(), + otelsql.TraceAll(), + ) + + config.driverName = fmt.Sprintf("otel-%s", config.driverName) + sql.Register(config.driverName, otelSQLDriverWithCheckNamedValueDisabled{ + wrappedDriver, + }) + sqlDrivers[flavor] = config + } +} + // defaultExecutorProvider use the context to register and manage a sql transaction (if the context is mark as transactional) func defaultExecutorProvider(schema Schema) func(ctx context.Context) (executor, error) { return func(ctx context.Context) (executor, error) { diff --git a/pkg/storage/sqlstorage/logs.go b/pkg/storage/sqlstorage/logs.go index 9b7a67c5b..bbde79ccd 100644 --- a/pkg/storage/sqlstorage/logs.go +++ b/pkg/storage/sqlstorage/logs.go @@ -9,7 +9,6 @@ import ( "time" "github.com/formancehq/go-libs/api" - "github.com/formancehq/go-libs/logging" "github.com/huandu/go-sqlbuilder" "github.com/numary/ledger/pkg/core" "github.com/numary/ledger/pkg/ledger" @@ -68,7 +67,6 @@ func (s *Store) appendLog(ctx context.Context, log ...core.Log) error { return err } - logging.GetLogger(ctx).Debugf("ExecContext: %s %s", query, args) _, err = executor.ExecContext(ctx, query, args...) if err != nil { return s.error(err) diff --git a/pkg/storage/sqlstorage/transactions.go b/pkg/storage/sqlstorage/transactions.go index 81362bf76..ae56e0d53 100644 --- a/pkg/storage/sqlstorage/transactions.go +++ b/pkg/storage/sqlstorage/transactions.go @@ -11,7 +11,6 @@ import ( "time" "github.com/formancehq/go-libs/api" - "github.com/formancehq/go-libs/logging" "github.com/huandu/go-sqlbuilder" "github.com/numary/ledger/pkg/core" "github.com/numary/ledger/pkg/ledger" @@ -473,14 +472,12 @@ func (s *Store) insertTransactions(ctx context.Context, txs ...core.ExpandedTran postingTxIds, postingIndices, sources, destinations, } - logging.GetLogger(ctx).Debugf("ExecContext: %s %s", queryPostings, argsPostings) _, err = executor.ExecContext(ctx, queryPostings, argsPostings...) if err != nil { return s.error(err) } } - logging.GetLogger(ctx).Debugf("ExecContext: %s %s", queryTxs, argsTxs) _, err = executor.ExecContext(ctx, queryTxs, argsTxs...) if err != nil { return s.error(err)