diff --git a/go.mod b/go.mod index f7555a5..fb0831f 100644 --- a/go.mod +++ b/go.mod @@ -66,6 +66,8 @@ require ( github.com/mimoo/StrobeGo v0.0.0-20181016162300-f8f6d4d2b643 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/relvacode/iso8601 v1.4.0 // indirect + github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect + github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect golang.org/x/sync v0.7.0 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect ) diff --git a/go.sum b/go.sum index 9ca9303..ba12e34 100644 --- a/go.sum +++ b/go.sum @@ -261,6 +261,10 @@ github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65E github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/ulule/limiter/v3 v3.11.2 h1:P4yOrxoEMJbOTfRJR2OzjL90oflzYPPmWg+dvwN2tHA= github.com/ulule/limiter/v3 v3.11.2/go.mod h1:QG5GnFOCV+k7lrL5Y8kgEeeflPH3+Cviqlqa8SVSQxI= +github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= +github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= diff --git a/pkg/api/utils.go b/pkg/api/utils.go index 11ba6ae..8b30d64 100644 --- a/pkg/api/utils.go +++ b/pkg/api/utils.go @@ -117,11 +117,11 @@ func getCallerIP(c *gin.Context) string { // TODO - Need to check if this is the correct way without getting spoofing if runtimeEnv := utils.LoadDotEnv("RUNTIME_ENV"); runtimeEnv == "aws" { callerIp := c.Request.Header.Get("X-Original-Forwarded-For") - log.Info().Msgf("Got caller IP from X-Original-Forwarded-For header: %s", callerIp) + log.Trace().Msgf("Got caller IP from X-Original-Forwarded-For header: %s", callerIp) return callerIp } callerIp := c.ClientIP() - log.Info().Msgf("Got caller IP from ClientIP: %s", callerIp) + log.Trace().Msgf("Got caller IP from ClientIP: %s", callerIp) return callerIp } diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index 93c88cd..928e3a1 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -11,6 +11,7 @@ import ( "dojo-api/utils" "github.com/redis/go-redis/v9" + "github.com/vmihailenco/msgpack/v5" "github.com/rs/zerolog/log" ) @@ -31,6 +32,56 @@ var ( mu sync.Mutex ) +// CacheKey type for type-safe cache keys +type CacheKey string + +const ( + // Task cache keys + TaskById CacheKey = "task" // Single task by ID + TasksByWorker CacheKey = "task:worker" // List of tasks by worker + + // Task Result cache keys + TaskResultByTaskAndWorker CacheKey = "tr:task:worker" // Task result by task ID and worker ID + TaskResultByWorker CacheKey = "tr:worker" // Task results by worker ID + + // Worker cache keys + WorkerByWallet CacheKey = "worker:wallet" // Worker by wallet address + WorkerCount CacheKey = "worker:count" // Total worker count + + // Subscription cache keys + SubByHotkey CacheKey = "sub:hotkey" // Subscription by hotkey + SubByKey CacheKey = "sub:key" // Subscription by key +) + +// CacheConfig defines cache keys and their expiration times +var CacheConfig = map[CacheKey]time.Duration{ + TaskById: 5 * time.Minute, + TasksByWorker: 2 * time.Minute, + TaskResultByTaskAndWorker: 10 * time.Minute, + TaskResultByWorker: 10 * time.Minute, + WorkerByWallet: 5 * time.Minute, + WorkerCount: 1 * time.Minute, + SubByHotkey: 5 * time.Minute, + SubByKey: 5 * time.Minute, +} + +// GetCacheExpiration returns the expiration time for a given cache key +func GetCacheExpiration(key CacheKey) time.Duration { + if duration, exists := CacheConfig[key]; exists { + return duration + } + return 5 * time.Minute // default expiration +} + +// BuildCacheKey builds a cache key with the given prefix and components +func BuildCacheKey(prefix CacheKey, components ...string) string { + key := string(prefix) + for _, component := range components { + key += ":" + component + } + return key +} + func GetCacheInstance() *Cache { once.Do(func() { mu.Lock() @@ -109,3 +160,30 @@ func (c *Cache) Shutdown() { c.Redis.Close() log.Info().Msg("Successfully closed Redis connection") } + +// GetCacheValue retrieves and unmarshals data from cache using MessagePack +func (c *Cache) GetCacheValue(key string, value interface{}) error { + cachedData, err := c.Get(key) + if err != nil || cachedData == "" { + return fmt.Errorf("cache miss for key: %s", key) + } + + log.Info().Msgf("Cache hit for key: %s", key) + return msgpack.Unmarshal([]byte(cachedData), value) +} + +// SetCacheValue marshals and stores data in cache using MessagePack +func (c *Cache) SetCacheValue(key string, value interface{}) error { + dataBytes, err := msgpack.Marshal(value) + if err != nil { + return fmt.Errorf("failed to marshal data: %w", err) + } + + expiration := GetCacheExpiration(CacheKey(key)) + if err := c.SetWithExpire(key, dataBytes, expiration); err != nil { + return fmt.Errorf("failed to set cache: %w", err) + } + + log.Info().Msgf("Successfully set cache for key: %s", key) + return nil +} diff --git a/pkg/orm/dojo_worker.go b/pkg/orm/dojo_worker.go index 2de7adb..6b5215f 100644 --- a/pkg/orm/dojo_worker.go +++ b/pkg/orm/dojo_worker.go @@ -2,11 +2,14 @@ package orm import ( "context" - "dojo-api/db" "errors" "fmt" "strconv" + "dojo-api/db" + + "dojo-api/pkg/cache" + "github.com/rs/zerolog/log" ) @@ -33,6 +36,17 @@ func (s *DojoWorkerORM) CreateDojoWorker(walletAddress string, chainId string) ( } func (s *DojoWorkerORM) GetDojoWorkerByWalletAddress(walletAddress string) (*db.DojoWorkerModel, error) { + cacheKey := cache.BuildCacheKey(cache.WorkerByWallet, walletAddress) + + var worker *db.DojoWorkerModel + cache := cache.GetCacheInstance() + + // Try to get from cache first + if err := cache.GetCacheValue(cacheKey, &worker); err == nil { + return worker, nil + } + + // Cache miss, fetch from database s.clientWrapper.BeforeQuery() defer s.clientWrapper.AfterQuery() @@ -47,10 +61,26 @@ func (s *DojoWorkerORM) GetDojoWorkerByWalletAddress(walletAddress string) (*db. } return nil, err } + + // Store in cache + if err := cache.SetCacheValue(cacheKey, worker); err != nil { + log.Warn().Err(err).Msg("Failed to set worker cache") + } + return worker, nil } func (s *DojoWorkerORM) GetDojoWorkers() (int, error) { + cacheKey := cache.BuildCacheKey(cache.WorkerCount, "") + var count int + cache := cache.GetCacheInstance() + + // Try to get from cache first + if err := cache.GetCacheValue(cacheKey, &count); err == nil { + return count, nil + } + + // Cache miss, fetch from database s.clientWrapper.BeforeQuery() defer s.clientWrapper.AfterQuery() @@ -70,10 +100,15 @@ func (s *DojoWorkerORM) GetDojoWorkers() (int, error) { } workerCountStr := string(result[0].Count) - workerCountInt, err := strconv.Atoi(workerCountStr) + count, err = strconv.Atoi(workerCountStr) if err != nil { return 0, err } - return workerCountInt, nil + // Store in cache + if err := cache.SetCacheValue(cacheKey, count); err != nil { + log.Warn().Err(err).Msg("Failed to set worker count cache") + } + + return count, nil } diff --git a/pkg/orm/subscriptionKey.go b/pkg/orm/subscriptionKey.go index 449181b..f646285 100644 --- a/pkg/orm/subscriptionKey.go +++ b/pkg/orm/subscriptionKey.go @@ -2,9 +2,11 @@ package orm import ( "context" - "dojo-api/db" "errors" + "dojo-api/db" + "dojo-api/pkg/cache" + "github.com/rs/zerolog/log" ) @@ -19,8 +21,15 @@ func NewSubscriptionKeyORM() *SubscriptionKeyORM { } func (a *SubscriptionKeyORM) GetSubscriptionKeysByMinerHotkey(hotkey string) ([]db.SubscriptionKeyModel, error) { - a.clientWrapper.BeforeQuery() - defer a.clientWrapper.AfterQuery() + cacheKey := cache.BuildCacheKey(cache.SubByHotkey, hotkey) + + var subKeys []db.SubscriptionKeyModel + cache := cache.GetCacheInstance() + + // Try to get from cache first + if err := cache.GetCacheValue(cacheKey, &subKeys); err == nil { + return subKeys, nil + } ctx := context.Background() @@ -41,6 +50,11 @@ func (a *SubscriptionKeyORM) GetSubscriptionKeysByMinerHotkey(hotkey string) ([] return nil, err } + // Cache the result + if err := cache.SetCacheValue(cacheKey, apiKeys); err != nil { + log.Error().Err(err).Msgf("Error caching subscription keys") + } + return apiKeys, nil } @@ -56,7 +70,7 @@ func (a *SubscriptionKeyORM) CreateSubscriptionKeyByHotkey(hotkey string, subscr return nil, err } - createdApiKey, err := a.dbClient.SubscriptionKey.CreateOne( + createdSubKey, err := a.dbClient.SubscriptionKey.CreateOne( db.SubscriptionKey.Key.Set(subscriptionKey), db.SubscriptionKey.MinerUser.Link( db.MinerUser.ID.Equals(minerUser.ID), @@ -67,7 +81,7 @@ func (a *SubscriptionKeyORM) CreateSubscriptionKeyByHotkey(hotkey string, subscr log.Error().Err(err).Msgf("Error creating subscription key") return nil, err } - return createdApiKey, nil + return createdSubKey, nil } func (a *SubscriptionKeyORM) DisableSubscriptionKeyByHotkey(hotkey string, subscriptionKey string) (*db.SubscriptionKeyModel, error) { @@ -88,6 +102,15 @@ func (a *SubscriptionKeyORM) DisableSubscriptionKeyByHotkey(hotkey string, subsc } func (a *SubscriptionKeyORM) GetSubscriptionByKey(subScriptionKey string) (*db.SubscriptionKeyModel, error) { + cacheKey := cache.BuildCacheKey(cache.SubByKey, subScriptionKey) + + var foundSubscriptionKey *db.SubscriptionKeyModel + cache := cache.GetCacheInstance() + + // Try to get from cache first + if err := cache.GetCacheValue(cacheKey, &foundSubscriptionKey); err == nil { + return foundSubscriptionKey, nil + } a.clientWrapper.BeforeQuery() defer a.clientWrapper.AfterQuery() @@ -107,5 +130,10 @@ func (a *SubscriptionKeyORM) GetSubscriptionByKey(subScriptionKey string) (*db.S return nil, err } + // Cache the result + if err := cache.SetCacheValue(cacheKey, foundSubscriptionKey); err != nil { + log.Error().Err(err).Msgf("Error caching subscription key") + } + return foundSubscriptionKey, nil } diff --git a/pkg/orm/task.go b/pkg/orm/task.go index c8e3048..b702e5d 100644 --- a/pkg/orm/task.go +++ b/pkg/orm/task.go @@ -9,6 +9,7 @@ import ( "time" "dojo-api/db" + "dojo-api/pkg/cache" sq "github.com/Masterminds/squirrel" @@ -49,45 +50,88 @@ func (o *TaskORM) CreateTask(ctx context.Context, task db.InnerTask, minerUserId return createdTask, err } +// GetById with caching func (o *TaskORM) GetById(ctx context.Context, taskId string) (*db.TaskModel, error) { + cacheKey := cache.BuildCacheKey(cache.TaskById, taskId) + + var task *db.TaskModel + cache := cache.GetCacheInstance() + + // Try to get from cache first + if err := cache.GetCacheValue(cacheKey, &task); err == nil { + return task, nil + } + + // Cache miss, fetch from database o.clientWrapper.BeforeQuery() defer o.clientWrapper.AfterQuery() + task, err := o.dbClient.Task.FindUnique( db.Task.ID.Equals(taskId), ).Exec(ctx) - return task, err + if err != nil { + return nil, err + } + + // Store in cache + if err := cache.SetCacheValue(cacheKey, task); err != nil { + log.Warn().Err(err).Msg("Failed to set cache") + } + + return task, nil } -// TODO: Optimization +// Modified GetTasksByWorkerSubscription with caching func (o *TaskORM) GetTasksByWorkerSubscription(ctx context.Context, workerId string, offset, limit int, sortQuery db.TaskOrderByParam, taskTypes []db.TaskType) ([]db.TaskModel, int, error) { + // Convert TaskTypes to strings + typeStrs := make([]string, len(taskTypes)) + for i, t := range taskTypes { + typeStrs[i] = string(t) + } + cacheKey := cache.BuildCacheKey(cache.TasksByWorker, workerId, strconv.Itoa(offset), strconv.Itoa(limit), strings.Join(typeStrs, ",")) + + var tasks []db.TaskModel + cache := cache.GetCacheInstance() + + // Try to get from cache first + if err := cache.GetCacheValue(cacheKey, &tasks); err == nil { + totalTasks, err := o.countTasksByWorkerSubscription(ctx, taskTypes, nil) + if err != nil { + log.Error().Err(err).Msgf("Error fetching total tasks for worker ID %v", workerId) + return tasks, 0, err + } + return tasks, totalTasks, nil + } + + // Cache miss, proceed with database query o.clientWrapper.BeforeQuery() defer o.clientWrapper.AfterQuery() - // Fetch all active WorkerPartner records to retrieve MinerUser's subscription keys. + + // Rest of the existing implementation... partners, err := o.dbClient.WorkerPartner.FindMany( db.WorkerPartner.WorkerID.Equals(workerId), db.WorkerPartner.IsDeleteByMiner.Equals(false), db.WorkerPartner.IsDeleteByWorker.Equals(false), ).Exec(ctx) if err != nil { - log.Error().Err(err).Msg("Error in fetching WorkerPartner by WorkerID") + log.Error().Err(err).Msgf("Error fetching WorkerPartner by WorkerID for worker ID %v", workerId) return nil, 0, err } - // Collect Subscription keys from the fetched WorkerPartner records var subscriptionKeys []string for _, partner := range partners { subscriptionKeys = append(subscriptionKeys, partner.MinerSubscriptionKey) } if len(subscriptionKeys) == 0 { - log.Error().Err(err).Msg("No WorkerPartner found with the given WorkerID") + log.Error().Msgf("No subscription keys found for worker ID %v", workerId) return nil, 0, err } filterParams := []db.TaskWhereParam{ db.Task.MinerUser.Where( db.MinerUser.SubscriptionKeys.Some( - db.SubscriptionKey.Key.In(subscriptionKeys), // SubscriptionKey should be one of the keys in the subscriptionKeys slice. + db.SubscriptionKey.Key.In(subscriptionKeys), ), ), } @@ -96,27 +140,30 @@ func (o *TaskORM) GetTasksByWorkerSubscription(ctx context.Context, workerId str filterParams = append(filterParams, db.Task.Type.In(taskTypes)) } - log.Debug().Interface("taskTypes", taskTypes).Msgf("Filter Params: %v", filterParams) - - // Fetch tasks associated with these subscription keys - tasks, err := o.dbClient.Task.FindMany( + tasks, err = o.dbClient.Task.FindMany( filterParams..., ).OrderBy(sortQuery). Skip(offset). Take(limit). Exec(ctx) if err != nil { - log.Error().Err(err).Msg("Error in fetching tasks by WorkerSubscriptionKey") + log.Error().Err(err).Msgf("Error fetching tasks for worker ID %v", workerId) return nil, 0, err } totalTasks, err := o.countTasksByWorkerSubscription(ctx, taskTypes, subscriptionKeys) if err != nil { - log.Error().Err(err).Msg("Error in fetching total tasks by WorkerSubscriptionKey") + log.Error().Err(err).Msgf("Error fetching total tasks for worker ID %v", workerId) return nil, 0, err } log.Info().Int("totalTasks", totalTasks).Msgf("Successfully fetched total tasks fetched for worker ID %v", workerId) + + // Store in cache + if err := cache.SetCacheValue(cacheKey, tasks); err != nil { + log.Warn().Err(err).Msg("Failed to set cache") + } + return tasks, totalTasks, nil } @@ -273,6 +320,7 @@ func (o *TaskORM) UpdateExpiredTasks(ctx context.Context) { } } +// Modify GetCompletedTaskCount to use the new pattern func (o *TaskORM) GetCompletedTaskCount(ctx context.Context) (int, error) { o.clientWrapper.BeforeQuery() defer o.clientWrapper.AfterQuery() @@ -292,12 +340,12 @@ func (o *TaskORM) GetCompletedTaskCount(ctx context.Context) (int, error) { } taskCountStr := string(result[0].Count) - taskCountInt, err := strconv.Atoi(taskCountStr) + count, err := strconv.Atoi(taskCountStr) if err != nil { return 0, err } - return taskCountInt, nil + return count, nil } func (o *TaskORM) GetNextInProgressTask(ctx context.Context, taskId string, workerId string) (*db.TaskModel, error) { diff --git a/pkg/orm/task_result.go b/pkg/orm/task_result.go index a792dcf..9b7d9d3 100644 --- a/pkg/orm/task_result.go +++ b/pkg/orm/task_result.go @@ -2,10 +2,14 @@ package orm import ( "context" - "dojo-api/db" "fmt" "strconv" "time" + + "dojo-api/db" + "dojo-api/pkg/cache" + + "github.com/rs/zerolog/log" ) type TaskResultORM struct { @@ -36,22 +40,71 @@ func (t *TaskResultORM) CreateTaskResult(ctx context.Context, taskResult *db.Inn func (t *TaskResultORM) GetTaskResultsByTaskId(ctx context.Context, taskId string) ([]db.TaskResultModel, error) { t.clientWrapper.BeforeQuery() defer t.clientWrapper.AfterQuery() + return t.client.TaskResult.FindMany(db.TaskResult.TaskID.Equals(taskId)).Exec(ctx) } -func (orm *TaskResultORM) GetCompletedTResultByTaskAndWorker(ctx context.Context, taskId string, workerId string) ([]db.TaskResultModel, error) { - return orm.client.TaskResult.FindMany( +func (t *TaskResultORM) GetCompletedTResultByTaskAndWorker(ctx context.Context, taskId string, workerId string) ([]db.TaskResultModel, error) { + cacheKey := cache.BuildCacheKey(cache.TaskResultByTaskAndWorker, taskId, workerId) + + var results []db.TaskResultModel + cacheInstance := cache.GetCacheInstance() + + // Try to get from cache + if err := cacheInstance.GetCacheValue(cacheKey, &results); err == nil { + return results, nil + } + + // Cache miss, fetch from database + t.clientWrapper.BeforeQuery() + defer t.clientWrapper.AfterQuery() + + results, err := t.client.TaskResult.FindMany( db.TaskResult.TaskID.Equals(taskId), db.TaskResult.WorkerID.Equals(workerId), db.TaskResult.Status.Equals(db.TaskResultStatusCompleted), ).Exec(ctx) + if err != nil { + return nil, err + } + + // Set cache + if err := cacheInstance.SetCacheValue(cacheKey, results); err != nil { + log.Warn().Err(err).Msg("Failed to set cache") + } + + return results, nil } -func (orm *TaskResultORM) GetCompletedTResultByWorker(ctx context.Context, workerId string) ([]db.TaskResultModel, error) { - return orm.client.TaskResult.FindMany( +func (t *TaskResultORM) GetCompletedTResultByWorker(ctx context.Context, workerId string) ([]db.TaskResultModel, error) { + cacheKey := cache.BuildCacheKey(cache.TaskResultByWorker, workerId) + + var results []db.TaskResultModel + cache := cache.GetCacheInstance() + + // Try to get from cache first + if err := cache.GetCacheValue(cacheKey, &results); err == nil { + return results, nil + } + + // Cache miss, fetch from database + t.clientWrapper.BeforeQuery() + defer t.clientWrapper.AfterQuery() + + results, err := t.client.TaskResult.FindMany( db.TaskResult.WorkerID.Equals(workerId), db.TaskResult.Status.Equals(db.TaskResultStatusCompleted), ).Exec(ctx) + if err != nil { + return nil, err + } + + // Store in cache + if err := cache.SetCacheValue(cacheKey, results); err != nil { + log.Warn().Err(err).Msg("Failed to set task result cache") + } + + return results, nil } func (t *TaskResultORM) CreateTaskResultWithInvalid(ctx context.Context, taskResult *db.InnerTaskResult) (*db.TaskResultModel, error) {