Skip to content

Commit

Permalink
Merge pull request #50 from tensorplex-labs/perf/remove-expired-task-…
Browse files Browse the repository at this point in the history
…with-no-task-result

perf: refactor updating the expired status with tx, and remove task t…
  • Loading branch information
codebender37 authored Nov 7, 2024
2 parents 168d4f6 + 5174a5e commit 011a82f
Showing 1 changed file with 69 additions and 35 deletions.
104 changes: 69 additions & 35 deletions pkg/orm/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,50 +197,84 @@ func (o *TaskORM) countTasksByWorkerSubscription(ctx context.Context, taskTypes
return totalTasks, nil
}

// check every three mins for expired tasks
// check every 10 mins for expired tasks
func (o *TaskORM) UpdateExpiredTasks(ctx context.Context) {
for range time.Tick(3 * time.Minute) {
log.Info().Msg("Checking for expired tasks")
o.clientWrapper.BeforeQuery()
// Fetch all expired tasks
tasks, err := o.dbClient.Task.
FindMany(
db.Task.ExpireAt.Lte(time.Now()),
db.Task.Status.Equals(db.TaskStatusInProgress),
).
OrderBy(db.Task.CreatedAt.Order(db.SortOrderDesc)).
Exec(ctx)
if err != nil {
log.Error().Err(err).Msg("Error in fetching expired tasks")
}
defer o.clientWrapper.AfterQuery()

currentTime := time.Now()
batchSize := 100 // Adjust batch size based on database performance

// Step 1: Delete expired tasks without TaskResults in batches
batchNumber := 0
startTime := time.Now() // Start timing for delete operation
for {
batchNumber++
deleteQuery := `
DELETE FROM "Task"
WHERE "id" IN (
SELECT "id" FROM "Task"
WHERE "expire_at" <= $1
AND "status" IN ($2::"TaskStatus", $3::"TaskStatus")
AND "id" NOT IN (SELECT DISTINCT "task_id" FROM "TaskResult")
LIMIT $4
)
`

// has to include TaskStatusInProgress, to handle Task with in-progress with no results
params := []interface{}{currentTime, db.TaskStatusInProgress, db.TaskStatusExpired, batchSize}

execResult, err := o.dbClient.Prisma.ExecuteRaw(deleteQuery, params...).Exec(ctx)
if err != nil {
log.Error().Err(err).Msg("Error deleting tasks without TaskResults")
break
}

if len(tasks) == 0 {
log.Info().Msg("No newly expired tasks to update skipping...")
continue
} else {
log.Info().Msgf("Fetched %v newly expired tasks", len(tasks))
if execResult.Count == 0 {
log.Info().Msg("No more expired tasks to delete without TaskResults")
break
}

log.Info().Msgf("Deleted %v expired tasks without associated TaskResults in batch %d", execResult.Count, batchNumber)
}
deleteDuration := time.Since(startTime) // Calculate total duration for delete operation
log.Info().Msgf("Total time taken to delete expired tasks without TaskResults: %s", deleteDuration)

// Step 2: Update expired tasks with TaskResults to 'expired' status in batches
batchNumber = 0
startTime = time.Now() // Start timing for update operation
for {
batchNumber++
updateQuery := `
UPDATE "Task"
SET "status" = $1::"TaskStatus", "updated_at" = $2
WHERE "id" IN (
SELECT "id" FROM "Task"
WHERE "expire_at" <= $2
AND "status" = $3::"TaskStatus"
AND "id" IN (SELECT DISTINCT "task_id" FROM "TaskResult")
LIMIT $4
)
`
params := []interface{}{db.TaskStatusExpired, currentTime, db.TaskStatusInProgress, batchSize}

execResult, err := o.dbClient.Prisma.ExecuteRaw(updateQuery, params...).Exec(ctx)
if err != nil {
log.Error().Err(err).Msg("Error updating tasks to expired status")
break
}

var txns []db.PrismaTransaction
for i, taskModel := range tasks {
transaction := o.dbClient.Task.FindUnique(
db.Task.ID.Equals(taskModel.ID),
).Update(
db.Task.Status.Set(db.TaskStatusExpired),
db.Task.UpdatedAt.Set(time.Now()),
).Tx()

txns = append(txns, transaction)

if len(txns) == 100 || (i == len(tasks)-1 && len(txns) > 0) {
if err := o.dbClient.Prisma.Transaction(txns...).Exec(ctx); err != nil {
log.Error().Err(err).Msg("Error in updating batch of task status to expired")
}
txns = []db.PrismaTransaction{}
if execResult.Count == 0 {
log.Info().Msg("No more expired tasks with TaskResults to update")
break
}
}

o.clientWrapper.AfterQuery()
log.Info().Msgf("Updated %v expired tasks with associated TaskResults in batch %d", execResult.Count, batchNumber)
}
updateDuration := time.Since(startTime) // Calculate total duration for update operation
log.Info().Msgf("Total time taken to update expired tasks with TaskResults: %s", updateDuration)
}
}

Expand Down

0 comments on commit 011a82f

Please sign in to comment.