From 6d6ed27ec615eeaec069f15894a518c4eb5763c6 Mon Sep 17 00:00:00 2001 From: Geoffrey Mureithi <95377562+geofmureithi@users.noreply.github.com> Date: Thu, 28 Nov 2024 09:01:08 +0300 Subject: [PATCH] [bump] introduce new version: 0.6 (#459) * fix: improve external api for redis * fix: improve exports for redis * fix: expose redis codec * Feature: v0.6.0-alpha.0 version of apalis Breaking Changes: - Dropped traits Job and Message, please use namespace * fix: minor fixes on some failures * lint: cargo fmt * fix: remove Job impl * lint: cargo fmt * bench: improve polling * fix: introduce namespace and codec config (#339) * fix: introduce namespace and codec config * fix: missing apis * Version: 0.6.0-alpha.1 (#342) * api: for redis and sqlite * Version: 0.6.0-alpha.1 Changelog: - Redis storage doesnt require pool to be clone. Allows use of deadpool-redis among others. - Namespace is picked by default for `new` methods. * fix: docs and tests * lint: cargo clippy and fmt * postgres: add a listener example * bump: to v0.6.0-alpha.1 (#343) * api: for redis and sqlite * Version: 0.6.0-alpha.1 Changelog: - Redis storage doesnt require pool to be clone. Allows use of deadpool-redis among others. - Namespace is picked by default for `new` methods. * fix: docs and tests * lint: cargo clippy and fmt * postgres: add a listener example * bump: to v0.6.0-alpha.1 * fix: allow cd for prereleases (#349) * Remove `Clone` constraints and buffer the service (#348) * feat: remove the `Clone` requirements for services * test save * fix: get buffered layer working * update: remove clone & update api * fix: tests and api * lint: clippy fixes * lint: cargo fmt * bump: to 0.6.0-rc.1 (#350) * feat: add rsmq example (#353) * Fix: load layer from poller (#354) * fix: backend layers were not loaded * fix: handle clone * Fix: mq example (#355) * fix: mq ack * lint: fmt * fix: handle unwraps in storages (#356) * fix: handle unwraps in storages * fix: ensure no unwrap * fix: better apalis deps allowing tree shaking for backends (#357) * fix: better apalis deps allowing tree shaking for backends * fix: remove backend features in the root crate * standardize backend for storage and mq (#358) * fix: standardize backend for storage and mq * fix: minor fixes * feat: standardize cron as backend (#359) * fix: remove non-working restapi example (#360) * fix: expose the missing apis (#361) * bump: to new version (#362) * Make Config accessible publicly (#364) * fix: add missing exposed config * fix: add getters * fix: die if retries is zero (#365) * Feature: Add a layer that catches panics (#366) * Feature: Add a layer that catches panics This allows preventing job execution from killing workers and returns an error containing the backtrace * fix: backtrace as it may be different * add: example for catch-panic * fix: make not default * Feature: Save results for storages (#369) * Feature: Save results for storages Currently just the status is stored, this PR adds the ability to save the result * fix: result from storage * fix: kill and abort issue * Bump: to 0.6.0-rc.3 (#370) * fix: serde for sql request (#371) * fix: serde for sql request * fix: serde for attempts * lint: fmt * fix: handle attempts in storages (#373) * fix: handle attempts in storages * fix: chrono serialization * fix: tests failing because of tests * add: test utils that allow backend polling during tests (#374) * add: test utils that allow backend polling during tests * fix: introduce testwrapper and add more tests * fix: add sample for testing * fix: more fixes and actions fixes * fix: more fixes on vacuuming * tests: improve cleanup and generic testing * fix: improve testing and fix some found bugs * fix: postgres query and remove incompatible tests * fix: remove redis incompatible check * fix: minor fixes * fix: postgres json elements * bump: to 0.6.0-rc.4 (#377) * fix: handle 0 retries (#378) * fix: ack api to allow backend to handle differently (#383) * fix: ack api to allow backend to handle differently * fix: related to storage tests * fix: calculate status for postgres * fix(deps): update rust crate sqlx to 0.8.0 (#380) * chore: fix typos (#346) * chore: Add repository to metadata (#345) * fix(deps): update rust crate sqlx to 0.8.0 * fix: sqlite example --------- Co-authored-by: John Vandenberg Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Co-authored-by: geofmureithi Co-authored-by: Geoffrey Mureithi <95377562+geofmureithi@users.noreply.github.com> * bump: to v0.6.0-rc.5 (#385) * chore: standardize codec usage (#388) * bump: to v0.6.0-rc.5 * fix: standardize codec usage * lint: cargo fmt * Chore/more examples (#389) * add: catch-panic example * add: graceful shutdown example * add: unmonitored example * add: arguments example * fix: minor updates * fix: sql tests * fix: minor updates * fix: improve on benches (#379) * fix: improve on benches * fix: bench trigger * fix: include tokio for sqlx * fix: improve the benching approach * fix: mysql api * fix: redis api * fix: improve bench approach, remove counter * remove: setup * remove: pg * fix: pg * fix: pg * fix(deps): update rust crate sqlx to 0.8.1 [security] (#400) Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Co-authored-by: Geoffrey Mureithi <95377562+geofmureithi@users.noreply.github.com> * fix: add some missing data required for dependency injection (#409) * fix: add some missing data required for dependency injection * lint: clippy and fmt * remove: benchmarks (#410) They will be moved to https://github.com/geofmureithi/apalis-benchmarks * bump: to 0.6.0-rc.6 (#412) * Update async-std to 1.13 (#413) * Feature: Introducing Request Context (#416) * wip: introduce context to request * fix: get request context working * lint: cargo fmt * fix: get tests compiling * add: push_request and shedule_request * fix: task_id for Testwrapper * fix: minor checks and fixes on postgres tests * fix: bug on postgres fetch_next * bump: to 0.6.0-rc.7 (#418) * fix: apply `FromRequest` for items in `Parts` (#425) Problem: We are missing crucial `FromRequest` impls for: - TaskId - Attempt - Namespace Also removed `Context` Solution: Implement `FromRequest` for these Types. * fix:[bug] include backend provided layer in service layers. (#426) * fix:[bug] include backend provided layer in service layers. Problem: The current worker logic is missing an implementation where the backend provided layer should be added to the service's layer. This is a critical issue that affects all v0.6.0-rc-7 users and they should update as soon as a new release is done. Solution: - Add backend layers to service's layer. - Add worker_consume tests on the storages to prevent regression on this. * chore: comment an enforcement rule not yet followed by redis * chore: bump to 0.6.0-rc.8 (#430) * fix: apply max_attempts set via SqlContext (#447) So that a custom number of attempts can be configured: let mut ctx = SqlContext::new(); ctx.set_max_attempts(2); let req = Request::new_with_ctx(job, ctx); storage.push_request(req).await.unwrap(); While the default is still to try up to 25 times: storage.push(job).await.unwrap(); * Bump redis (#442) * feat: re-export sqlx (#451) Making sqlx accessible to users of apalis without requiring them to explicitly add it as a dependency. * feat: Improve Worker management and drop Executor (#428) * feat: introducing WorkerBuilderExt which makes the work of building a new worker way easier. * improve: worker api almost there * fix: radical improvements and updates. Removed executor and got graceful shutdown working * chore: deprecate register with count and force builder order * chore: more improvements on the worker * fix: allow DI for Worker * add: get the task count by a worker * lint: fmt and clippy * fix: allow worker stopping * Chore/better api (#452) * fix: relax the api provided for sqlx req * lint: clippy and fmt * feat: add recovery of abandoned jobs to backend heartbeats (#453) * feat: add recovery of abandoned jobs to backend heartbeats * lint: fmt * fix: attempt to get tests passing * fix: attempt to get tests passing * fix: minor fix typo * fix: minor different solutions * fix: better handle attempts * handle postgres edge case * fix: better handling * feat: allow backends to emit errors (#454) * feat: allow backends to emit errors * lint: fmt * fix: pass in a reference to prevent mutation * Feat: Introduce simple ability to pipe cron jobs to any backend (#455) * Feat: Introduce simple ability to pipe cron jobs to any backend This feature allows you to quickly persist cron jobs guaranteeing they will be run and can be distributed * lint: cargo fmt * chore/dev-branch * feat: return of exposing backends to help in building apis (#457) (#458) * fmt: Cargo.toml * fix: removed features * fix: run only specific tests --------- Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Co-authored-by: John Vandenberg Co-authored-by: Mathias Lafeldt Co-authored-by: zakstucke <44890343+zakstucke@users.noreply.github.com> --- .github/workflows/bench.yaml | 39 - .github/workflows/cd.yaml | 5 +- .github/workflows/ci.yaml | 4 +- .github/workflows/mysql.yaml | 7 - .github/workflows/postgres.yaml | 9 +- .github/workflows/redis.yaml | 8 +- .github/workflows/sqlite.yaml | 9 +- Cargo.toml | 94 +- README.md | 31 +- benches/storages.rs | 158 --- examples/actix-web/Cargo.toml | 3 +- examples/actix-web/src/main.rs | 17 +- examples/async-std-runtime/Cargo.toml | 5 +- examples/async-std-runtime/src/main.rs | 43 +- examples/axum/Cargo.toml | 5 +- examples/axum/src/main.rs | 19 +- examples/basics/Cargo.toml | 5 +- examples/basics/src/layer.rs | 15 +- examples/basics/src/main.rs | 61 +- examples/catch-panic/Cargo.toml | 24 + examples/catch-panic/src/main.rs | 54 + examples/cron/Cargo.toml | 23 + examples/cron/src/main.rs | 41 + examples/email-service/Cargo.toml | 3 +- examples/email-service/src/lib.rs | 44 +- examples/fn-args/Cargo.toml | 22 + examples/fn-args/src/main.rs | 73 ++ examples/graceful-shutdown/Cargo.toml | 23 + examples/graceful-shutdown/src/main.rs | 52 + examples/mysql/Cargo.toml | 7 +- examples/mysql/src/main.rs | 13 +- examples/persisted-cron/Cargo.toml | 28 + examples/persisted-cron/src/main.rs | 57 + examples/postgres/Cargo.toml | 3 +- examples/postgres/src/main.rs | 32 +- examples/prometheus/Cargo.toml | 3 +- examples/prometheus/src/main.rs | 19 +- examples/redis-deadpool/Cargo.toml | 22 + examples/redis-deadpool/src/main.rs | 57 + examples/redis-mq-example/Cargo.toml | 25 + examples/redis-mq-example/src/main.rs | 192 ++++ examples/redis-with-msg-pack/Cargo.toml | 22 + examples/redis-with-msg-pack/src/main.rs | 73 ++ examples/redis/Cargo.toml | 3 +- examples/redis/src/main.rs | 40 +- examples/rest-api/Cargo.toml | 14 +- examples/rest-api/README.md | 14 +- examples/rest-api/src/main.rs | 363 +----- examples/sentry/Cargo.toml | 3 +- examples/sentry/src/main.rs | 19 +- examples/sqlite/Cargo.toml | 5 +- examples/sqlite/src/job.rs | 5 - examples/sqlite/src/main.rs | 20 +- examples/tracing/Cargo.toml | 5 +- examples/tracing/src/main.rs | 21 +- examples/unmonitored-worker/Cargo.toml | 22 + examples/unmonitored-worker/src/main.rs | 55 + packages/apalis-core/Cargo.toml | 6 +- packages/apalis-core/src/backend.rs | 92 ++ packages/apalis-core/src/builder.rs | 120 +- packages/apalis-core/src/codec/json.rs | 58 +- .../apalis-core/src/codec/message_pack.rs | 0 packages/apalis-core/src/codec/mod.rs | 20 + packages/apalis-core/src/data.rs | 23 + packages/apalis-core/src/error.rs | 121 +- packages/apalis-core/src/executor.rs | 7 - packages/apalis-core/src/layers.rs | 157 ++- packages/apalis-core/src/lib.rs | 331 +++++- packages/apalis-core/src/memory.rs | 54 +- packages/apalis-core/src/monitor/mod.rs | 296 ++--- packages/apalis-core/src/monitor/shutdown.rs | 9 +- packages/apalis-core/src/mq/mod.rs | 27 +- packages/apalis-core/src/poller/mod.rs | 61 +- packages/apalis-core/src/request.rs | 163 ++- packages/apalis-core/src/response.rs | 130 ++- packages/apalis-core/src/service_fn.rs | 66 +- packages/apalis-core/src/storage/mod.rs | 76 +- packages/apalis-core/src/task/attempt.rs | 51 +- packages/apalis-core/src/task/mod.rs | 2 + packages/apalis-core/src/task/namespace.rs | 52 + packages/apalis-core/src/task/task_id.rs | 10 +- packages/apalis-core/src/worker/mod.rs | 737 ++++++------ packages/apalis-core/src/worker/stream.rs | 56 - packages/apalis-cron/Cargo.toml | 11 +- packages/apalis-cron/README.md | 42 +- packages/apalis-cron/src/lib.rs | 125 +- packages/apalis-cron/src/pipe.rs | 77 ++ packages/apalis-redis/Cargo.toml | 16 +- .../lua/{ack_job.lua => done_job.lua} | 7 +- packages/apalis-redis/lua/kill_job.lua | 16 +- packages/apalis-redis/lua/retry_job.lua | 11 +- packages/apalis-redis/src/expose.rs | 258 +++++ packages/apalis-redis/src/lib.rs | 12 +- packages/apalis-redis/src/storage.rs | 1024 ++++++++++------- packages/apalis-sql/Cargo.toml | 17 +- packages/apalis-sql/src/context.rs | 103 +- packages/apalis-sql/src/from_row.rs | 121 +- packages/apalis-sql/src/lib.rs | 235 +++- packages/apalis-sql/src/mysql.rs | 584 ++++++---- packages/apalis-sql/src/postgres.rs | 793 ++++++++----- packages/apalis-sql/src/sqlite.rs | 539 +++++---- src/layers/catch_panic/mod.rs | 207 ++++ src/layers/mod.rs | 291 +++++ src/layers/prometheus/mod.rs | 39 +- src/layers/retry/mod.rs | 34 +- src/layers/sentry/mod.rs | 44 +- src/layers/tracing/make_span.rs | 24 +- src/layers/tracing/mod.rs | 26 +- src/layers/tracing/on_failure.rs | 22 +- src/layers/tracing/on_request.rs | 28 +- src/layers/tracing/on_response.rs | 8 +- src/lib.rs | 99 +- 112 files changed, 6158 insertions(+), 3418 deletions(-) delete mode 100644 .github/workflows/bench.yaml delete mode 100644 benches/storages.rs create mode 100644 examples/catch-panic/Cargo.toml create mode 100644 examples/catch-panic/src/main.rs create mode 100644 examples/cron/Cargo.toml create mode 100644 examples/cron/src/main.rs create mode 100644 examples/fn-args/Cargo.toml create mode 100644 examples/fn-args/src/main.rs create mode 100644 examples/graceful-shutdown/Cargo.toml create mode 100644 examples/graceful-shutdown/src/main.rs create mode 100644 examples/persisted-cron/Cargo.toml create mode 100644 examples/persisted-cron/src/main.rs create mode 100644 examples/redis-deadpool/Cargo.toml create mode 100644 examples/redis-deadpool/src/main.rs create mode 100644 examples/redis-mq-example/Cargo.toml create mode 100644 examples/redis-mq-example/src/main.rs create mode 100644 examples/redis-with-msg-pack/Cargo.toml create mode 100644 examples/redis-with-msg-pack/src/main.rs create mode 100644 examples/unmonitored-worker/Cargo.toml create mode 100644 examples/unmonitored-worker/src/main.rs create mode 100644 packages/apalis-core/src/backend.rs delete mode 100644 packages/apalis-core/src/codec/message_pack.rs delete mode 100644 packages/apalis-core/src/executor.rs create mode 100644 packages/apalis-core/src/task/namespace.rs delete mode 100644 packages/apalis-core/src/worker/stream.rs create mode 100644 packages/apalis-cron/src/pipe.rs rename packages/apalis-redis/lua/{ack_job.lua => done_job.lua} (70%) create mode 100644 packages/apalis-redis/src/expose.rs create mode 100644 src/layers/catch_panic/mod.rs diff --git a/.github/workflows/bench.yaml b/.github/workflows/bench.yaml deleted file mode 100644 index 9edc5a7..0000000 --- a/.github/workflows/bench.yaml +++ /dev/null @@ -1,39 +0,0 @@ -on: - pull_request: - paths: - - 'packages/**' - - '.github/workflows/bench.yaml' -name: Benchmark -jobs: - storageBenchmark: - name: Storage Benchmarks - runs-on: ubuntu-latest - services: - redis: - image: redis - ports: - - 6379:6379 - postgres: - image: postgres:17 - env: - POSTGRES_PASSWORD: postgres - ports: - - 5432:5432 - mysql: - image: mysql:8 - env: - MYSQL_DATABASE: test - MYSQL_USER: test - MYSQL_PASSWORD: test - MYSQL_ROOT_PASSWORD: root - ports: - - 3306:3306 - env: - POSTGRES_URL: postgres://postgres:postgres@localhost/postgres - MYSQL_URL: mysql://test:test@localhost/test - REDIS_URL: redis://127.0.0.1/ - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 - - uses: boa-dev/criterion-compare-action@v3 - with: - branchName: ${{ github.base_ref }} diff --git a/.github/workflows/cd.yaml b/.github/workflows/cd.yaml index 8cdfe15..df339a7 100644 --- a/.github/workflows/cd.yaml +++ b/.github/workflows/cd.yaml @@ -31,10 +31,7 @@ jobs: # vX.Y.Z-foo is pre-release version VERSION=${GITHUB_REF#refs/tags/v} VERSION_NUMBER=${VERSION%-*} - PUBLISH_OPTS="--dry-run" - if [[ $VERSION == $VERSION_NUMBER ]]; then - PUBLISH_OPTS="" - fi + PUBLISH_OPTS="" echo VERSION=${VERSION} >> $GITHUB_ENV echo PUBLISH_OPTS=${PUBLISH_OPTS} >> $GITHUB_ENV echo VERSION_NUMBER=${VERSION_NUMBER} >> $GITHUB_ENV diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index a400bee..c4a9e52 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -20,7 +20,7 @@ jobs: - uses: actions-rs/cargo@v1 with: command: check - args: --features tokio-comp + args: --all - uses: actions-rs/cargo@v1 with: command: check @@ -39,7 +39,7 @@ jobs: - uses: actions-rs/cargo@v1 with: command: test - args: --features tokio-comp + # args: --all fmt: diff --git a/.github/workflows/mysql.yaml b/.github/workflows/mysql.yaml index f842fda..50e337c 100644 --- a/.github/workflows/mysql.yaml +++ b/.github/workflows/mysql.yaml @@ -1,11 +1,4 @@ on: - push: - paths: - - "packages/apalis-sql/src/lib.rs" - - "packages/apalis-sql/mysql.rs" - - "packages/apalis-sql/src/migrations/mysql/**" - - "packages/apalis-sql/src/Cargo.toml" - - ".github/workflows/mysql.yaml" pull_request: paths: - "packages/apalis-sql/src/lib.rs" diff --git a/.github/workflows/postgres.yaml b/.github/workflows/postgres.yaml index 63eb68c..93eb8b9 100644 --- a/.github/workflows/postgres.yaml +++ b/.github/workflows/postgres.yaml @@ -1,11 +1,4 @@ on: - push: - paths: - - "packages/apalis-sql/src/lib.rs" - - "packages/apalis-sql/postgres.rs" - - "packages/apalis-sql/src/migrations/postgres/**" - - "packages/apalis-sql/src/Cargo.toml" - - ".github/workflows/postgres.yaml" pull_request: paths: - "packages/apalis-sql/src/lib.rs" @@ -37,4 +30,4 @@ jobs: toolchain: stable override: true - run: cargo test --no-default-features --features postgres,migrate,tokio-comp -- --test-threads=1 - working-directory: packages/apalis-sql \ No newline at end of file + working-directory: packages/apalis-sql diff --git a/.github/workflows/redis.yaml b/.github/workflows/redis.yaml index 395d780..66d3673 100644 --- a/.github/workflows/redis.yaml +++ b/.github/workflows/redis.yaml @@ -1,8 +1,4 @@ on: - push: - paths: - - "packages/apalis-redis/**" - - ".github/workflows/redis.yaml" pull_request: paths: - "packages/apalis-redis/**" @@ -26,11 +22,11 @@ jobs: profile: minimal toolchain: stable override: true - - run: cargo test --features tokio-comp -- --test-threads=1 + - run: cargo test -- --test-threads=1 working-directory: packages/apalis-redis env: REDIS_URL: redis://127.0.0.1/ - - run: cargo test --features async-std-comp -- --test-threads=1 + - run: cargo test -- --test-threads=1 working-directory: packages/apalis-redis env: REDIS_URL: redis://127.0.0.1/ diff --git a/.github/workflows/sqlite.yaml b/.github/workflows/sqlite.yaml index 5c30911..f3e9ade 100644 --- a/.github/workflows/sqlite.yaml +++ b/.github/workflows/sqlite.yaml @@ -1,11 +1,4 @@ on: - push: - paths: - - "packages/apalis-sql/src/lib.rs" - - "packages/apalis-sql/src/sqlite.rs" - - "packages/apalis-sql/src/migrations/sqlite/**" - - "packages/apalis-sql/src/Cargo.toml" - - ".github/workflows/sqlite.yaml" pull_request: paths: - "packages/apalis-sql/src/lib.rs" @@ -28,4 +21,4 @@ jobs: toolchain: stable override: true - run: cargo test --no-default-features --features sqlite,migrate,tokio-comp -- --test-threads=1 - working-directory: packages/apalis-sql \ No newline at end of file + working-directory: packages/apalis-sql diff --git a/Cargo.toml b/Cargo.toml index 886e73b..e0d9fcb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ repository = "https://github.com/geofmureithi/apalis" [package] name = "apalis" -version = "0.5.5" +version = "0.6.0" authors = ["Geoffrey Mureithi "] description = "Simple, extensible multithreaded background job processing for Rust" edition.workspace = true @@ -19,18 +19,8 @@ categories = ["database"] bench = false [features] -default = ["tracing", "tokio-comp"] - -## Include redis storage -redis = ["apalis-redis"] -## Include Postgres storage -postgres = ["apalis-sql/postgres"] -## Include SQlite storage -sqlite = ["apalis-sql/sqlite"] -## Include MySql storage -mysql = ["apalis-sql/mysql"] -## Include Cron functionality -cron = ["apalis-cron"] +default = ["tracing"] + ## Support Tracing 👀 tracing = ["dep:tracing", "dep:tracing-futures"] @@ -47,20 +37,8 @@ timeout = ["tower/timeout"] limit = ["tower/limit"] ## Support filtering jobs based on a predicate filter = ["tower/filter"] -## Compatibility with async-std and smol runtimes -async-std-comp = [ - "apalis-sql?/async-std-comp", - "apalis-redis?/async-std-comp", - "apalis-cron?/async-std-comp", - "async-std", -] -## Compatibility with tokio and actix runtimes -tokio-comp = [ - "apalis-sql?/tokio-comp", - "apalis-redis?/tokio-comp", - "apalis-cron?/tokio-comp", - "tokio", -] +## Captures panics in executions and convert them to errors +catch-panic = [] layers = [ "sentry", @@ -70,35 +48,16 @@ layers = [ "timeout", "limit", "filter", + "catch-panic", ] docsrs = ["document-features"] -[dependencies.apalis-redis] -version = "0.5.5" -optional = true -path = "./packages/apalis-redis" -default-features = false - -[dependencies.apalis-sql] - -version = "0.5.5" -features = ["migrate"] -optional = true -default-features = false -path = "./packages/apalis-sql" - [dependencies.apalis-core] -version = "0.5.5" +version = "0.6.0" default-features = false path = "./packages/apalis-core" -[dependencies.apalis-cron] -version = "0.5.5" -optional = true -default-features = false -path = "./packages/apalis-cron" - [dependencies.document-features] version = "0.2" optional = true @@ -116,22 +75,24 @@ pprof = { version = "0.14", features = ["flamegraph"] } paste = "1.0.14" serde = "1" tokio = { version = "1", features = ["macros", "rt-multi-thread"] } -apalis = { path = ".", features = ["redis", "sqlite", "postgres", "mysql"] } -redis = { version = "0.25.3", default-features = false, features = [ - "script", - "aio", - "connection-manager", +apalis = { path = ".", features = ["limit"] } +apalis-redis = { path = "./packages/apalis-redis" } +apalis-sql = { path = "./packages/apalis-sql", features = [ + "postgres", + "mysql", + "sqlite", +] } +redis = { version = "0.27", default-features = false, features = [ + "tokio-comp", + "script", + "aio", + "connection-manager", ] } [dev-dependencies.sqlx] version = "0.8.1" default-features = false -features = ["chrono", "mysql", "sqlite", "postgres"] - - -[[bench]] -name = "storages" -harness = false +features = ["chrono", "mysql", "sqlite", "postgres", "runtime-tokio"] [workspace] members = [ @@ -150,17 +111,22 @@ members = [ "examples/axum", "examples/prometheus", "examples/tracing", - # "examples/rest-api", "examples/async-std-runtime", "examples/basics", + "examples/redis-with-msg-pack", + "examples/redis-deadpool", + "examples/redis-mq-example", + "examples/cron", + "examples/catch-panic", + "examples/graceful-shutdown", + "examples/unmonitored-worker", + "examples/fn-args", + "examples/persisted-cron", + "examples/rest-api", ] [dependencies] -tokio = { version = "1", features = [ - "rt", -], default-features = false, optional = true } -async-std = { version = "1", optional = true } tower = { version = "0.5", features = ["util"], default-features = false } tracing-futures = { version = "0.2.5", optional = true, default-features = false } sentry-core = { version = "0.34.0", optional = true, default-features = false } diff --git a/README.md b/README.md index 2582e14..e878aa5 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ - Simple and predictable job handling model. - Jobs handlers with a macro free API. -- Take full advantage of the [tower] ecosystem of +- Take full advantage of the [`tower`] ecosystem of middleware, services, and utilities. - Runtime agnostic - Use tokio, smol etc. - Optional Web interface to help you manage your jobs. @@ -59,26 +59,23 @@ To get started, just add to Cargo.toml ```toml [dependencies] -apalis = { version = "0.5", features = ["redis"] } # Backends available: postgres, sqlite, mysql, amqp +apalis = { version = "0.6" } +apalis-redis = { version = "0.6" } +# apalis-sql = { version = "0.6", features = ["postgres"] } # or mysql, sqlite ``` ## Usage ```rust use apalis::prelude::*; -use apalis::redis::RedisStorage; +use apalis_redis::{RedisStorage, Config}; use serde::{Deserialize, Serialize}; -use anyhow::Result; #[derive(Debug, Deserialize, Serialize)] struct Email { to: String, } -impl Job for Email { - const NAME: &'static str = "apalis::Email"; -} - /// A function that will be converted into a service. async fn send_email(job: Email, data: Data) -> Result<(), Error> { /// execute job @@ -86,16 +83,18 @@ async fn send_email(job: Email, data: Data) -> Result<(), Error> { } #[tokio::main] -async fn main() -> Result<()> { +async fn main() -> { std::env::set_var("RUST_LOG", "debug"); env_logger::init(); let redis_url = std::env::var("REDIS_URL").expect("Missing env variable REDIS_URL"); - let storage = RedisStorage::new(redis).await?; + let conn = apalis_redis::connect(redis_url).await.expect("Could not connect"); + let storage = RedisStorage::new(conn); Monitor::new() - .register_with_count(2, { + .register({ WorkerBuilder::new(format!("email-worker")) + .concurrency(2) .data(0usize) - .with_storage(storage) + .backend(storage) .build_fn(send_email) }) .run() @@ -122,17 +121,13 @@ async fn produce_route_jobs(storage: &RedisStorage) -> Result<()> { ## Feature flags - _tracing_ (enabled by default) — Support Tracing 👀 -- _redis_ — Include redis storage -- _postgres_ — Include Postgres storage -- _sqlite_ — Include SQlite storage -- _mysql_ — Include MySql storage -- _cron_ — Include cron job processing - _sentry_ — Support for Sentry exception and performance monitoring - _prometheus_ — Support Prometheus metrics - _retry_ — Support direct retrying jobs - _timeout_ — Support timeouts on jobs - _limit_ — 💊 Limit the amount of jobs - _filter_ — Support filtering jobs based on a predicate +- _catch-panic_ - Catch panics that occur during execution ## Storage Comparison @@ -214,5 +209,5 @@ See also the list of [contributors](https://github.com/geofmureithi/apalis/contr This project is licensed under the MIT License - see the [LICENSE.md](LICENSE.md) file for details [`tower::Service`]: https://docs.rs/tower/latest/tower/trait.Service.html -[tower]: https://crates.io/crates/tower +[`tower`]: https://crates.io/crates/tower [`actix`]: https://crates.io/crates/actix diff --git a/benches/storages.rs b/benches/storages.rs deleted file mode 100644 index 45b2874..0000000 --- a/benches/storages.rs +++ /dev/null @@ -1,158 +0,0 @@ -use apalis::prelude::*; - -use apalis::redis::RedisStorage; -use apalis::{ - mysql::{MySqlPool, MysqlStorage}, - postgres::{PgPool, PostgresStorage}, - sqlite::{SqlitePool, SqliteStorage}, -}; -use criterion::*; -use futures::Future; -use paste::paste; -use serde::{Deserialize, Serialize}; -use std::sync::atomic::AtomicUsize; -use std::sync::atomic::Ordering; -use std::sync::Arc; -use std::time::{Duration, Instant}; -use tokio::runtime::Runtime; -macro_rules! define_bench { - ($name:expr, $setup:expr ) => { - paste! { - fn [<$name>](c: &mut Criterion) { - let size: usize = 1000; - - let mut group = c.benchmark_group($name); - group.sample_size(10); - group.bench_with_input(BenchmarkId::new("consume", size), &size, |b, &s| { - b.to_async(Runtime::new().unwrap()) - .iter_custom(|iters| async move { - let mut interval = tokio::time::interval(Duration::from_millis(100)); - let storage = { $setup }; - let mut s1 = storage.clone(); - let counter = Counter::default(); - let c = counter.clone(); - tokio::spawn(async move { - Monitor::::new() - .register({ - let worker = - WorkerBuilder::new(format!("{}-bench", $name)) - .data(c) - .source(storage) - .build_fn(handle_test_job); - worker - }) - .run() - .await - .unwrap(); - }); - - let start = Instant::now(); - for _ in 0..iters { - for _i in 0..s { - let _ = s1.push(TestJob).await; - } - while s1.len().await.unwrap_or(-1) != 0 { - interval.tick().await; - } - counter.0.store(0, Ordering::Relaxed); - } - let elapsed = start.elapsed(); - s1.cleanup().await; - elapsed - }) - }); - group.bench_with_input(BenchmarkId::new("push", size), &size, |b, &s| { - b.to_async(Runtime::new().unwrap()).iter(|| async move { - let mut storage = { $setup }; - let start = Instant::now(); - for _i in 0..s { - let _ = black_box(storage.push(TestJob).await); - } - start.elapsed() - }); - }); - }} - }; -} - -#[derive(Serialize, Deserialize, Debug)] -struct TestJob; - -impl Job for TestJob { - const NAME: &'static str = "TestJob"; -} - -#[derive(Debug, Default, Clone)] -struct Counter(Arc); - -async fn handle_test_job(_req: TestJob, counter: Data) -> Result<(), Error> { - counter.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - Ok(()) -} - -trait CleanUp { - fn cleanup(&mut self) -> impl Future + Send; -} - -impl CleanUp for SqliteStorage { - async fn cleanup(&mut self) { - let pool = self.pool(); - let query = "DELETE FROM Jobs; DELETE from Workers;"; - sqlx::query(query).execute(pool).await.unwrap(); - } -} - -impl CleanUp for PostgresStorage { - async fn cleanup(&mut self) { - let pool = self.pool(); - let query = "DELETE FROM apalis.jobs;"; - sqlx::query(query).execute(pool).await.unwrap(); - let query = "DELETE from apalis.workers;"; - sqlx::query(query).execute(pool).await.unwrap(); - } -} - -impl CleanUp for MysqlStorage { - async fn cleanup(&mut self) { - let pool = self.pool(); - let query = "DELETE FROM jobs; DELETE from workers;"; - sqlx::query(query).execute(pool).await.unwrap(); - } -} - -impl CleanUp for RedisStorage { - async fn cleanup(&mut self) { - let mut conn = self.get_connection().clone(); - let _resp: String = redis::cmd("FLUSHDB") - .query_async(&mut conn) - .await - .expect("failed to Flushdb"); - } -} - -define_bench!("sqlite_in_memory", { - let pool = SqlitePool::connect("sqlite::memory:").await.unwrap(); - let _ = SqliteStorage::setup(&pool).await; - SqliteStorage::new(pool) -}); - -define_bench!("redis", { - let conn = apalis::redis::connect(env!("REDIS_URL")).await.unwrap(); - let redis = RedisStorage::new(conn); - redis -}); - -define_bench!("postgres", { - let pool = PgPool::connect(env!("POSTGRES_URL")).await.unwrap(); - let _ = PostgresStorage::setup(&pool).await.unwrap(); - PostgresStorage::new(pool) -}); - -// define_bench!("mysql", { -// let pool = MySqlPool::connect(env!("MYSQL_URL")).await.unwrap(); -// let _ = MysqlStorage::setup(&pool).await.unwrap(); -// MysqlStorage::new(pool) -// }); - -criterion_group!(benches, sqlite_in_memory, redis, postgres); -criterion_main!(benches); diff --git a/examples/actix-web/Cargo.toml b/examples/actix-web/Cargo.toml index 43093e8..7f755b6 100644 --- a/examples/actix-web/Cargo.toml +++ b/examples/actix-web/Cargo.toml @@ -7,7 +7,8 @@ license = "MIT OR Apache-2.0" [dependencies] anyhow = "1" -apalis = { path = "../../", features = ["redis"] } +apalis = { path = "../../" } +apalis-redis = { path = "../../packages/apalis-redis" } serde = "1" env_logger = "0.10" actix-web = "4" diff --git a/examples/actix-web/src/main.rs b/examples/actix-web/src/main.rs index 41e42be..39c786b 100644 --- a/examples/actix-web/src/main.rs +++ b/examples/actix-web/src/main.rs @@ -2,8 +2,8 @@ use actix_web::rt::signal; use actix_web::{web, App, HttpResponse, HttpServer}; use anyhow::Result; use apalis::prelude::*; -use apalis::utils::TokioExecutor; -use apalis::{layers::tracing::TraceLayer, redis::RedisStorage}; + +use apalis_redis::RedisStorage; use futures::future; use email_service::{send_email, Email}; @@ -16,7 +16,7 @@ async fn push_email( let mut storage = storage.clone(); let res = storage.push(email.into_inner()).await; match res { - Ok(jid) => HttpResponse::Ok().body(format!("Email with job_id [{jid}] added to queue")), + Ok(ctx) => HttpResponse::Ok().json(ctx), Err(e) => HttpResponse::InternalServerError().body(format!("{e}")), } } @@ -26,7 +26,7 @@ async fn main() -> Result<()> { std::env::set_var("RUST_LOG", "debug"); env_logger::init(); - let conn = apalis::redis::connect("redis://127.0.0.1/").await?; + let conn = apalis_redis::connect("redis://127.0.0.1/").await?; let storage = RedisStorage::new(conn); let data = web::Data::new(storage.clone()); let http = async { @@ -40,11 +40,12 @@ async fn main() -> Result<()> { .await?; Ok(()) }; - let worker = Monitor::::new() - .register_with_count(2, { + let worker = Monitor::new() + .register({ WorkerBuilder::new("tasty-avocado") - .layer(TraceLayer::new()) - .with_storage(storage) + .enable_tracing() + // .concurrency(2) + .backend(storage) .build_fn(send_email) }) .run_with_signal(signal::ctrl_c()); diff --git a/examples/async-std-runtime/Cargo.toml b/examples/async-std-runtime/Cargo.toml index b3645e4..f3d607f 100644 --- a/examples/async-std-runtime/Cargo.toml +++ b/examples/async-std-runtime/Cargo.toml @@ -8,13 +8,12 @@ edition = "2021" [dependencies] anyhow = "1" apalis = { path = "../../", default-features = false, features = [ - "cron", - "async-std-comp", "tracing", "retry", ] } +apalis-cron = { path = "../../packages/apalis-cron" } apalis-core = { path = "../../packages/apalis-core", default-features = false } -async-std = { version = "1.12.0", features = ["attributes"] } +async-std = { version = "1.13.0", features = ["attributes"] } serde = "1" tracing-subscriber = "0.3.11" chrono = { version = "0.4", default-features = false, features = ["clock"] } diff --git a/examples/async-std-runtime/src/main.rs b/examples/async-std-runtime/src/main.rs index 0fcd995..0d767f1 100644 --- a/examples/async-std-runtime/src/main.rs +++ b/examples/async-std-runtime/src/main.rs @@ -1,16 +1,15 @@ -use std::{future::Future, str::FromStr, time::Duration}; +use std::{str::FromStr, time::Duration}; use anyhow::Result; use apalis::{ - cron::{CronStream, Schedule}, - layers::{retry::RetryLayer, retry::RetryPolicy, tracing::MakeSpan, tracing::TraceLayer}, + layers::{retry::RetryPolicy, tracing::MakeSpan, tracing::TraceLayer}, prelude::*, }; - +use apalis_cron::{CronStream, Schedule}; use chrono::{DateTime, Utc}; use tracing::{debug, info, Instrument, Level, Span}; -type WorkerCtx = Context; +type WorkerCtx = Worker; #[derive(Default, Debug, Clone)] struct Reminder(DateTime); @@ -27,7 +26,7 @@ async fn send_in_background(reminder: Reminder) { } async fn send_reminder(reminder: Reminder, worker: WorkerCtx) -> bool { // this will happen in the workers background and wont block the next tasks - worker.spawn(send_in_background(reminder).in_current_span()); + async_std::task::spawn(worker.track(send_in_background(reminder).in_current_span())); false } @@ -43,13 +42,13 @@ async fn main() -> Result<()> { let schedule = Schedule::from_str("1/1 * * * * *").unwrap(); let worker = WorkerBuilder::new("daily-cron-worker") - .layer(RetryLayer::new(RetryPolicy::retries(5))) + .retry(RetryPolicy::retries(5)) .layer(TraceLayer::new().make_span_with(ReminderSpan::new())) - .stream(CronStream::new(schedule).into_stream()) + .backend(CronStream::new(schedule)) .build_fn(send_reminder); - Monitor::::new() - .register_with_count(2, worker) + Monitor::new() + .register(worker) .on_event(|e| debug!("Worker event: {e:?}")) .run_with_signal(async { ctrl_c.recv().await.ok(); @@ -60,22 +59,6 @@ async fn main() -> Result<()> { Ok(()) } -#[derive(Clone, Debug, Default)] -pub struct AsyncStdExecutor; - -impl AsyncStdExecutor { - /// A new async-std executor - pub fn new() -> Self { - Self - } -} - -impl Executor for AsyncStdExecutor { - fn spawn(&self, fut: impl Future + Send + 'static) { - async_std::task::spawn(fut); - } -} - #[derive(Debug, Clone)] pub struct ReminderSpan { level: Level, @@ -96,10 +79,10 @@ impl ReminderSpan { } } -impl MakeSpan for ReminderSpan { - fn make_span(&mut self, req: &Request) -> Span { - let task_id: &TaskId = req.get().unwrap(); - let attempts: Attempt = req.get().cloned().unwrap_or_default(); +impl MakeSpan for ReminderSpan { + fn make_span(&mut self, req: &Request) -> Span { + let task_id: &TaskId = &req.parts.task_id; + let attempts: &Attempt = &req.parts.attempt; let span = Span::current(); macro_rules! make_span { ($level:expr) => { diff --git a/examples/axum/Cargo.toml b/examples/axum/Cargo.toml index 95a322d..5d14a34 100644 --- a/examples/axum/Cargo.toml +++ b/examples/axum/Cargo.toml @@ -11,6 +11,7 @@ tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } serde = { version = "1.0", features = ["derive"] } -apalis = { path = "../../", features = ["redis"] } +apalis = { path = "../../" } +apalis-redis = { path = "../../packages/apalis-redis" } futures = "0.3" -email-service = { path = "../email-service" } \ No newline at end of file +email-service = { path = "../email-service" } diff --git a/examples/axum/src/main.rs b/examples/axum/src/main.rs index a4b7146..99b3146 100644 --- a/examples/axum/src/main.rs +++ b/examples/axum/src/main.rs @@ -4,8 +4,9 @@ //! cd examples && cargo run -p axum-example //! ``` use anyhow::Result; + use apalis::prelude::*; -use apalis::{layers::tracing::TraceLayer, redis::RedisStorage}; +use apalis_redis::RedisStorage; use axum::{ extract::Form, http::StatusCode, @@ -29,15 +30,15 @@ async fn add_new_job( Extension(mut storage): Extension>, ) -> impl IntoResponse where - T: 'static + Debug + Job + Serialize + DeserializeOwned + Send + Sync + Unpin, + T: 'static + Debug + Serialize + DeserializeOwned + Send + Sync + Unpin, { dbg!(&input); let new_job = storage.push(input).await; match new_job { - Ok(id) => ( + Ok(ctx) => ( StatusCode::CREATED, - format!("Job [{id}] was successfully added"), + format!("Job [{ctx:?}] was successfully added"), ), Err(e) => ( StatusCode::INTERNAL_SERVER_ERROR, @@ -55,7 +56,7 @@ async fn main() -> Result<()> { )) .with(tracing_subscriber::fmt::layer()) .init(); - let conn = apalis::redis::connect("redis://127.0.0.1/").await?; + let conn = apalis_redis::connect("redis://127.0.0.1/").await?; let storage = RedisStorage::new(conn); // build our application with some routes let app = Router::new() @@ -72,11 +73,11 @@ async fn main() -> Result<()> { .map_err(|e| Error::new(std::io::ErrorKind::Interrupted, e)) }; let monitor = async { - Monitor::::new() - .register_with_count(2, { + Monitor::new() + .register({ WorkerBuilder::new("tasty-pear") - .layer(TraceLayer::new()) - .with_storage(storage.clone()) + .enable_tracing() + .backend(storage.clone()) .build_fn(send_email) }) .run() diff --git a/examples/basics/Cargo.toml b/examples/basics/Cargo.toml index c6589c2..9016a8c 100644 --- a/examples/basics/Cargo.toml +++ b/examples/basics/Cargo.toml @@ -6,9 +6,10 @@ edition = "2021" license = "MIT OR Apache-2.0" [dependencies] -thiserror = "1" +thiserror = "2.0.0" tokio = { version = "1", features = ["full"] } -apalis = { path = "../../", features = ["sqlite", "limit", "tokio-comp"] } +apalis = { path = "../../", features = ["limit", "catch-panic"] } +apalis-sql = { path = "../../packages/apalis-sql", features = ["sqlite"] } serde = "1" tracing-subscriber = "0.3.11" email-service = { path = "../email-service" } diff --git a/examples/basics/src/layer.rs b/examples/basics/src/layer.rs index 2918346..6c817f6 100644 --- a/examples/basics/src/layer.rs +++ b/examples/basics/src/layer.rs @@ -1,10 +1,14 @@ -use std::task::{Context, Poll}; +use std::{ + fmt::Debug, + task::{Context, Poll}, +}; use apalis::prelude::Request; use tower::{Layer, Service}; use tracing::info; /// A layer that logs a job info before it starts +#[derive(Debug, Clone)] pub struct LogLayer { target: &'static str, } @@ -34,10 +38,11 @@ pub struct LogService { service: S, } -impl Service> for LogService +impl Service> for LogService where - S: Service> + Clone, - Req: std::fmt::Debug, + S: Service> + Clone, + Req: Debug, + Ctx: Debug, { type Response = S::Response; type Error = S::Error; @@ -47,7 +52,7 @@ where self.service.poll_ready(cx) } - fn call(&mut self, request: Request) -> Self::Future { + fn call(&mut self, request: Request) -> Self::Future { // Use service to apply middleware before or(and) after a request info!("request = {:?}, target = {:?}", request, self.target); self.service.call(request) diff --git a/examples/basics/src/main.rs b/examples/basics/src/main.rs index 341f8f5..82043b9 100644 --- a/examples/basics/src/main.rs +++ b/examples/basics/src/main.rs @@ -2,20 +2,17 @@ mod cache; mod layer; mod service; -use std::time::Duration; +use std::{sync::Arc, time::Duration}; -use apalis::{ - layers::tracing::TraceLayer, - prelude::*, - sqlite::{SqlitePool, SqliteStorage}, -}; +use apalis::{layers::catch_panic::CatchPanicLayer, prelude::*}; +use apalis_sql::sqlite::{SqlitePool, SqliteStorage}; use email_service::Email; use layer::LogLayer; use tracing::{log::info, Instrument, Span}; -type WorkerCtx = Context; +type WorkerCtx = Context; use crate::{cache::ValidEmailCache, service::EmailService}; @@ -35,7 +32,7 @@ async fn produce_jobs(storage: &SqliteStorage) { } #[derive(thiserror::Error, Debug)] -pub enum Error { +pub enum ServiceError { #[error("data store disconnected")] Disconnect(#[from] std::io::Error), #[error("the data for key `{0}` is not available")] @@ -46,15 +43,21 @@ pub enum Error { Unknown, } +#[derive(thiserror::Error, Debug)] +pub enum PanicError { + #[error("{0}")] + Panic(String), +} + /// Quick solution to prevent spam. /// If email in cache, then send email else complete the job but let a validation process run in the background, async fn send_email( email: Email, svc: Data, worker_ctx: Data, - worker_id: WorkerId, + worker_id: Data, cache: Data, -) -> Result<(), Error> { +) -> Result<(), ServiceError> { info!("Job started in worker {:?}", worker_id); let cache_clone = cache.clone(); let email_to = email.to.clone(); @@ -66,14 +69,16 @@ async fn send_email( // This can be important for starting long running jobs that don't block the queue // Its also possible to acquire context types and clone them into the futures context. // They will also be gracefully shutdown if [`Monitor`] has a shutdown signal - worker_ctx.spawn( - async move { - if cache::fetch_validity(email_to, &cache_clone).await { - svc.send(email).await; - info!("Email added to cache") + tokio::spawn( + worker_ctx.track( + async move { + if cache::fetch_validity(email_to, &cache_clone).await { + svc.send(email).await; + info!("Email added to cache") + } } - } - .instrument(Span::current()), // Its still gonna use the jobs current tracing span. Important eg using sentry. + .instrument(Span::current()), + ), // Its still gonna use the jobs current tracing span. Important eg using sentry. ); } @@ -96,15 +101,29 @@ async fn main() -> Result<(), std::io::Error> { let sqlite: SqliteStorage = SqliteStorage::new(pool); produce_jobs(&sqlite).await; - Monitor::::new() - .register_with_count(2, { + Monitor::new() + .register({ WorkerBuilder::new("tasty-banana") - .layer(TraceLayer::new()) + // This handles any panics that may occur in any of the layers below + // .catch_panic() + // Or just to customize + .layer(CatchPanicLayer::with_panic_handler(|e| { + let panic_info = if let Some(s) = e.downcast_ref::<&str>() { + s.to_string() + } else if let Some(s) = e.downcast_ref::() { + s.clone() + } else { + "Unknown panic".to_string() + }; + // Abort tells the backend to kill job + Error::Abort(Arc::new(Box::new(PanicError::Panic(panic_info)))) + })) + .enable_tracing() .layer(LogLayer::new("some-log-example")) // Add shared context to all jobs executed by this worker .data(EmailService::new()) .data(ValidEmailCache::new()) - .with_storage(sqlite) + .backend(sqlite) .build_fn(send_email) }) .shutdown_timeout(Duration::from_secs(5)) diff --git a/examples/catch-panic/Cargo.toml b/examples/catch-panic/Cargo.toml new file mode 100644 index 0000000..e53eed1 --- /dev/null +++ b/examples/catch-panic/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "catch-panic" +version = "0.1.0" +edition.workspace = true +repository.workspace = true + +[dependencies] +anyhow = "1" +tokio = { version = "1", features = ["full"] } +apalis = { path = "../../", features = ["limit", "tracing", "catch-panic"] } +apalis-sql = { path = "../../packages/apalis-sql", features = ["sqlite"] } +serde = { version = "1", features = ["derive"] } +tracing-subscriber = "0.3.11" +email-service = { path = "../email-service" } + + +[dependencies.tracing] +default-features = false +version = "0.1" + +[dependencies.sqlx] +version = "0.8" +default-features = false +features = ["sqlite", "runtime-tokio"] diff --git a/examples/catch-panic/src/main.rs b/examples/catch-panic/src/main.rs new file mode 100644 index 0000000..de87e09 --- /dev/null +++ b/examples/catch-panic/src/main.rs @@ -0,0 +1,54 @@ +use anyhow::Result; +use apalis::prelude::*; + +use apalis_sql::sqlite::SqliteStorage; + +use email_service::Email; +use sqlx::SqlitePool; + +async fn produce_emails(storage: &mut SqliteStorage) -> Result<()> { + for i in 0..2 { + storage + .push(Email { + to: format!("test{i}@example.com"), + text: "Test background job from apalis".to_string(), + subject: "Background email job".to_string(), + }) + .await?; + } + Ok(()) +} + +async fn send_email(_: Email) { + unimplemented!("panic from unimplemented") +} + +#[tokio::main] +async fn main() -> Result<()> { + std::env::set_var("RUST_LOG", "debug,sqlx::query=info"); + tracing_subscriber::fmt::init(); + + let pool = SqlitePool::connect("sqlite::memory:").await?; + // Do migrations: Mainly for "sqlite::memory:" + SqliteStorage::setup(&pool) + .await + .expect("unable to run migrations for sqlite"); + + let mut email_storage: SqliteStorage = SqliteStorage::new(pool.clone()); + + produce_emails(&mut email_storage).await?; + + Monitor::new() + .register({ + WorkerBuilder::new("tasty-banana") + .catch_panic() + .enable_tracing() + .concurrency(2) + .backend(email_storage) + .build_fn(send_email) + }) + .on_event(|e| tracing::info!("{e:?}")) + .run() + .await?; + Ok(()) +} diff --git a/examples/cron/Cargo.toml b/examples/cron/Cargo.toml new file mode 100644 index 0000000..21e9529 --- /dev/null +++ b/examples/cron/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "cron-example" +version = "0.1.0" +edition = "2021" + +[dependencies] +anyhow = "1" +apalis = { path = "../../", default-features = false, features = [ + "tracing", + "limit", + "catch-panic", +] } +apalis-cron = { path = "../../packages/apalis-cron" } +tokio = { version = "1", features = ["full"] } +serde = "1" +tracing-subscriber = "0.3.11" +chrono = { version = "0.4", default-features = false, features = ["clock"] } +pin-project-lite = "0.2.9" +tower = { version = "0.4", features = ["load-shed"] } + +[dependencies.tracing] +default-features = false +version = "0.1" diff --git a/examples/cron/src/main.rs b/examples/cron/src/main.rs new file mode 100644 index 0000000..6826452 --- /dev/null +++ b/examples/cron/src/main.rs @@ -0,0 +1,41 @@ +use apalis::prelude::*; + +use apalis_cron::CronStream; +use apalis_cron::Schedule; +use chrono::{DateTime, Utc}; +use std::str::FromStr; +use std::time::Duration; +// use std::time::Duration; +use tower::load_shed::LoadShedLayer; + +#[derive(Clone)] +struct FakeService; +impl FakeService { + fn execute(&self, item: Reminder) { + dbg!(&item.0); + } +} + +#[derive(Default, Debug, Clone)] +struct Reminder(DateTime); +impl From> for Reminder { + fn from(t: DateTime) -> Self { + Reminder(t) + } +} +async fn send_reminder(job: Reminder, svc: Data) { + svc.execute(job); +} + +#[tokio::main] +async fn main() { + let schedule = Schedule::from_str("1/1 * * * * *").unwrap(); + let worker = WorkerBuilder::new("morning-cereal") + .enable_tracing() + .layer(LoadShedLayer::new()) // Important when you have layers that block the service + .rate_limit(1, Duration::from_secs(2)) + .data(FakeService) + .backend(CronStream::new(schedule)) + .build_fn(send_reminder); + Monitor::new().register(worker).run().await.unwrap(); +} diff --git a/examples/email-service/Cargo.toml b/examples/email-service/Cargo.toml index 8ede34e..3aca96a 100644 --- a/examples/email-service/Cargo.toml +++ b/examples/email-service/Cargo.toml @@ -8,4 +8,5 @@ apalis = { path = "../../", default-features = false } futures-util = "0.3.0" serde_json = "1.0" serde = { version = "1.0", features = ["derive"] } -log = "0.4" \ No newline at end of file +log = "0.4" +email_address = "0.2.5" diff --git a/examples/email-service/src/lib.rs b/examples/email-service/src/lib.rs index 323e115..252fee6 100644 --- a/examples/email-service/src/lib.rs +++ b/examples/email-service/src/lib.rs @@ -1,4 +1,7 @@ +use std::{str::FromStr, sync::Arc}; + use apalis::prelude::*; +use email_address::EmailAddress; use serde::{Deserialize, Serialize}; #[derive(Debug, Deserialize, Serialize, Clone)] @@ -8,12 +11,45 @@ pub struct Email { pub text: String, } -impl Job for Email { - const NAME: &'static str = "apalis::Email"; +pub async fn send_email(job: Email) -> Result<(), Error> { + let validation = EmailAddress::from_str(&job.to); + match validation { + Ok(email) => { + log::info!("Attempting to send email to {}", email.as_str()); + Ok(()) + } + Err(email_address::Error::InvalidCharacter) => { + log::error!("Killed send email job. Invalid character {}", job.to); + Err(Error::Abort(Arc::new(Box::new( + email_address::Error::InvalidCharacter, + )))) + } + Err(e) => Err(Error::Failed(Arc::new(Box::new(e)))), + } +} + +pub fn example_good_email() -> Email { + Email { + subject: "Test Subject".to_string(), + to: "example@gmail.com".to_string(), + text: "Some Text".to_string(), + } } -pub async fn send_email(job: Email) { - log::info!("Attempting to send email to {}", job.to); +pub fn example_killed_email() -> Email { + Email { + subject: "Test Subject".to_string(), + to: "example@ÂĐ.com".to_string(), // killed because it has ÂĐ which is invalid + text: "Some Text".to_string(), + } +} + +pub fn example_retry_able_email() -> Email { + Email { + subject: "Test Subject".to_string(), + to: "example".to_string(), + text: "Some Text".to_string(), + } } pub const FORM_HTML: &str = r#" diff --git a/examples/fn-args/Cargo.toml b/examples/fn-args/Cargo.toml new file mode 100644 index 0000000..ddfc725 --- /dev/null +++ b/examples/fn-args/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "fn-args" +version = "0.1.0" +edition.workspace = true +repository.workspace = true + +[dependencies] +tokio = { version = "1", features = ["full"] } +apalis = { path = "../../", features = ["limit", "catch-panic"] } +apalis-sql = { path = "../../packages/apalis-sql", features = [ + "sqlite", + "tokio-comp", +] } +serde = "1" +tracing-subscriber = "0.3.11" +futures = "0.3" +tower = "0.4" + + +[dependencies.tracing] +default-features = false +version = "0.1" diff --git a/examples/fn-args/src/main.rs b/examples/fn-args/src/main.rs new file mode 100644 index 0000000..99c3949 --- /dev/null +++ b/examples/fn-args/src/main.rs @@ -0,0 +1,73 @@ +use std::{ + ops::Deref, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; + +use apalis::prelude::*; +use apalis_sql::{ + context::SqlContext, + sqlite::{SqlitePool, SqliteStorage}, +}; +use serde::{Deserialize, Serialize}; +use tracing::info; + +#[derive(Debug, Serialize, Deserialize)] +struct SimpleJob {} + +// A task can have up to 16 arguments +async fn simple_job( + _: SimpleJob, // Required, must be of the type of the job/message + worker: Worker, // The worker and its context, added by worker + _sqlite: Data>, // The source, added by storage + task_id: TaskId, // The task id, added by storage + attempt: Attempt, // The current attempt + ctx: SqlContext, // The task context provided by the backend + count: Data, // Our custom data added via layer +) { + // increment the counter + let current = count.fetch_add(1, Ordering::Relaxed); + info!("worker: {worker:?}; task_id: {task_id:?}, ctx: {ctx:?}, attempt:{attempt:?} count: {current:?}"); +} + +async fn produce_jobs(storage: &mut SqliteStorage) { + for _ in 0..10 { + storage.push(SimpleJob {}).await.unwrap(); + } +} + +#[derive(Clone, Debug, Default)] +struct Count(Arc); + +impl Deref for Count { + type Target = Arc; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[tokio::main] +async fn main() -> Result<(), std::io::Error> { + std::env::set_var("RUST_LOG", "debug,sqlx::query=error"); + tracing_subscriber::fmt::init(); + let pool = SqlitePool::connect("sqlite::memory:").await.unwrap(); + SqliteStorage::setup(&pool) + .await + .expect("unable to run migrations for sqlite"); + let mut sqlite: SqliteStorage = SqliteStorage::new(pool); + produce_jobs(&mut sqlite).await; + Monitor::new() + .register({ + WorkerBuilder::new("tasty-banana") + .data(Count::default()) + .data(sqlite.clone()) + .concurrency(2) + .backend(sqlite) + .build_fn(simple_job) + }) + .run() + .await?; + Ok(()) +} diff --git a/examples/graceful-shutdown/Cargo.toml b/examples/graceful-shutdown/Cargo.toml new file mode 100644 index 0000000..7ae58a7 --- /dev/null +++ b/examples/graceful-shutdown/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "graceful-shutdown" +version = "0.1.0" +edition.workspace = true +repository.workspace = true + +[dependencies] +thiserror = "2.0.0" +tokio = { version = "1", features = ["full"] } +apalis = { path = "../../", features = ["limit", "catch-panic"] } +apalis-sql = { path = "../../packages/apalis-sql", features = [ + "sqlite", + "tokio-comp", +] } +serde = "1" +tracing-subscriber = "0.3.11" +futures = "0.3" +tower = "0.4" + + +[dependencies.tracing] +default-features = false +version = "0.1" diff --git a/examples/graceful-shutdown/src/main.rs b/examples/graceful-shutdown/src/main.rs new file mode 100644 index 0000000..c056603 --- /dev/null +++ b/examples/graceful-shutdown/src/main.rs @@ -0,0 +1,52 @@ +use std::time::Duration; + +use apalis::prelude::*; +use apalis_sql::sqlite::{SqlitePool, SqliteStorage}; +use serde::{Deserialize, Serialize}; +use tracing::info; + +#[derive(Debug, Serialize, Deserialize)] +struct LongRunningJob {} + +async fn long_running_task(_task: LongRunningJob, worker: Worker) { + loop { + info!("is_shutting_down: {}", worker.is_shutting_down()); + if worker.is_shutting_down() { + info!("saving the job state"); + break; + } + tokio::time::sleep(Duration::from_secs(3)).await; // Do some hard thing + } + info!("Shutdown complete!"); +} + +async fn produce_jobs(storage: &mut SqliteStorage) { + storage.push(LongRunningJob {}).await.unwrap(); +} + +#[tokio::main] +async fn main() -> Result<(), std::io::Error> { + std::env::set_var("RUST_LOG", "debug,sqlx::query=error"); + tracing_subscriber::fmt::init(); + let pool = SqlitePool::connect("sqlite::memory:").await.unwrap(); + SqliteStorage::setup(&pool) + .await + .expect("unable to run migrations for sqlite"); + let mut sqlite: SqliteStorage = SqliteStorage::new(pool); + produce_jobs(&mut sqlite).await; + Monitor::new() + .register({ + WorkerBuilder::new("tasty-banana") + .concurrency(2) + .enable_tracing() + .backend(sqlite) + .build_fn(long_running_task) + }) + .on_event(|e| info!("{e}")) + // Wait 5 seconds after shutdown is triggered to allow any incomplete jobs to complete + .shutdown_timeout(Duration::from_secs(5)) + // Use .run() if you don't want without signals + .run_with_signal(tokio::signal::ctrl_c()) // This will wait for ctrl+c then gracefully shutdown + .await?; + Ok(()) +} diff --git a/examples/mysql/Cargo.toml b/examples/mysql/Cargo.toml index bbd3a96..d8a6500 100644 --- a/examples/mysql/Cargo.toml +++ b/examples/mysql/Cargo.toml @@ -7,11 +7,8 @@ license = "MIT OR Apache-2.0" [dependencies] anyhow = "1" -apalis = { path = "../../", features = [ - "mysql", - "tokio-comp", - "tracing", -], default-features = false } +apalis = { path = "../../", features = ["tracing"], default-features = false } +apalis-sql = { path = "../../packages/apalis-sql", features = ["mysql"] } serde = "1" tracing-subscriber = "0.3.11" chrono = { version = "0.4", default-features = false, features = ["clock"] } diff --git a/examples/mysql/src/main.rs b/examples/mysql/src/main.rs index 9ec2dca..cf91e78 100644 --- a/examples/mysql/src/main.rs +++ b/examples/mysql/src/main.rs @@ -1,7 +1,8 @@ use anyhow::Result; -use apalis::mysql::MySqlPool; + use apalis::prelude::*; -use apalis::{layers::tracing::TraceLayer, mysql::MysqlStorage}; +use apalis_sql::mysql::MySqlPool; +use apalis_sql::mysql::MysqlStorage; use email_service::{send_email, Email}; async fn produce_jobs(storage: &MysqlStorage) -> Result<()> { @@ -32,11 +33,11 @@ async fn main() -> Result<()> { let mysql: MysqlStorage = MysqlStorage::new(pool); produce_jobs(&mysql).await?; - Monitor::new_with_executor(TokioExecutor) - .register_with_count(1, { + Monitor::new() + .register({ WorkerBuilder::new("tasty-avocado") - .layer(TraceLayer::new()) - .with_storage(mysql) + .enable_tracing() + .backend(mysql) .build_fn(send_email) }) .run() diff --git a/examples/persisted-cron/Cargo.toml b/examples/persisted-cron/Cargo.toml new file mode 100644 index 0000000..ef4fb65 --- /dev/null +++ b/examples/persisted-cron/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "persisted-cron" +version = "0.1.0" +edition.workspace = true +repository.workspace = true + +[dependencies] +anyhow = "1" +apalis = { path = "../../", default-features = false, features = [ + "tracing", + "limit", + "catch-panic", +] } +apalis-cron = { path = "../../packages/apalis-cron" } +apalis-sql = { path = "../../packages/apalis-sql", features = [ + "sqlite", + "tokio-comp", +] } +tokio = { version = "1", features = ["full"] } +serde = "1" +tracing-subscriber = "0.3.11" +chrono = { version = "0.4", default-features = false, features = ["clock"] } +pin-project-lite = "0.2.9" +tower = { version = "0.4", features = ["load-shed"] } + +[dependencies.tracing] +default-features = false +version = "0.1" diff --git a/examples/persisted-cron/src/main.rs b/examples/persisted-cron/src/main.rs new file mode 100644 index 0000000..eeb3028 --- /dev/null +++ b/examples/persisted-cron/src/main.rs @@ -0,0 +1,57 @@ +use apalis::prelude::*; + +use apalis_cron::CronStream; +use apalis_cron::Schedule; +use apalis_sql::sqlite::SqliteStorage; +use apalis_sql::sqlx::SqlitePool; +use chrono::{DateTime, Utc}; +use serde::Deserialize; +use serde::Serialize; +use std::str::FromStr; +use std::time::Duration; + +#[derive(Clone)] +struct FakeService; +impl FakeService { + fn execute(&self, item: Reminder) { + dbg!(&item.0); + } +} + +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +struct Reminder(DateTime); +impl From> for Reminder { + fn from(t: DateTime) -> Self { + Reminder(t) + } +} +async fn send_reminder(job: Reminder, svc: Data) { + svc.execute(job); +} + +#[tokio::main] +async fn main() { + std::env::set_var("RUST_LOG", "debug,sqlx::query=error"); + tracing_subscriber::fmt::init(); + + // We create our cron jobs stream + let schedule = Schedule::from_str("1/1 * * * * *").unwrap(); + let cron_stream = CronStream::new(schedule); + + // Lets create a storage for our cron jobs + let pool = SqlitePool::connect("sqlite::memory:").await.unwrap(); + SqliteStorage::setup(&pool) + .await + .expect("unable to run migrations for sqlite"); + let sqlite = SqliteStorage::new(pool); + + let backend = cron_stream.pipe_to_storage(sqlite); + + let worker = WorkerBuilder::new("morning-cereal") + .enable_tracing() + .rate_limit(1, Duration::from_secs(2)) + .data(FakeService) + .backend(backend) + .build_fn(send_reminder); + Monitor::new().register(worker).run().await.unwrap(); +} diff --git a/examples/postgres/Cargo.toml b/examples/postgres/Cargo.toml index bff7094..151c8a4 100644 --- a/examples/postgres/Cargo.toml +++ b/examples/postgres/Cargo.toml @@ -7,7 +7,8 @@ license = "MIT OR Apache-2.0" [dependencies] anyhow = "1" -apalis = { path = "../../", features = ["postgres", "retry"] } +apalis = { path = "../../", features = ["retry"] } +apalis-sql = { path = "../../packages/apalis-sql", features = ["postgres"] } serde = "1" tracing-subscriber = "0.3.11" chrono = { version = "0.4", default-features = false, features = ["clock"] } diff --git a/examples/postgres/src/main.rs b/examples/postgres/src/main.rs index b4983a1..1b4ba9b 100644 --- a/examples/postgres/src/main.rs +++ b/examples/postgres/src/main.rs @@ -1,14 +1,12 @@ use anyhow::Result; use apalis::layers::retry::RetryPolicy; -use apalis::postgres::PgPool; + use apalis::prelude::*; -use apalis::{layers::retry::RetryLayer, layers::tracing::TraceLayer, postgres::PostgresStorage}; +use apalis_sql::postgres::{PgListen, PgPool, PostgresStorage}; use email_service::{send_email, Email}; use tracing::{debug, info}; -async fn produce_jobs(storage: &PostgresStorage) -> Result<()> { - // The programmatic way - let mut storage = storage.clone(); +async fn produce_jobs(storage: &mut PostgresStorage) -> Result<()> { for index in 0..10 { storage .push(Email { @@ -34,18 +32,26 @@ async fn main() -> Result<()> { .await .expect("unable to run migrations for postgres"); - let pg = PostgresStorage::new(pool); - produce_jobs(&pg).await?; + let mut pg = PostgresStorage::new(pool.clone()); + produce_jobs(&mut pg).await?; + + let mut listener = PgListen::new(pool).await?; + + listener.subscribe_with(&mut pg); + + tokio::spawn(async move { + listener.listen().await.unwrap(); + }); - Monitor::::new() - .register_with_count(4, { + Monitor::new() + .register({ WorkerBuilder::new("tasty-orange") - .layer(TraceLayer::new()) - .layer(RetryLayer::new(RetryPolicy::retries(5))) - .with_storage(pg.clone()) + .enable_tracing() + .retry(RetryPolicy::retries(5)) + .backend(pg) .build_fn(send_email) }) - .on_event(|e| debug!("{e:?}")) + .on_event(|e| debug!("{e}")) .run_with_signal(async { tokio::signal::ctrl_c().await?; info!("Shutting down the system"); diff --git a/examples/prometheus/Cargo.toml b/examples/prometheus/Cargo.toml index 32a1647..1668918 100644 --- a/examples/prometheus/Cargo.toml +++ b/examples/prometheus/Cargo.toml @@ -11,7 +11,8 @@ tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } serde = { version = "1.0", features = ["derive"] } -apalis = { path = "../../", features = ["redis", "prometheus"] } +apalis = { path = "../../", features = ["prometheus"] } +apalis-redis = { path = "../../packages/apalis-redis" } futures = "0.3" metrics = "0.21" metrics-exporter-prometheus = "0.12" diff --git a/examples/prometheus/src/main.rs b/examples/prometheus/src/main.rs index 1288b41..0f85eb0 100644 --- a/examples/prometheus/src/main.rs +++ b/examples/prometheus/src/main.rs @@ -4,8 +4,9 @@ //! cd examples && cargo run -p prometheus-example //! ``` use anyhow::Result; +use apalis::layers::prometheus::PrometheusLayer; use apalis::prelude::*; -use apalis::{layers::prometheus::PrometheusLayer, redis::RedisStorage}; +use apalis_redis::RedisStorage; use axum::{ extract::Form, http::StatusCode, @@ -29,7 +30,7 @@ async fn main() -> Result<()> { )) .with(tracing_subscriber::fmt::layer()) .init(); - let conn = apalis::redis::connect("redis://127.0.0.1/").await?; + let conn = apalis_redis::connect("redis://127.0.0.1/").await?; let storage = RedisStorage::new(conn); // build our application with some routes let recorder_handle = setup_metrics_recorder(); @@ -47,11 +48,11 @@ async fn main() -> Result<()> { .map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)) }; let monitor = async { - Monitor::::new() - .register_with_count(2, { + Monitor::new() + .register({ WorkerBuilder::new("tasty-banana") - .layer(PrometheusLayer) - .with_storage(storage.clone()) + .layer(PrometheusLayer::default()) + .backend(storage.clone()) .build_fn(send_email) }) .run() @@ -87,15 +88,15 @@ async fn add_new_job( Extension(mut storage): Extension>, ) -> impl IntoResponse where - T: 'static + Debug + Job + Serialize + DeserializeOwned + Unpin + Send + Sync, + T: 'static + Debug + Serialize + DeserializeOwned + Unpin + Send + Sync, { dbg!(&input); let new_job = storage.push(input).await; match new_job { - Ok(jid) => ( + Ok(ctx) => ( StatusCode::CREATED, - format!("Job [{jid}] was successfully added"), + format!("Job [{ctx:?}] was successfully added"), ), Err(e) => ( StatusCode::INTERNAL_SERVER_ERROR, diff --git a/examples/redis-deadpool/Cargo.toml b/examples/redis-deadpool/Cargo.toml new file mode 100644 index 0000000..be1ac8f --- /dev/null +++ b/examples/redis-deadpool/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "redis-deadpool" +version = "0.1.0" +edition = "2021" + +[dependencies] +deadpool-redis = { version = "0.18" } +anyhow = "1" +tokio = { version = "1", features = ["full"] } +apalis = { path = "../../", features = ["timeout"] } +apalis-redis = { path = "../../packages/apalis-redis" } +serde = "1" +env_logger = "0.10" +tracing-subscriber = "0.3.11" +chrono = { version = "0.4", default-features = false, features = ["clock"] } +email-service = { path = "../email-service" } +rmp-serde = "1.3" + + +[dependencies.tracing] +default-features = false +version = "0.1" diff --git a/examples/redis-deadpool/src/main.rs b/examples/redis-deadpool/src/main.rs new file mode 100644 index 0000000..74f5829 --- /dev/null +++ b/examples/redis-deadpool/src/main.rs @@ -0,0 +1,57 @@ +use std::time::Duration; + +use anyhow::Result; +use apalis::prelude::*; +use apalis_redis::RedisStorage; + +use deadpool_redis::{Config, Connection, Runtime}; +use email_service::{send_email, Email}; +use tracing::info; + +#[tokio::main] +async fn main() -> Result<()> { + std::env::set_var("RUST_LOG", "debug"); + + tracing_subscriber::fmt::init(); + + let config = apalis_redis::Config::default() + .set_namespace("apalis_redis-dead-pool") + .set_max_retries(5); + + let cfg = Config::from_url("redis://127.0.0.1/"); + let pool = cfg.create_pool(Some(Runtime::Tokio1)).unwrap(); + let conn = pool.get().await.unwrap(); + let mut storage = RedisStorage::new_with_config(conn, config); + // This can be in another part of the program + produce_jobs(&mut storage).await?; + + let worker = WorkerBuilder::new("rango-tango") + .data(pool) + .backend(storage) + .build_fn(send_email); + + Monitor::new() + .register(worker) + .shutdown_timeout(Duration::from_millis(5000)) + .run_with_signal(async { + tokio::signal::ctrl_c().await?; + info!("Monitor starting shutdown"); + Ok(()) + }) + .await?; + info!("Monitor shutdown complete"); + Ok(()) +} + +async fn produce_jobs(storage: &mut RedisStorage) -> Result<()> { + for index in 0..10 { + storage + .push(Email { + to: index.to_string(), + text: "Test background job from apalis".to_string(), + subject: "Background email job".to_string(), + }) + .await?; + } + Ok(()) +} diff --git a/examples/redis-mq-example/Cargo.toml b/examples/redis-mq-example/Cargo.toml new file mode 100644 index 0000000..d841cb7 --- /dev/null +++ b/examples/redis-mq-example/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "redis-mq-example" +version = "0.1.0" +edition = "2021" + +[dependencies] +apalis = { path = "../.." } +apalis-redis = { path = "../../packages/apalis-redis" } +apalis-core = { path = "../../packages/apalis-core", features = ["json"] } +rsmq_async = "11.1.0" +anyhow = "1" +tokio = { version = "1", features = ["full"] } +serde = "1" +env_logger = "0.10" +tracing-subscriber = "0.3.11" +chrono = { version = "0.4", default-features = false, features = ["clock"] } +email-service = { path = "../email-service" } +rmp-serde = "1.3" +tower = "0.4" +futures = "0.3" + + +[dependencies.tracing] +default-features = false +version = "0.1" diff --git a/examples/redis-mq-example/src/main.rs b/examples/redis-mq-example/src/main.rs new file mode 100644 index 0000000..1cfe723 --- /dev/null +++ b/examples/redis-mq-example/src/main.rs @@ -0,0 +1,192 @@ +use std::{fmt::Debug, marker::PhantomData, time::Duration}; + +use apalis::prelude::*; + +use apalis_redis::{self, Config}; + +use apalis_core::{ + codec::json::JsonCodec, + layers::{Ack, AckLayer}, + response::Response, +}; +use email_service::{send_email, Email}; +use futures::{channel::mpsc, SinkExt}; +use rsmq_async::{Rsmq, RsmqConnection, RsmqError}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use tokio::time::sleep; +use tracing::info; + +struct RedisMq>> { + conn: Rsmq, + msg_type: PhantomData, + config: Config, + codec: PhantomData, +} + +#[derive(Clone, Debug, Serialize, Deserialize, Default)] +pub struct RedisMqContext { + max_attempts: usize, + message_id: String, +} + +impl FromRequest> for RedisMqContext { + fn from_request(req: &Request) -> Result { + Ok(req.parts.context.clone()) + } +} + +// Manually implement Clone for RedisMq +impl Clone for RedisMq { + fn clone(&self) -> Self { + RedisMq { + conn: self.conn.clone(), + msg_type: PhantomData, + config: self.config.clone(), + codec: self.codec, + } + } +} + +impl Backend, Res> for RedisMq +where + Req: Send + DeserializeOwned + 'static, + C: Codec>, +{ + type Stream = RequestStream>; + + type Layer = AckLayer; + + fn poll(mut self, _worker: &Worker) -> Poller { + let (mut tx, rx) = mpsc::channel(self.config.get_buffer_size()); + let stream: RequestStream> = Box::pin(rx); + let layer = AckLayer::new(self.clone()); + let heartbeat = async move { + loop { + sleep(*self.config.get_poll_interval()).await; + let msg: Option> = self + .conn + .receive_message(self.config.get_namespace(), None) + .await + .unwrap() + .map(|r| { + let mut req: Request = + C::decode(r.message).map_err(Into::into).unwrap(); + req.insert(r.id); + req + }); + tx.send(Ok(msg)).await.unwrap(); + } + }; + Poller::new_with_layer(stream, heartbeat, layer) + } +} + +impl Ack for RedisMq +where + T: Send, + Res: Debug + Send + Sync, + C: Send, +{ + type Context = RedisMqContext; + + type AckError = RsmqError; + + async fn ack( + &mut self, + ctx: &Self::Context, + res: &Response, + ) -> Result<(), Self::AckError> { + if res.is_success() || res.attempt.current() >= ctx.max_attempts { + self.conn + .delete_message(self.config.get_namespace(), &ctx.message_id) + .await?; + } + Ok(()) + } +} + +impl MessageQueue for RedisMq +where + Message: Send + Serialize + DeserializeOwned + 'static, + C: Codec> + Send, +{ + type Error = RsmqError; + + async fn enqueue(&mut self, message: Message) -> Result<(), Self::Error> { + let bytes = C::encode(Request::::new(message)) + .map_err(Into::into) + .unwrap(); + self.conn + .send_message(self.config.get_namespace(), bytes, None) + .await?; + Ok(()) + } + + async fn dequeue(&mut self) -> Result, Self::Error> { + Ok(self + .conn + .receive_message(self.config.get_namespace(), None) + .await? + .map(|r| { + let req: Request = + C::decode(r.message).map_err(Into::into).unwrap(); + req.args + })) + } + + async fn size(&mut self) -> Result { + self.conn + .get_queue_attributes(self.config.get_namespace()) + .await? + .msgs + .try_into() + .map_err(|_| RsmqError::InvalidFormat("Could not convert to usize".to_owned())) + } +} + +async fn produce_jobs(mq: &mut RedisMq) -> anyhow::Result<()> { + for index in 0..1 { + mq.enqueue(Email { + to: index.to_string(), + text: "Test background job from apalis".to_string(), + subject: "Background email job".to_string(), + }) + .await?; + } + Ok(()) +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + std::env::set_var("RUST_LOG", "debug"); + + tracing_subscriber::fmt::init(); + + let mut conn = rsmq_async::Rsmq::new(Default::default()).await?; + let _ = conn.create_queue("email", None, None, None).await; + let mut mq = RedisMq { + conn, + msg_type: PhantomData, + codec: PhantomData, + config: Config::default().set_namespace("email"), + }; + produce_jobs(&mut mq).await?; + + let worker = WorkerBuilder::new("rango-tango") + .enable_tracing() + .backend(mq) + .build_fn(send_email); + + Monitor::new() + .register(worker) + .on_event(|e| info!("{e}")) + .shutdown_timeout(Duration::from_millis(5000)) + .run_with_signal(async { + tokio::signal::ctrl_c().await?; + info!("Monitor starting shutdown"); + Ok(()) + }) + .await?; + info!("Monitor shutdown complete"); + Ok(()) +} diff --git a/examples/redis-with-msg-pack/Cargo.toml b/examples/redis-with-msg-pack/Cargo.toml new file mode 100644 index 0000000..c6ec5cc --- /dev/null +++ b/examples/redis-with-msg-pack/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "redis-with-msg-pack" +version = "0.1.0" +edition = "2021" + +[dependencies] +anyhow = "1" +tokio = { version = "1", features = ["full"] } +apalis = { path = "../../", features = ["timeout"] } +apalis-redis = { path = "../../packages/apalis-redis" } +serde = "1" +env_logger = "0.10" +tracing-subscriber = "0.3.11" +chrono = { version = "0.4", default-features = false, features = ["clock"] } +email-service = { path = "../email-service" } +rmp-serde = "1.3" +redis = "0.27" + + +[dependencies.tracing] +default-features = false +version = "0.1" diff --git a/examples/redis-with-msg-pack/src/main.rs b/examples/redis-with-msg-pack/src/main.rs new file mode 100644 index 0000000..61613c5 --- /dev/null +++ b/examples/redis-with-msg-pack/src/main.rs @@ -0,0 +1,73 @@ +use std::{sync::Arc, time::Duration}; + +use anyhow::Result; +use apalis::prelude::*; +use apalis_redis::RedisStorage; + +use email_service::{send_email, Email}; +use redis::aio::ConnectionManager; +use serde::{Deserialize, Serialize}; +use tracing::info; + +struct MessagePack; + +impl Codec for MessagePack { + type Compact = Vec; + type Error = Error; + fn encode(input: T) -> Result, Self::Error> { + rmp_serde::to_vec(&input).map_err(|e| Error::SourceError(Arc::new(Box::new(e)))) + } + + fn decode(compact: Vec) -> Result + where + O: for<'de> Deserialize<'de>, + { + rmp_serde::from_slice(&compact).map_err(|e| Error::SourceError(Arc::new(Box::new(e)))) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + std::env::set_var("RUST_LOG", "debug"); + + tracing_subscriber::fmt::init(); + + let conn = apalis_redis::connect("redis://127.0.0.1/").await?; + let config = apalis_redis::Config::default() + .set_namespace("apalis_redis-with-msg-pack") + .set_max_retries(5); + let storage = RedisStorage::new_with_codec::(conn, config); + // This can be in another part of the program + produce_jobs(storage.clone()).await?; + + let worker = WorkerBuilder::new("rango-tango") + .backend(storage) + .build_fn(send_email); + + Monitor::new() + .register(worker) + .shutdown_timeout(Duration::from_millis(5000)) + .run_with_signal(async { + tokio::signal::ctrl_c().await?; + info!("Monitor starting shutdown"); + Ok(()) + }) + .await?; + info!("Monitor shutdown complete"); + Ok(()) +} + +async fn produce_jobs( + mut storage: RedisStorage, +) -> Result<()> { + for index in 0..10 { + storage + .push(Email { + to: index.to_string(), + text: "Test background job from apalis".to_string(), + subject: "Background email job".to_string(), + }) + .await?; + } + Ok(()) +} diff --git a/examples/redis/Cargo.toml b/examples/redis/Cargo.toml index 2f8a9be..f3549b7 100644 --- a/examples/redis/Cargo.toml +++ b/examples/redis/Cargo.toml @@ -8,7 +8,8 @@ license = "MIT OR Apache-2.0" [dependencies] anyhow = "1" tokio = { version = "1", features = ["full"] } -apalis = { path = "../../", features = ["redis", "timeout"]} +apalis = { path = "../../", features = ["timeout", "limit"] } +apalis-redis = { path = "../../packages/apalis-redis" } serde = "1" env_logger = "0.10" tracing-subscriber = "0.3.11" diff --git a/examples/redis/src/main.rs b/examples/redis/src/main.rs index d9b80e7..708825e 100644 --- a/examples/redis/src/main.rs +++ b/examples/redis/src/main.rs @@ -1,18 +1,15 @@ -use std::{ - ops::Deref, - sync::{atomic::AtomicUsize, Arc}, - time::Duration, -}; +use std::time::Duration; use anyhow::Result; +use apalis::layers::ErrorHandlingLayer; use apalis::prelude::*; -use apalis::redis::RedisStorage; +use apalis_redis::RedisStorage; use email_service::{send_email, Email}; use tracing::{error, info}; async fn produce_jobs(mut storage: RedisStorage) -> Result<()> { - for index in 0..1 { + for index in 0..10 { storage .push(Email { to: index.to_string(), @@ -24,36 +21,28 @@ async fn produce_jobs(mut storage: RedisStorage) -> Result<()> { Ok(()) } -#[derive(Clone, Debug, Default)] -struct Count(Arc); - -impl Deref for Count { - type Target = Arc; - fn deref(&self) -> &Self::Target { - &self.0 - } -} - #[tokio::main] async fn main() -> Result<()> { std::env::set_var("RUST_LOG", "debug"); tracing_subscriber::fmt::init(); - let conn = apalis::redis::connect("redis://127.0.0.1/").await?; - let config = apalis::redis::Config::default(); - let storage = RedisStorage::new_with_config(conn, config); + let conn = apalis_redis::connect("redis://127.0.0.1/").await?; + let storage = RedisStorage::new(conn); // This can be in another part of the program produce_jobs(storage.clone()).await?; let worker = WorkerBuilder::new("rango-tango") - .chain(|svc| svc.timeout(Duration::from_millis(500))) - .data(Count::default()) - .with_storage(storage) + .layer(ErrorHandlingLayer::new()) + .enable_tracing() + .rate_limit(5, Duration::from_secs(1)) + .timeout(Duration::from_millis(500)) + .concurrency(2) + .backend(storage) .build_fn(send_email); - Monitor::::new() - .register_with_count(2, worker) + Monitor::new() + .register(worker) .on_event(|e| { let worker_id = e.id(); match e.inner() { @@ -72,6 +61,7 @@ async fn main() -> Result<()> { }) .shutdown_timeout(Duration::from_millis(5000)) .run_with_signal(async { + info!("Monitor started"); tokio::signal::ctrl_c().await?; info!("Monitor starting shutdown"); Ok(()) diff --git a/examples/rest-api/Cargo.toml b/examples/rest-api/Cargo.toml index a6bd252..0dfc63a 100644 --- a/examples/rest-api/Cargo.toml +++ b/examples/rest-api/Cargo.toml @@ -1,19 +1,15 @@ [package] -name = "rest-api-example" +name = "rest-api" version = "0.1.0" -authors = ["Njuguna Mureithi "] -edition = "2018" -license = "MIT OR Apache-2.0" +edition.workspace = true +repository.workspace = true [dependencies] anyhow = "1" -apalis = { path = "../../", features = ["redis", "sqlite", "sentry", "postgres", "mysql", "expose"] } +apalis = { path = "../../" } +apalis-redis = { path = "../../packages/apalis-redis" } serde = "1" -tokio = { version = "1", features =["macros", "rt-multi-thread"] } env_logger = "0.10" actix-web = "4" futures = "0.3" -actix-cors = "0.6.1" -serde_json = "1" -chrono = { version = "0.4", default-features = false, features = ["clock"] } email-service = { path = "../email-service" } diff --git a/examples/rest-api/README.md b/examples/rest-api/README.md index 8a3c220..ea4f534 100644 --- a/examples/rest-api/README.md +++ b/examples/rest-api/README.md @@ -2,16 +2,4 @@ ![UI](https://github.com/geofmureithi/apalis-board/raw/master/screenshots/workers.png) -## Backend - -``` -cd examples && cargo run -p rest-api -``` - -## Frontend - -``` -git clone https://github.com/geofmureithi/apalis-board -cd apalis-board -yarn && yarn start:dev -``` +Please see https://github.com/geofmureithi/apalis-board for a working example diff --git a/examples/rest-api/src/main.rs b/examples/rest-api/src/main.rs index 0f90ff0..e8bc43a 100644 --- a/examples/rest-api/src/main.rs +++ b/examples/rest-api/src/main.rs @@ -1,349 +1,108 @@ -use std::collections::HashSet; -use std::time::Duration; +use actix_web::rt::signal; +use actix_web::{web, App, HttpResponse, HttpServer}; +use anyhow::Result; +use apalis::prelude::*; -use actix_cors::Cors; -use actix_web::{web, App, HttpResponse, HttpServer, Scope}; -use apalis::{ - layers::{SentryJobLayer, TraceLayer}, - mysql::MysqlStorage, - postgres::PostgresStorage, - prelude::*, - redis::RedisStorage, - sqlite::SqliteStorage, -}; +use apalis_redis::RedisStorage; use futures::future; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; use email_service::{send_email, Email}; +use serde::{Deserialize, Serialize}; -#[derive(Debug, Deserialize, Serialize)] -struct Notification { - text: String, -} - -impl Job for Notification { - const NAME: &'static str = "sqlite::Notification"; -} - -async fn notification_service(notif: Notification) -> Result<(), Error> { - println!("Attempting to send notification {}", notif.text); - tokio::time::sleep(Duration::from_millis(1)).await; - Ok(()) -} - -#[derive(Debug, Deserialize, Serialize)] -struct Document { - text: String, -} - -impl Job for Document { - const NAME: &'static str = "postgres::Document"; -} - -async fn document_service(doc: Document) -> Result<(), Error> { - println!("Attempting to convert {} to pdf", doc.text); - tokio::time::sleep(Duration::from_millis(1)).await; - Ok(()) -} - -#[derive(Debug, Deserialize, Serialize)] -struct Upload { - url: String, +#[derive(Deserialize, Debug)] +struct Filter { + #[serde(default)] + pub status: State, + #[serde(default = "default_page")] + pub page: i32, } -impl Job for Upload { - const NAME: &'static str = "mysql::Upload"; +fn default_page() -> i32 { + 1 } -async fn upload_service(upload: Upload) -> Result<(), Error> { - println!("Attempting to upload {} to cloud", upload.url); - tokio::time::sleep(Duration::from_millis(1)).await; - Ok(()) +#[derive(Debug, Serialize, Deserialize)] +struct GetJobsResult { + pub stats: Stat, + pub jobs: Vec, } -#[derive(Serialize)] -struct JobsResult { - jobs: Vec>, - counts: StateCount, -} -#[derive(Deserialize)] -struct Filter { - #[serde(default)] - status: State, - #[serde(default)] - page: i32, -} - -async fn push_job(job: web::Json, storage: web::Data) -> HttpResponse -where - J: Job + Serialize + DeserializeOwned + 'static, - S: Storage, -{ - let storage = &*storage.into_inner(); - let mut storage = storage.clone(); +async fn push_job(job: web::Json, storage: web::Data>) -> HttpResponse { + let mut storage = (**storage).clone(); let res = storage.push(job.into_inner()).await; match res { - Ok(id) => HttpResponse::Ok().body(format!("Job with ID [{id}] added to queue")), - Err(e) => HttpResponse::InternalServerError().body(format!("{e}")), + Ok(parts) => { + HttpResponse::Ok().body(format!("Job with ID [{}] added to queue", parts.task_id)) + } + Err(e) => HttpResponse::InternalServerError().json(e.to_string()), } } -async fn get_jobs(storage: web::Data, filter: web::Query) -> HttpResponse -where - J: Job + Serialize + DeserializeOwned + 'static, - S: Storage + JobStreamExt + Send, -{ - let storage = &*storage.into_inner(); - let mut storage = storage.clone(); - let counts = storage.counts().await.unwrap(); - let jobs = storage.list_jobs(&filter.status, filter.page).await; - - match jobs { - Ok(jobs) => HttpResponse::Ok().json(JobsResult { jobs, counts }), - Err(e) => HttpResponse::InternalServerError().body(format!("{e}")), +async fn get_jobs( + storage: web::Data>, + filter: web::Query, +) -> HttpResponse { + let stats = storage.stats().await.unwrap_or_default(); + let res = storage.list_jobs(&filter.status, filter.page).await; + match res { + Ok(jobs) => HttpResponse::Ok().json(GetJobsResult { stats, jobs }), + Err(e) => HttpResponse::InternalServerError().json(e.to_string()), } } -async fn get_workers(storage: web::Data) -> HttpResponse -where - J: Job + Serialize + DeserializeOwned + 'static, - S: Storage + JobStreamExt, -{ - let storage = &*storage.into_inner(); - let mut storage = storage.clone(); +async fn get_workers(storage: web::Data>) -> HttpResponse { let workers = storage.list_workers().await; match workers { - Ok(workers) => HttpResponse::Ok().json(serde_json::to_value(workers).unwrap()), - Err(e) => HttpResponse::InternalServerError().body(format!("{e}")), + Ok(workers) => HttpResponse::Ok().json(workers), + Err(e) => HttpResponse::InternalServerError().json(e.to_string()), } } -async fn get_job(job_id: web::Path, storage: web::Data) -> HttpResponse -where - J: Job + Serialize + DeserializeOwned + 'static, - S: Storage + 'static, -{ - let storage = &*storage.into_inner(); - let storage = storage.clone(); +async fn get_job( + job_id: web::Path, + storage: web::Data>, +) -> HttpResponse { + let mut storage = (**storage).clone(); + let res = storage.fetch_by_id(&job_id).await; match res { Ok(Some(job)) => HttpResponse::Ok().json(job), Ok(None) => HttpResponse::NotFound().finish(), - Err(e) => HttpResponse::InternalServerError().body(format!("{e}")), - } -} - -trait StorageRest: Storage { - fn name(&self) -> String; -} - -impl StorageRest for S -where - S: Storage + JobStreamExt + 'static, - J: Job + Serialize + DeserializeOwned + 'static, -{ - fn name(&self) -> String { - J::NAME.to_string() + Err(e) => HttpResponse::InternalServerError().json(e.to_string()), } } -#[derive(Debug, Deserialize, Serialize)] -struct Queue { - name: String, -} - -#[derive(Debug, Deserialize, Serialize)] -struct QueueList { - set: HashSet, -} - -struct StorageApiBuilder { - scope: Scope, - list: QueueList, -} - -impl StorageApiBuilder { - fn add_storage(mut self, storage: S) -> Self - where - J: Job + Serialize + DeserializeOwned + 'static, - S: StorageRest + JobStreamExt, - S: Storage, - S: 'static + Send, - { - let name = J::NAME.to_string(); - self.list.set.insert(name); - - Self { - scope: self.scope.service( - Scope::new(J::NAME) - .app_data(web::Data::new(storage)) - .route("", web::get().to(get_jobs::)) // Fetch jobs in queue - .route("/workers", web::get().to(get_workers::)) // Fetch jobs in queue - .route("/job", web::put().to(push_job::)) // Allow add jobs via api - .route("/job/{job_id}", web::get().to(get_job::)), // Allow fetch specific job - ), - list: self.list, - } - } - - fn build(self) -> Scope { - async fn fetch_queues(queues: web::Data) -> HttpResponse { - let mut queue_result = Vec::new(); - for queue in &queues.set { - queue_result.push(Queue { - name: queue.clone(), - }) - } - #[derive(Serialize)] - struct Res { - queues: Vec, - } - - HttpResponse::Ok().json(Res { - queues: queue_result, - }) - } - - self.scope - .app_data(web::Data::new(self.list)) - .route("", web::get().to(fetch_queues)) - } - - fn new() -> Self { - Self { - scope: Scope::new("queues"), - list: QueueList { - set: HashSet::new(), - }, - } - } -} - -async fn produce_redis_jobs(mut storage: RedisStorage) { - for i in 0..10 { - storage - .push(Email { - to: format!("test{i}@example.com"), - text: "Test background job from apalis".to_string(), - subject: "Background email job".to_string(), - }) - .await - .unwrap(); - } -} -async fn produce_sqlite_jobs(mut storage: SqliteStorage) { - for i in 0..100 { - storage - .push(Notification { - text: format!("Notiification: {i}"), - }) - .await - .unwrap(); - } -} - -async fn produce_postgres_jobs(mut storage: PostgresStorage) { - for i in 0..100 { - storage - .push(Document { - text: format!("Document: {i}"), - }) - .await - .unwrap(); - } -} - -async fn produce_mysql_jobs(mut storage: MysqlStorage) { - for i in 0..100 { - storage - .push(Upload { - url: format!("Upload: {i}"), - }) - .await - .unwrap(); - } -} - -#[tokio::main(flavor = "multi_thread", worker_threads = 10)] -async fn main() -> anyhow::Result<()> { - std::env::set_var("RUST_LOG", "debug,sqlx::query=error"); +#[actix_web::main] +async fn main() -> Result<()> { + std::env::set_var("RUST_LOG", "debug"); env_logger::init(); - let database_url = std::env::var("DATABASE_URL").expect("Must specify DATABASE_URL"); - let pg: PostgresStorage = PostgresStorage::connect(database_url).await?; - pg.setup().await.expect("Unable to migrate"); - - let database_url = std::env::var("MYSQL_URL").expect("Must specify MYSQL_URL"); - - let mysql: MysqlStorage = MysqlStorage::connect(database_url).await?; - mysql - .setup() - .await - .expect("unable to run migrations for mysql"); - let storage = RedisStorage::connect("redis://127.0.0.1/").await?; - - let sqlite = SqliteStorage::connect("sqlite://data.db").await?; - sqlite.setup().await.expect("Unable to migrate"); - - let worker_storage = storage.clone(); - let sqlite_storage = sqlite.clone(); - let pg_storage = pg.clone(); - let mysql_storage = mysql.clone(); - - produce_redis_jobs(storage.clone()).await; - produce_sqlite_jobs(sqlite.clone()).await; - produce_postgres_jobs(pg_storage.clone()).await; - produce_mysql_jobs(mysql.clone()).await; + let conn = apalis_redis::connect("redis://127.0.0.1/").await?; + let storage = RedisStorage::new(conn); + let data = web::Data::new(storage.clone()); let http = async { HttpServer::new(move || { - App::new().wrap(Cors::permissive()).service( - web::scope("/api").service( - StorageApiBuilder::new() - .add_storage(storage.clone()) - .add_storage(sqlite.clone()) - .add_storage(pg.clone()) - .add_storage(mysql.clone()) - .build(), - ), - ) + App::new() + .app_data(data.clone()) + .route("/", web::get().to(get_jobs)) // Fetch jobs in queue + .route("/workers", web::get().to(get_workers)) // Fetch workers + .route("/job", web::put().to(push_job)) // Allow add jobs via api + .route("/job/{job_id}", web::get().to(get_job)) // Allow fetch specific job }) .bind("127.0.0.1:8000")? .run() .await?; Ok(()) }; - let worker = Monitor::new() - .register_with_count(1, move |_| { - WorkerBuilder::new("tasty-apple") - .layer(SentryJobLayer) - .layer(TraceLayer::new()) - .with_storage(worker_storage.clone()) + .register({ + WorkerBuilder::new("tasty-avocado") + .enable_tracing() + .backend(storage) .build_fn(send_email) }) - .register_with_count(4, move |c| { - WorkerBuilder::new(format!("tasty-avocado-{c}")) - .layer(SentryJobLayer) - .layer(TraceLayer::new()) - .with_storage(sqlite_storage.clone()) - .build_fn(notification_service) - }) - .register_with_count(2, move |c| { - WorkerBuilder::new(format!("tasty-banana-{c}")) - .layer(SentryJobLayer) - .layer(TraceLayer::new()) - .with_storage(pg_storage.clone()) - .build_fn(document_service) - }) - .register_with_count(2, move |c| { - WorkerBuilder::new(format!("tasty-pear-{c}")) - .layer(SentryJobLayer::new()) - .layer(TraceLayer::new()) - .with_storage(mysql_storage.clone()) - .build_fn(upload_service) - }) - .run(); - future::try_join(http, worker).await?; + .run_with_signal(signal::ctrl_c()); + future::try_join(http, worker).await?; Ok(()) } diff --git a/examples/sentry/Cargo.toml b/examples/sentry/Cargo.toml index 20b1a1a..6e51b11 100644 --- a/examples/sentry/Cargo.toml +++ b/examples/sentry/Cargo.toml @@ -7,7 +7,8 @@ license = "MIT OR Apache-2.0" [dependencies] anyhow = "1" -apalis = { path = "../../", features = ["redis", "sentry"] } +apalis = { path = "../../", features = ["sentry"] } +apalis-redis = { path = "../../packages/apalis-redis" } serde = "1" env_logger = "0.10" tracing-subscriber = { version = "0.3.11", features = ["env-filter"] } diff --git a/examples/sentry/src/main.rs b/examples/sentry/src/main.rs index 376507a..083734c 100644 --- a/examples/sentry/src/main.rs +++ b/examples/sentry/src/main.rs @@ -6,11 +6,9 @@ use std::time::Duration; use tracing_subscriber::prelude::*; use anyhow::Result; -use apalis::{ - layers::{sentry::SentryLayer, tracing::TraceLayer}, - prelude::*, - redis::RedisStorage, -}; + +use apalis::{layers::sentry::SentryLayer, prelude::*}; +use apalis_redis::RedisStorage; use email_service::Email; use tokio::time::sleep; @@ -129,18 +127,19 @@ async fn main() -> Result<()> { .with(sentry_tracing::layer()) .init(); - let conn = apalis::redis::connect(redis_url).await?; + let conn = apalis_redis::connect(redis_url).await?; let storage = RedisStorage::new(conn); //This can be in another part of the program produce_jobs(storage.clone()).await?; - Monitor::::new() - .register_with_count(2, { + Monitor::new() + .register({ WorkerBuilder::new("tasty-avocado") .layer(NewSentryLayer::new_from_top()) .layer(SentryLayer::new()) - .layer(TraceLayer::new()) - .with_storage(storage.clone()) + .enable_tracing() + .concurrency(2) + .backend(storage.clone()) .build_fn(email_service) }) .run() diff --git a/examples/sqlite/Cargo.toml b/examples/sqlite/Cargo.toml index c76d7b3..9f7d502 100644 --- a/examples/sqlite/Cargo.toml +++ b/examples/sqlite/Cargo.toml @@ -8,10 +8,9 @@ license = "MIT OR Apache-2.0" [dependencies] anyhow = "1" tokio = { version = "1", features = ["full"] } -apalis = { path = "../../", features = [ +apalis = { path = "../../", features = ["limit", "tracing"] } +apalis-sql = { path = "../../packages/apalis-sql", features = [ "sqlite", - "limit", - "tracing", "tokio-comp", ] } serde = { version = "1", features = ["derive"] } diff --git a/examples/sqlite/src/job.rs b/examples/sqlite/src/job.rs index a8311d2..4e0dded 100644 --- a/examples/sqlite/src/job.rs +++ b/examples/sqlite/src/job.rs @@ -1,4 +1,3 @@ -use apalis::prelude::*; use serde::{Deserialize, Serialize}; #[derive(Debug, Deserialize, Serialize)] @@ -7,10 +6,6 @@ pub struct Notification { pub text: String, } -impl Job for Notification { - const NAME: &'static str = "apalis::Notification"; -} - pub async fn notify(job: Notification) { tracing::info!("Attempting to send notification to {}", job.to); } diff --git a/examples/sqlite/src/main.rs b/examples/sqlite/src/main.rs index 2b086b0..0b43210 100644 --- a/examples/sqlite/src/main.rs +++ b/examples/sqlite/src/main.rs @@ -1,10 +1,10 @@ mod job; use anyhow::Result; -use apalis::utils::TokioExecutor; -use apalis::{layers::tracing::TraceLayer, prelude::*, sqlite::SqliteStorage}; -use chrono::Utc; +use apalis::prelude::*; +use apalis_sql::sqlite::SqliteStorage; +use chrono::Utc; use email_service::{send_email, Email}; use job::Notification; use sqlx::SqlitePool; @@ -58,17 +58,17 @@ async fn main() -> Result<()> { produce_notifications(¬ification_storage).await?; - Monitor::::new() - .register_with_count(2, { + Monitor::new() + .register({ WorkerBuilder::new("tasty-banana") - .layer(TraceLayer::new()) - .with_storage(email_storage) + .enable_tracing() + .backend(email_storage) .build_fn(send_email) }) - .register_with_count(10, { + .register({ WorkerBuilder::new("tasty-mango") - .layer(TraceLayer::new()) - .with_storage(notification_storage) + // .enable_tracing() + .backend(notification_storage) .build_fn(job::notify) }) .run() diff --git a/examples/tracing/Cargo.toml b/examples/tracing/Cargo.toml index c7b0936..3f15f84 100644 --- a/examples/tracing/Cargo.toml +++ b/examples/tracing/Cargo.toml @@ -7,9 +7,10 @@ license = "MIT OR Apache-2.0" [dependencies] anyhow = "1" -apalis = { path = "../../", features = ["redis"] } +apalis = { path = "../../" } +apalis-redis = { path = "../../packages/apalis-redis" } serde = "1" -tokio = { version ="1", features = ["full"]} +tokio = { version = "1", features = ["full"] } env_logger = "0.10" tracing-subscriber = { version = "0.3.11", features = ["env-filter", "json"] } chrono = { version = "0.4", default-features = false, features = ["clock"] } diff --git a/examples/tracing/src/main.rs b/examples/tracing/src/main.rs index cce99bd..3078479 100644 --- a/examples/tracing/src/main.rs +++ b/examples/tracing/src/main.rs @@ -1,17 +1,12 @@ use anyhow::Result; - +use apalis::layers::WorkerBuilderExt; +use apalis::prelude::{Monitor, Storage, WorkerBuilder, WorkerFactoryFn}; +use apalis_redis::RedisStorage; use std::error::Error; use std::fmt; use std::time::Duration; use tracing_subscriber::prelude::*; -use apalis::{ - layers::tracing::TraceLayer, - prelude::{Monitor, Storage, WorkerBuilder, WorkerFactoryFn}, - redis::RedisStorage, - utils::TokioExecutor, -}; - use tokio::time::sleep; use email_service::Email; @@ -32,7 +27,7 @@ impl Error for InvalidEmailError {} async fn email_service(email: Email) -> Result<(), InvalidEmailError> { tracing::info!("Checking if dns configured"); sleep(Duration::from_millis(1008)).await; - tracing::info!("Sent in 1 sec"); + tracing::info!("Failed in 1 sec"); Err(InvalidEmailError { email: email.to }) } @@ -63,18 +58,18 @@ async fn main() -> Result<()> { .with(fmt_layer) .init(); - let conn = apalis::redis::connect(redis_url) + let conn = apalis_redis::connect(redis_url) .await .expect("Could not connect to RedisStorage"); let storage = RedisStorage::new(conn); //This can be in another part of the program produce_jobs(storage.clone()).await?; - Monitor::::new() + Monitor::new() .register( WorkerBuilder::new("tasty-avocado") - .chain(|srv| srv.layer(TraceLayer::new())) - .with_storage(storage) + .enable_tracing() + .backend(storage) .build_fn(email_service), ) .run() diff --git a/examples/unmonitored-worker/Cargo.toml b/examples/unmonitored-worker/Cargo.toml new file mode 100644 index 0000000..403d4c6 --- /dev/null +++ b/examples/unmonitored-worker/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "unmonitored-worker" +version = "0.1.0" +edition.workspace = true +repository.workspace = true + +[dependencies] +tokio = { version = "1", features = ["full"] } +apalis = { path = "../../", features = ["limit", "catch-panic"] } +apalis-sql = { path = "../../packages/apalis-sql", features = [ + "sqlite", + "tokio-comp", +] } +serde = "1" +tracing-subscriber = "0.3.11" +futures = "0.3" +tower = "0.4" + + +[dependencies.tracing] +default-features = false +version = "0.1" diff --git a/examples/unmonitored-worker/src/main.rs b/examples/unmonitored-worker/src/main.rs new file mode 100644 index 0000000..7bafa85 --- /dev/null +++ b/examples/unmonitored-worker/src/main.rs @@ -0,0 +1,55 @@ +use std::time::Duration; + +use apalis::prelude::*; +use apalis_sql::sqlite::{SqlitePool, SqliteStorage}; +use serde::{Deserialize, Serialize}; +use tracing::info; + +#[derive(Debug, Serialize, Deserialize)] +struct SelfMonitoringJob { + id: i32, +} + +async fn self_monitoring_task(task: SelfMonitoringJob, worker: Worker) { + info!("task: {:?}, {:?}", task, worker); + if task.id == 1 { + tokio::spawn(async move { + loop { + tokio::time::sleep(Duration::from_secs(1)).await; + if !worker.has_pending_tasks() { + info!("done with all tasks, stopping worker"); + worker.stop(); + break; + } + } + }); + } + tokio::time::sleep(Duration::from_secs(5)).await; +} + +async fn produce_jobs(storage: &mut SqliteStorage) { + for id in 0..100 { + storage.push(SelfMonitoringJob { id }).await.unwrap(); + } +} + +#[tokio::main] +async fn main() -> Result<(), std::io::Error> { + std::env::set_var("RUST_LOG", "debug,sqlx::query=error"); + tracing_subscriber::fmt::init(); + let pool = SqlitePool::connect("sqlite::memory:").await.unwrap(); + SqliteStorage::setup(&pool) + .await + .expect("unable to run migrations for sqlite"); + let mut sqlite: SqliteStorage = SqliteStorage::new(pool); + produce_jobs(&mut sqlite).await; + + WorkerBuilder::new("tasty-banana") + .concurrency(2) + .backend(sqlite) + .build_fn(self_monitoring_task) + .on_event(|e| info!("{e}")) + .run() + .await; + Ok(()) +} diff --git a/packages/apalis-core/Cargo.toml b/packages/apalis-core/Cargo.toml index e3247ea..475c4b9 100644 --- a/packages/apalis-core/Cargo.toml +++ b/packages/apalis-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "apalis-core" -version = "0.5.5" +version = "0.6.0" authors = ["Njuguna Mureithi "] edition.workspace = true repository.workspace = true @@ -17,7 +17,6 @@ serde = { version = "1.0", features = ["derive"] } futures = { version = "0.3.30", features = ["async-await"] } tower = { version = "0.5", features = ["util"], default-features = false } pin-project-lite = "0.2.14" -async-oneshot = "0.5.9" thiserror = "2.0.0" ulid = { version = "1.1.2", default-features = false, features = ["std"] } futures-timer = { version = "3.0.3", optional = true } @@ -30,10 +29,11 @@ optional = true [features] -default = [] +default = ["test-utils"] docsrs = ["document-features"] sleep = ["futures-timer"] json = ["serde_json"] +test-utils = [] [package.metadata.docs.rs] # defines the configuration attribute `docsrs` diff --git a/packages/apalis-core/src/backend.rs b/packages/apalis-core/src/backend.rs new file mode 100644 index 0000000..eae050e --- /dev/null +++ b/packages/apalis-core/src/backend.rs @@ -0,0 +1,92 @@ +use std::{any::type_name, future::Future}; + +use futures::Stream; +use serde::{Deserialize, Serialize}; +use tower::Service; + +use crate::{ + poller::Poller, + request::State, + worker::{Context, Worker}, +}; + +/// A backend represents a task source +/// Both [`Storage`] and [`MessageQueue`] need to implement it for workers to be able to consume tasks +/// +/// [`Storage`]: crate::storage::Storage +/// [`MessageQueue`]: crate::mq::MessageQueue +pub trait Backend { + /// The stream to be produced by the backend + type Stream: Stream, crate::error::Error>>; + + /// Returns the final decoration of layers + type Layer; + + /// Returns a poller that is ready for streaming + fn poll>( + self, + worker: &Worker, + ) -> Poller; +} + +/// Represents functionality that allows reading of jobs and stats from a backend +/// Some backends esp MessageQueues may not currently implement this +pub trait BackendExpose +where + Self: Sized, +{ + /// The request type being handled by the backend + type Request; + /// The error returned during reading jobs and stats + type Error; + /// List all Workers that are working on a backend + fn list_workers( + &self, + ) -> impl Future>, Self::Error>> + Send; + + /// Returns the counts of jobs in different states + fn stats(&self) -> impl Future> + Send; + + /// Fetch jobs persisted in a backend + fn list_jobs( + &self, + status: &State, + page: i32, + ) -> impl Future, Self::Error>> + Send; +} + +/// Represents the current statistics of a backend +#[derive(Debug, Deserialize, Serialize, Default)] +pub struct Stat { + /// Represents pending tasks + pub pending: usize, + /// Represents running tasks + pub running: usize, + /// Represents dead tasks + pub dead: usize, + /// Represents failed tasks + pub failed: usize, + /// Represents successful tasks + pub success: usize, +} + +/// A serializable version of a worker's state. +#[derive(Debug, Serialize, Deserialize)] +pub struct WorkerState { + /// Type of task being consumed by the worker, useful for display and filtering + pub r#type: String, + /// The type of job stream + pub source: String, + // TODO: // The layers that were loaded for worker. + // TODO: // pub layers: Vec, + // TODO: // last_seen: Timestamp, +} +impl WorkerState { + /// Build a new state + pub fn new(r#type: String) -> Self { + Self { + r#type, + source: type_name::().to_string(), + } + } +} diff --git a/packages/apalis-core/src/builder.rs b/packages/apalis-core/src/builder.rs index 3bd05b1..b362509 100644 --- a/packages/apalis-core/src/builder.rs +++ b/packages/apalis-core/src/builder.rs @@ -7,29 +7,27 @@ use tower::{ }; use crate::{ + backend::Backend, error::Error, layers::extensions::Data, - mq::MessageQueue, request::Request, service_fn::service_fn, service_fn::ServiceFn, - storage::Storage, worker::{Ready, Worker, WorkerId}, - Backend, }; /// Allows building a [`Worker`]. /// Usually the output is [`Worker`] -pub struct WorkerBuilder { +pub struct WorkerBuilder { id: WorkerId, - request: PhantomData, + request: PhantomData>, layer: ServiceBuilder, source: Source, service: PhantomData, } -impl std::fmt::Debug - for WorkerBuilder +impl std::fmt::Debug + for WorkerBuilder { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("WorkerBuilder") @@ -41,10 +39,10 @@ impl std::fmt::Debug } } -impl WorkerBuilder<(), (), Identity, Serv> { +impl WorkerBuilder<(), (), (), Identity, Serv> { /// Build a new [`WorkerBuilder`] instance with a name for the worker to build - pub fn new>(name: T) -> WorkerBuilder<(), (), Identity, Serv> { - let job: PhantomData<()> = PhantomData; + pub fn new>(name: T) -> WorkerBuilder<(), (), (), Identity, Serv> { + let job: PhantomData> = PhantomData; WorkerBuilder { request: job, layer: ServiceBuilder::new(), @@ -55,12 +53,17 @@ impl WorkerBuilder<(), (), Identity, Serv> { } } -impl WorkerBuilder { +impl WorkerBuilder<(), (), (), M, Serv> { /// Consume a stream directly - pub fn stream>, Error>> + Send + 'static, NJ>( + #[deprecated(since = "0.6.0", note = "Consider using the `.backend`")] + pub fn stream< + NS: Stream>, Error>> + Send + 'static, + NJ, + Ctx, + >( self, stream: NS, - ) -> WorkerBuilder { + ) -> WorkerBuilder { WorkerBuilder { request: PhantomData, layer: self.layer, @@ -70,39 +73,14 @@ impl WorkerBuilder { } } - /// Set the source to a [Storage] - pub fn with_storage, NJ>( + /// Set the source to a backend that implements [Backend] + pub fn backend, Res>, NJ, Res: Send, Ctx>( self, - storage: NS, - ) -> WorkerBuilder { - WorkerBuilder { - request: PhantomData, - layer: self.layer, - source: storage, - id: self.id, - service: self.service, - } - } - - /// Set the source to a [MessageQueue] - pub fn with_mq, NJ>( - self, - message_queue: NS, - ) -> WorkerBuilder { - WorkerBuilder { - request: PhantomData, - layer: self.layer, - source: message_queue, - id: self.id, - service: self.service, - } - } - - /// Set the source to a generic backend that implements only [Backend] - pub fn source>, NJ>( - self, - backend: NS, - ) -> WorkerBuilder { + backend: NB, + ) -> WorkerBuilder + where + Serv: Service, Response = Res>, + { WorkerBuilder { request: PhantomData, layer: self.layer, @@ -113,13 +91,13 @@ impl WorkerBuilder { } } -impl WorkerBuilder { +impl WorkerBuilder { /// Allows of decorating the service that consumes jobs. /// Allows adding multiple [`tower`] middleware pub fn chain( self, - f: impl Fn(ServiceBuilder) -> ServiceBuilder, - ) -> WorkerBuilder { + f: impl FnOnce(ServiceBuilder) -> ServiceBuilder, + ) -> WorkerBuilder { let middleware = f(self.layer); WorkerBuilder { @@ -131,7 +109,7 @@ impl WorkerBuilder { } } /// Allows adding a single layer [tower] middleware - pub fn layer(self, layer: U) -> WorkerBuilder, Serv> + pub fn layer(self, layer: U) -> WorkerBuilder, Serv> where M: Layer, { @@ -146,7 +124,7 @@ impl WorkerBuilder { /// Adds data to the context /// This will be shared by all requests - pub fn data(self, data: D) -> WorkerBuilder, M>, Serv> + pub fn data(self, data: D) -> WorkerBuilder, M>, Serv> where M: Layer>, { @@ -160,33 +138,32 @@ impl WorkerBuilder { } } -impl> + 'static, M: 'static, S> - WorkerFactory for WorkerBuilder +impl WorkerFactory for WorkerBuilder where - S: Service> + Send + 'static + Clone + Sync, + S: Service> + Send + 'static + Sync, S::Future: Send, S::Response: 'static, - P::Layer: Layer, - M: Layer<>::Service>, + M: Layer, + Req: Send + 'static + Sync, + P: Backend, S::Response> + 'static, + M: 'static, { type Source = P; type Service = M::Service; - /// Build a worker, given a tower service - fn build(self, service: S) -> Worker> { + + fn build(self, service: S) -> Worker> { let worker_id = self.id; - let common_layer = self.source.common_layer(worker_id.clone()); let poller = self.source; - let middleware = self.layer.layer(common_layer); + let middleware = self.layer; let service = middleware.service(service); Worker::new(worker_id, Ready::new(service, poller)) } } - /// Helper trait for building new Workers from [`WorkerBuilder`] -pub trait WorkerFactory { +pub trait WorkerFactory { /// The request source for the worker type Source; @@ -205,7 +182,7 @@ pub trait WorkerFactory { /// Helper trait for building new Workers from [`WorkerBuilder`] -pub trait WorkerFactoryFn { +pub trait WorkerFactoryFn { /// The request source for the [`Worker`] type Source; @@ -224,36 +201,29 @@ pub trait WorkerFactoryFn { /// - An async function with an argument of the item being processed plus up-to 16 arguments that are extracted from the request [`Data`] /// /// A function can return: - /// - Unit + /// - () /// - primitive /// - Result /// - impl IntoResponse /// /// ```rust + /// # use apalis_core::layers::extensions::Data; /// #[derive(Debug)] /// struct Email; /// #[derive(Debug)] /// struct PgPool; - /// # struct PgError; - /// - /// async fn send_email(email: Email) { - /// // Implementation of the job function - /// // ... - /// } /// - /// async fn send_email(email: Email, data: Data) -> Result<(), PgError> { - /// // Implementation of the job function? - /// // ... - /// Ok(()) + /// async fn send_email(email: Email, data: Data) { + /// // Implementation of the task function? /// } /// ``` /// fn build_fn(self, f: F) -> Worker>; } -impl WorkerFactoryFn for W +impl WorkerFactoryFn for W where - W: WorkerFactory>, + W: WorkerFactory>, { type Source = W::Source; diff --git a/packages/apalis-core/src/codec/json.rs b/packages/apalis-core/src/codec/json.rs index 7ed7c7f..35376ad 100644 --- a/packages/apalis-core/src/codec/json.rs +++ b/packages/apalis-core/src/codec/json.rs @@ -1,40 +1,56 @@ -use crate::{error::Error, Codec}; -use serde::{de::DeserializeOwned, Serialize}; +use std::marker::PhantomData; + +use crate::codec::Codec; +use serde::{Deserialize, Serialize}; use serde_json::Value; /// Json encoding and decoding #[derive(Debug, Clone, Default)] -pub struct JsonCodec; +pub struct JsonCodec { + _o: PhantomData, +} -impl Codec> for JsonCodec { - type Error = Error; - fn encode(&self, input: &T) -> Result, Self::Error> { - serde_json::to_vec(input).map_err(|e| Error::SourceError(Box::new(e))) +impl Codec for JsonCodec> { + type Compact = Vec; + type Error = serde_json::Error; + fn encode(input: T) -> Result, Self::Error> { + serde_json::to_vec(&input) } - fn decode(&self, compact: &Vec) -> Result { - serde_json::from_slice(compact).map_err(|e| Error::SourceError(Box::new(e))) + fn decode(compact: Vec) -> Result + where + O: for<'de> Deserialize<'de>, + { + serde_json::from_slice(&compact) } } -impl Codec for JsonCodec { - type Error = Error; - fn encode(&self, input: &T) -> Result { - serde_json::to_string(input).map_err(|e| Error::SourceError(Box::new(e))) +impl Codec for JsonCodec { + type Compact = String; + type Error = serde_json::Error; + fn encode(input: T) -> Result { + serde_json::to_string(&input) } - fn decode(&self, compact: &String) -> Result { - serde_json::from_str(compact).map_err(|e| Error::SourceError(Box::new(e))) + fn decode(compact: String) -> Result + where + O: for<'de> Deserialize<'de>, + { + serde_json::from_str(&compact) } } -impl Codec for JsonCodec { - type Error = Error; - fn encode(&self, input: &T) -> Result { - serde_json::to_value(input).map_err(|e| Error::SourceError(Box::new(e))) +impl Codec for JsonCodec { + type Compact = Value; + type Error = serde_json::Error; + fn encode(input: T) -> Result { + serde_json::to_value(input) } - fn decode(&self, compact: &Value) -> Result { - serde_json::from_value(compact.clone()).map_err(|e| Error::SourceError(Box::new(e))) + fn decode(compact: Value) -> Result + where + O: for<'de> Deserialize<'de>, + { + serde_json::from_value(compact) } } diff --git a/packages/apalis-core/src/codec/message_pack.rs b/packages/apalis-core/src/codec/message_pack.rs deleted file mode 100644 index e69de29..0000000 diff --git a/packages/apalis-core/src/codec/mod.rs b/packages/apalis-core/src/codec/mod.rs index a2d1f55..c12d6fa 100644 --- a/packages/apalis-core/src/codec/mod.rs +++ b/packages/apalis-core/src/codec/mod.rs @@ -1,3 +1,23 @@ +use serde::{Deserialize, Serialize}; + +use crate::error::BoxDynError; + +/// A codec allows backends to encode and decode data +pub trait Codec { + /// The mode of storage by the codec + type Compact; + /// Error encountered by the codec + type Error: Into; + /// The encoding method + fn encode(input: I) -> Result + where + I: Serialize; + /// The decoding method + fn decode(input: Self::Compact) -> Result + where + O: for<'de> Deserialize<'de>; +} + /// Encoding for tasks using json #[cfg(feature = "json")] pub mod json; diff --git a/packages/apalis-core/src/data.rs b/packages/apalis-core/src/data.rs index 33cd3f9..e2829c8 100644 --- a/packages/apalis-core/src/data.rs +++ b/packages/apalis-core/src/data.rs @@ -5,6 +5,8 @@ use std::collections::HashMap; use std::fmt; use std::hash::{BuildHasherDefault, Hasher}; +use crate::error::Error; + type AnyMap = HashMap, BuildHasherDefault>; // With TypeIds as keys, there's no need to hash them. They are already hashes @@ -87,6 +89,27 @@ impl Extensions { .and_then(|boxed| (**boxed).as_any().downcast_ref()) } + /// Get a checked reference to a type previously inserted on this `Extensions`. + /// + /// # Example + /// + /// ``` + /// # use apalis_core::data::Extensions; + /// let mut ext = Extensions::new(); + /// assert!(ext.get_checked::().is_err()); + /// ext.insert(5i32); + /// + /// assert_eq!(ext.get_checked::(), Ok(&5i32)); + /// ``` + pub fn get_checked(&self) -> Result<&T, Error> { + self.get() + .ok_or({ + let type_name = std::any::type_name::(); + Error::MissingData( + format!("Missing the an entry for `{type_name}`. Did you forget to add `.data(<{type_name}>)", )) + }) + } + /// Get a mutable reference to a type previously inserted on this `Extensions`. /// /// # Example diff --git a/packages/apalis-core/src/error.rs b/packages/apalis-core/src/error.rs index 7ba3cf5..aa27412 100644 --- a/packages/apalis-core/src/error.rs +++ b/packages/apalis-core/src/error.rs @@ -1,5 +1,13 @@ -use std::error::Error as StdError; +use std::{ + error::Error as StdError, + future::Future, + marker::PhantomData, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; use thiserror::Error; +use tower::Service; use crate::worker::WorkerError; @@ -7,38 +15,117 @@ use crate::worker::WorkerError; pub type BoxDynError = Box; /// Represents a general error returned by a task or by internals of the platform -#[derive(Error, Debug)] +#[derive(Error, Debug, Clone)] #[non_exhaustive] pub enum Error { /// An error occurred during execution. - #[error("Task Failed: {0}")] - Failed(#[source] BoxDynError), - - /// A generic IO error - #[error("IO error: {0}")] - Io(#[from] std::io::Error), - - /// Missing some context and yet it was requested during execution. - #[error("MissingContext: {0}")] - InvalidContext(String), + #[error("FailedError: {0}")] + Failed(#[source] Arc), /// Execution was aborted - #[error("Execution was aborted")] - Abort, + #[error("AbortError: {0}")] + Abort(#[source] Arc), + #[doc(hidden)] /// Encountered an error during worker execution - #[error("Encountered an error during worker execution")] + /// This should not be used inside a task function + #[error("WorkerError: {0}")] WorkerError(WorkerError), + /// Missing some data and yet it was requested during execution. + /// This should not be used inside a task function + #[error("MissingDataError: {0}")] + MissingData(String), + #[doc(hidden)] /// Encountered an error during service execution /// This should not be used inside a task function #[error("Encountered an error during service execution")] - ServiceError(#[source] BoxDynError), + ServiceError(#[source] Arc), #[doc(hidden)] /// Encountered an error during service execution /// This should not be used inside a task function #[error("Encountered an error during streaming")] - SourceError(#[source] BoxDynError), + SourceError(#[source] Arc), +} + +impl From for Error { + fn from(err: BoxDynError) -> Self { + if let Some(e) = err.downcast_ref::() { + e.clone() + } else { + Error::Failed(Arc::new(err)) + } + } +} + +/// A Tower layer for handling and converting service errors into a custom `Error` type. +/// +/// This layer wraps a service and intercepts any errors returned by the service. +/// It attempts to downcast the error into the custom `Error` enum. If the downcast +/// succeeds, it returns the downcasted `Error`. If the downcast fails, the original +/// error is wrapped in `Error::Failed`. +/// +/// The service's error type must implement `Into`, allowing for flexible +/// error handling, especially when dealing with trait objects or complex error chains. +#[derive(Clone, Debug)] +pub struct ErrorHandlingLayer { + _p: PhantomData<()>, +} + +impl ErrorHandlingLayer { + /// Create a new ErrorHandlingLayer + pub fn new() -> Self { + Self { _p: PhantomData } + } +} + +impl Default for ErrorHandlingLayer { + fn default() -> Self { + Self::new() + } +} + +impl tower::layer::Layer for ErrorHandlingLayer { + type Service = ErrorHandlingService; + + fn layer(&self, service: S) -> Self::Service { + ErrorHandlingService { service } + } +} + +/// The underlying service +#[derive(Clone, Debug)] +pub struct ErrorHandlingService { + service: S, +} + +impl Service for ErrorHandlingService +where + S: Service, + S::Error: Into, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = Error; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx).map_err(|e| { + let boxed_error: BoxDynError = e.into(); + boxed_error.into() + }) + } + + fn call(&mut self, req: Request) -> Self::Future { + let fut = self.service.call(req); + + Box::pin(async move { + fut.await.map_err(|e| { + let boxed_error: BoxDynError = e.into(); + boxed_error.into() + }) + }) + } } diff --git a/packages/apalis-core/src/executor.rs b/packages/apalis-core/src/executor.rs deleted file mode 100644 index 5891964..0000000 --- a/packages/apalis-core/src/executor.rs +++ /dev/null @@ -1,7 +0,0 @@ -use futures::Future; - -/// An Executor that is used to spawn futures -pub trait Executor { - /// Spawns a new asynchronous task - fn spawn(&self, future: impl Future + Send + 'static); -} diff --git a/packages/apalis-core/src/layers.rs b/packages/apalis-core/src/layers.rs index 59108ed..2d9c971 100644 --- a/packages/apalis-core/src/layers.rs +++ b/packages/apalis-core/src/layers.rs @@ -1,10 +1,15 @@ +use crate::error::{BoxDynError, Error}; +use crate::request::Request; +use crate::response::Response; +use futures::channel::mpsc::{SendError, Sender}; +use futures::SinkExt; +use futures::{future::BoxFuture, Future, FutureExt}; +use serde::Serialize; use std::marker::PhantomData; use std::{fmt, sync::Arc}; -pub use tower::{layer::layer_fn, util::BoxCloneService, Layer, Service, ServiceBuilder}; - -use futures::{future::BoxFuture, Future, FutureExt}; - -use crate::{request::Request, worker::WorkerId}; +pub use tower::{ + layer::layer_fn, layer::util::Identity, util::BoxCloneService, Layer, Service, ServiceBuilder, +}; /// A generic layer that has been stripped off types. /// This is returned by a [crate::Backend] and can be used to customize the middleware of the service consuming tasks @@ -88,7 +93,7 @@ pub mod extensions { /// /// let worker = WorkerBuilder::new("tasty-avocado") /// .data(state) - /// .source(MemoryStorage::new()) + /// .backend(MemoryStorage::new()) /// .build(service_fn(email_service)); /// ``` @@ -129,9 +134,9 @@ pub mod extensions { value: T, } - impl Service> for AddExtension + impl Service> for AddExtension where - S: Service>, + S: Service>, T: Clone + Send + Sync + 'static, { type Response = S::Response; @@ -143,126 +148,160 @@ pub mod extensions { self.inner.poll_ready(cx) } - fn call(&mut self, mut req: Request) -> Self::Future { - req.data.insert(self.value.clone()); + fn call(&mut self, mut req: Request) -> Self::Future { + req.parts.data.insert(self.value.clone()); self.inner.call(req) } } } /// A trait for acknowledging successful processing -pub trait Ack { +/// This trait is called even when a task fails. +/// This is a way of a [`Backend`] to save the result of a job or message +pub trait Ack { /// The data to fetch from context to allow acknowledgement - type Acknowledger; + type Context; /// The error returned by the ack - type Error: std::error::Error; + type AckError: std::error::Error; + /// Acknowledges successful processing of the given request fn ack( - &self, - worker_id: &WorkerId, - data: &Self::Acknowledger, - ) -> impl Future> + Send; + &mut self, + ctx: &Self::Context, + response: &Response, + ) -> impl Future> + Send; +} + +impl Ack + for Sender<(Ctx, Response)> +{ + type AckError = SendError; + type Context = Ctx; + async fn ack( + &mut self, + ctx: &Self::Context, + result: &Response, + ) -> Result<(), Self::AckError> { + let ctx = ctx.clone(); + self.send((ctx, result.clone())).await.unwrap(); + Ok(()) + } } /// A layer that acknowledges a job completed successfully #[derive(Debug)] -pub struct AckLayer, J> { +pub struct AckLayer { ack: A, - job_type: PhantomData, - worker_id: WorkerId, + job_type: PhantomData>, + res: PhantomData, } -impl, J> AckLayer { +impl AckLayer { /// Build a new [AckLayer] for a job - pub fn new(ack: A, worker_id: WorkerId) -> Self { + pub fn new(ack: A) -> Self { Self { ack, job_type: PhantomData, - worker_id, + res: PhantomData, } } } -impl Layer for AckLayer +impl Layer for AckLayer where - S: Service> + Send + 'static, + S: Service> + Send + 'static, S::Error: std::error::Error + Send + Sync + 'static, S::Future: Send + 'static, - A: Ack + Clone + Send + Sync + 'static, + A: Ack + Clone + Send + Sync + 'static, { - type Service = AckService; + type Service = AckService; fn layer(&self, service: S) -> Self::Service { AckService { service, ack: self.ack.clone(), job_type: PhantomData, - worker_id: self.worker_id.clone(), + res: PhantomData, } } } /// The underlying service for an [AckLayer] #[derive(Debug)] -pub struct AckService { +pub struct AckService { service: SV, ack: A, - job_type: PhantomData, - worker_id: WorkerId, + job_type: PhantomData>, + res: PhantomData, } -impl Clone for AckService { +impl Clone for AckService { fn clone(&self) -> Self { Self { ack: self.ack.clone(), job_type: PhantomData, - worker_id: self.worker_id.clone(), service: self.service.clone(), + res: PhantomData, } } } -impl Service> for AckService +impl Service> for AckService where - SV: Service> + Send + Sync + 'static, - SV::Error: std::error::Error + Send + Sync + 'static, - >>::Future: std::marker::Send + 'static, - A: Ack + Send + 'static + Clone + Send + Sync, - J: 'static, - >>::Response: std::marker::Send, - >::Acknowledger: Sync + Send + Clone, + SV: Service> + Send + Sync + 'static, + >>::Error: Into + Send + Sync + 'static, + >>::Future: std::marker::Send + 'static, + A: Ack>>::Response, Context = Ctx> + + Send + + 'static + + Clone + + Send + + Sync, + Req: 'static + Send, + >>::Response: std::marker::Send + fmt::Debug + Sync + Serialize, + >::Context: Sync + Send + Clone, + >>::Response>>::Context: 'static, + Ctx: Clone, { type Response = SV::Response; - type Error = SV::Error; + type Error = Error; type Future = BoxFuture<'static, Result>; fn poll_ready( &mut self, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - self.service.poll_ready(cx) + self.service + .poll_ready(cx) + .map_err(|e| Error::Failed(Arc::new(e.into()))) } - fn call(&mut self, request: Request) -> Self::Future { - let ack = self.ack.clone(); - let worker_id = self.worker_id.clone(); - let data = request.get::<>::Acknowledger>().cloned(); - + fn call(&mut self, request: Request) -> Self::Future { + let mut ack = self.ack.clone(); + let ctx = request.parts.context.clone(); + let attempt = request.parts.attempt.clone(); + let task_id = request.parts.task_id.clone(); let fut = self.service.call(request); let fut_with_ack = async move { - let res = fut.await; - if let Some(data) = data { - if let Err(_e) = ack.ack(&worker_id, &data).await { - // tracing::warn!("Acknowledgement Failed: {}", e); - // try get monitor, and emit + let res = fut.await.map_err(|err| { + let e: BoxDynError = err.into(); + // Try to downcast the error to see if it is already of type `Error` + if let Some(custom_error) = e.downcast_ref::() { + return custom_error.clone(); } - } else { - // tracing::warn!( - // "Acknowledgement could not be called due to missing ack data in context : {}", - // &std::any::type_name::<>::Acknowledger>() - // ); + Error::Failed(Arc::new(e)) + }); + let response = Response { + attempt, + inner: res, + task_id, + _priv: (), + }; + if let Err(_e) = ack.ack(&ctx, &response).await { + // TODO: Implement tracing in apalis core + // tracing::error!("Acknowledgement Failed: {}", e); } - res + response.inner }; fut_with_ack.boxed() } diff --git a/packages/apalis-core/src/lib.rs b/packages/apalis-core/src/lib.rs index afbcdb9..6e356d4 100644 --- a/packages/apalis-core/src/lib.rs +++ b/packages/apalis-core/src/lib.rs @@ -22,16 +22,13 @@ #![cfg_attr(docsrs, feature(doc_cfg))] //! # apalis-core //! Utilities for building job and message processing tools. -use futures::Stream; -use poller::Poller; -use worker::WorkerId; - /// Represent utilities for creating worker instances. pub mod builder; + +/// Represents a task source eg Postgres or Redis +pub mod backend; /// Includes all possible error types. pub mod error; -/// Represents an executor. -pub mod executor; /// Represents middleware offered through [`tower`] pub mod layers; /// Represents monitoring of running workers @@ -65,53 +62,297 @@ pub mod task; /// Codec for handling data pub mod codec; -/// A backend represents a task source -/// Both [`Storage`] and [`MessageQueue`] need to implement it for workers to be able to consume tasks -/// -/// [`Storage`]: crate::storage::Storage -/// [`MessageQueue`]: crate::mq::MessageQueue -pub trait Backend { - /// The stream to be produced by the backend - type Stream: Stream, crate::error::Error>>; +/// Sleep utilities +#[cfg(feature = "sleep")] +pub async fn sleep(duration: std::time::Duration) { + futures_timer::Delay::new(duration).await; +} - /// Returns the final decoration of layers - type Layer; +#[cfg(feature = "sleep")] +/// Interval utilities +pub mod interval { + use std::fmt; + use std::future::Future; + use std::pin::Pin; + use std::task::{Context, Poll}; + use std::time::Duration; - /// Allows the backend to decorate the service with [Layer] - /// - /// [Layer]: tower::Layer - #[allow(unused)] - fn common_layer(&self, worker: WorkerId) -> Self::Layer; + use futures::future::BoxFuture; + use futures::Stream; - /// Returns a poller that is ready for streaming - fn poll(self, worker: WorkerId) -> Poller; -} + use crate::sleep; + /// Creates a new stream that yields at a set interval. + pub fn interval(duration: Duration) -> Interval { + Interval { + timer: Box::pin(sleep(duration)), + interval: duration, + } + } -/// This allows encoding and decoding of requests in different backends -pub trait Codec { - /// Error encountered by the codec - type Error; + /// A stream representing notifications at fixed interval + #[must_use = "streams do nothing unless polled or .awaited"] + pub struct Interval { + timer: BoxFuture<'static, ()>, + interval: Duration, + } - /// Convert to the compact version - fn encode(&self, input: &T) -> Result; + impl fmt::Debug for Interval { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Interval") + .field("interval", &self.interval) + .field("timer", &"a future represented `apalis_core::sleep`") + .finish() + } + } - /// Decode back to our request type - fn decode(&self, compact: &Compact) -> Result; -} + impl Stream for Interval { + type Item = (); -/// Sleep utilities -#[cfg(feature = "sleep")] -pub async fn sleep(duration: std::time::Duration) { - futures_timer::Delay::new(duration).await; + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match Pin::new(&mut self.timer).poll(cx) { + Poll::Ready(_) => {} + Poll::Pending => return Poll::Pending, + }; + let interval = self.interval; + let fut = std::mem::replace(&mut self.timer, Box::pin(sleep(interval))); + drop(fut); + Poll::Ready(Some(())) + } + } } -#[cfg(test)] -#[doc(hidden)] -#[derive(Debug, Default, Clone)] -pub(crate) struct TestExecutor; -#[cfg(test)] -impl crate::executor::Executor for TestExecutor { - fn spawn(&self, future: impl futures::prelude::Future + Send + 'static) { - tokio::spawn(future); +#[cfg(feature = "test-utils")] +/// Test utilities that allows you to test backends +pub mod test_utils { + use crate::backend::Backend; + use crate::error::BoxDynError; + use crate::request::Request; + use crate::task::task_id::TaskId; + use crate::worker::{Worker, WorkerId}; + use futures::channel::mpsc::{channel, Receiver, Sender}; + use futures::future::BoxFuture; + use futures::stream::{Stream, StreamExt}; + use futures::{Future, FutureExt, SinkExt}; + use std::fmt::Debug; + use std::marker::PhantomData; + use std::ops::{Deref, DerefMut}; + use std::pin::Pin; + use std::task::{Context, Poll}; + use tower::{Layer, Service}; + + /// Define a dummy service + #[derive(Debug, Clone)] + pub struct DummyService; + + impl Service for DummyService { + type Response = Request; + type Error = std::convert::Infallible; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + let fut = async move { Ok(req) }; + Box::pin(fut) + } + } + + /// A generic backend wrapper that polls and executes jobs + #[derive(Debug)] + pub struct TestWrapper { + stop_tx: Sender<()>, + res_rx: Receiver<(TaskId, Result)>, + _p: PhantomData, + _r: PhantomData, + backend: B, + } + /// A test wrapper to allow you to test without requiring a worker. + /// Important for testing backends and jobs + /// # Example + /// ```no_run + /// #[cfg(tests)] + /// mod tests { + /// use crate::{ + /// error::Error, memory::MemoryStorage, mq::MessageQueue, service_fn::service_fn, + /// }; + /// + /// use super::*; + /// + /// async fn is_even(req: usize) -> Result<(), Error> { + /// if req % 2 == 0 { + /// Ok(()) + /// } else { + /// Err(Error::Abort("Not an even number".to_string())) + /// } + /// } + /// + /// #[tokio::test] + /// async fn test_accepts_even() { + /// let backend = MemoryStorage::new(); + /// let (mut tester, poller) = TestWrapper::new_with_service(backend, service_fn(is_even)); + /// tokio::spawn(poller); + /// tester.enqueue(42usize).await.unwrap(); + /// assert_eq!(tester.size().await.unwrap(), 1); + /// let (_, resp) = tester.execute_next().await; + /// assert_eq!(resp, Ok("()".to_string())); + /// } + ///} + /// ```` + impl TestWrapper, Res> + where + B: Backend, Res> + Send + Sync + 'static + Clone, + Req: Send + 'static, + Ctx: Send, + B::Stream: Send + 'static, + B::Stream: Stream>, crate::error::Error>> + Unpin, + { + /// Build a new instance provided a custom service + pub fn new_with_service(backend: B, service: S) -> (Self, BoxFuture<'static, ()>) + where + S: Service, Response = Res> + Send + 'static, + B::Layer: Layer, + <, Res>>::Layer as Layer>::Service: + Service> + Send + 'static, + <<, Res>>::Layer as Layer>::Service as Service< + Request, + >>::Response: Send + Debug, + <<, Res>>::Layer as Layer>::Service as Service< + Request, + >>::Error: Send + Into + Sync, + <<, Res>>::Layer as Layer>::Service as Service< + Request, + >>::Future: Send + 'static, + { + let worker_id = WorkerId::new("test-worker"); + let worker = Worker::new(worker_id, crate::worker::Context::default()); + let b = backend.clone(); + let mut poller = b.poll::(&worker); + let (stop_tx, mut stop_rx) = channel::<()>(1); + + let (mut res_tx, res_rx) = channel(10); + + let mut service = poller.layer.layer(service); + + let poller = async move { + let heartbeat = poller.heartbeat.shared(); + loop { + futures::select! { + + item = poller.stream.next().fuse() => match item { + Some(Ok(Some(req))) => { + let task_id = req.parts.task_id.clone(); + match service.call(req).await { + Ok(res) => { + res_tx.send((task_id, Ok(format!("{res:?}")))).await.unwrap(); + }, + Err(err) => { + res_tx.send((task_id, Err(err.into().to_string()))).await.unwrap(); + } + } + } + Some(Ok(None)) | None => break, + Some(Err(_e)) => { + // handle error + break; + } + }, + _ = stop_rx.next().fuse() => break, + _ = heartbeat.clone().fuse() => { + + }, + } + } + }; + ( + TestWrapper { + stop_tx, + res_rx, + _p: PhantomData, + backend, + _r: PhantomData, + }, + poller.boxed(), + ) + } + + /// Stop polling + pub fn stop(mut self) { + self.stop_tx.try_send(()).unwrap(); + } + + /// Gets the current state of results + pub async fn execute_next(&mut self) -> (TaskId, Result) { + self.res_rx.next().await.unwrap() + } + } + + impl Deref for TestWrapper, Res> + where + B: Backend, Res>, + { + type Target = B; + + fn deref(&self) -> &Self::Target { + &self.backend + } + } + + impl DerefMut for TestWrapper, Res> + where + B: Backend, Res>, + { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.backend + } + } + + pub use tower::service_fn as apalis_test_service_fn; + + #[macro_export] + /// Tests a generic mq + macro_rules! test_message_queue { + ($backend_instance:expr) => { + #[tokio::test] + async fn it_works_as_an_mq_backend() { + let backend = $backend_instance; + let service = apalis_test_service_fn(|request: Request| async { + Ok::<_, io::Error>(request) + }); + let (mut t, poller) = TestWrapper::new_with_service(backend, service); + tokio::spawn(poller); + t.enqueue(1).await.unwrap(); + tokio::time::sleep(Duration::from_secs(1)).await; + let _res = t.execute_next().await; + // assert_eq!(res.len(), 1); // One job is done + } + }; + } + #[macro_export] + /// Tests a generic storage + macro_rules! generic_storage_test { + ($setup:path ) => { + #[tokio::test] + async fn integration_test_storage_push_and_consume() { + let backend = $setup().await; + let service = apalis_test_service_fn(|request: Request| async move { + Ok::<_, io::Error>(request.args) + }); + let (mut t, poller) = TestWrapper::new_with_service(backend, service); + tokio::spawn(poller); + let res = t.len().await.unwrap(); + assert_eq!(res, 0); // No jobs + t.push(1).await.unwrap(); + let res = t.len().await.unwrap(); + assert_eq!(res, 1); // A job exists + let res = t.execute_next().await; + assert_eq!(res.1, Ok("1".to_owned())); + // TODO: all storages need to satisfy this rule, redis does not + // let res = t.len().await.unwrap(); + // assert_eq!(res, 0); + t.vacuum().await.unwrap(); + } + }; } } diff --git a/packages/apalis-core/src/memory.rs b/packages/apalis-core/src/memory.rs index dd832a5..7caeb1c 100644 --- a/packages/apalis-core/src/memory.rs +++ b/packages/apalis-core/src/memory.rs @@ -1,9 +1,10 @@ use crate::{ + backend::Backend, mq::MessageQueue, + poller::Poller, poller::{controller::Controller, stream::BackendStream}, request::{Request, RequestStream}, - worker::WorkerId, - Backend, Poller, + worker::{self, Worker}, }; use futures::{ channel::mpsc::{channel, Receiver, Sender}, @@ -14,7 +15,7 @@ use std::{ sync::Arc, task::{Context, Poll}, }; -use tower::{layer::util::Identity, ServiceBuilder}; +use tower::layer::util::Identity; #[derive(Debug)] /// An example of the basics of a backend @@ -52,8 +53,8 @@ impl Clone for MemoryStorage { /// In-memory queue that implements [Stream] #[derive(Debug)] pub struct MemoryWrapper { - sender: Sender, - receiver: Arc>>, + sender: Sender>, + receiver: Arc>>>, } impl Clone for MemoryWrapper { @@ -84,7 +85,7 @@ impl Default for MemoryWrapper { } impl Stream for MemoryWrapper { - type Item = T; + type Item = Request; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if let Some(mut receiver) = self.receiver.try_lock() { @@ -96,37 +97,44 @@ impl Stream for MemoryWrapper { } // MemoryStorage as a Backend -impl Backend> for MemoryStorage { - type Stream = BackendStream>>; +impl Backend, Res> for MemoryStorage { + type Stream = BackendStream>>; - type Layer = ServiceBuilder; + type Layer = Identity; - fn common_layer(&self, _worker: WorkerId) -> Self::Layer { - ServiceBuilder::new() - } - - fn poll(self, _worker: WorkerId) -> Poller { - let stream = self.inner.map(|r| Ok(Some(Request::new(r)))).boxed(); + fn poll(self, _worker: &Worker) -> Poller { + let stream = self.inner.map(|r| Ok(Some(r))).boxed(); Poller { stream: BackendStream::new(stream, self.controller), - heartbeat: Box::pin(async {}), + heartbeat: Box::pin(futures::future::pending()), + layer: Identity::new(), + _priv: (), } } } impl MessageQueue for MemoryStorage { type Error = (); - async fn enqueue(&self, message: Message) -> Result<(), Self::Error> { - self.inner.sender.clone().try_send(message).unwrap(); + async fn enqueue(&mut self, message: Message) -> Result<(), Self::Error> { + self.inner + .sender + .try_send(Request::new(message)) + .map_err(|_| ())?; Ok(()) } - async fn dequeue(&self) -> Result, ()> { - Err(()) - // self.inner.receiver.lock().await.next().await + async fn dequeue(&mut self) -> Result, ()> { + Ok(self + .inner + .receiver + .lock() + .await + .next() + .await + .map(|r| r.args)) } - async fn size(&self) -> Result { - Ok(self.inner.clone().count().await) + async fn size(&mut self) -> Result { + Ok(self.inner.receiver.lock().await.size_hint().0) } } diff --git a/packages/apalis-core/src/monitor/mod.rs b/packages/apalis-core/src/monitor/mod.rs index 393f58f..895337f 100644 --- a/packages/apalis-core/src/monitor/mod.rs +++ b/packages/apalis-core/src/monitor/mod.rs @@ -1,99 +1,67 @@ use std::{ - any::Any, fmt::{self, Debug, Formatter}, - sync::{Arc, RwLock}, + sync::Arc, }; use futures::{future::BoxFuture, Future, FutureExt}; -use tower::Service; -mod shutdown; +use serde::Serialize; +use tower::{Layer, Service}; + +/// Shutdown utilities +pub mod shutdown; use crate::{ + backend::Backend, error::BoxDynError, - executor::Executor, request::Request, - worker::{Context, Event, Ready, Worker}, - Backend, + worker::{Context, Event, EventHandler, Ready, Worker, WorkerId}, }; use self::shutdown::Shutdown; /// A monitor for coordinating and managing a collection of workers. -pub struct Monitor { - workers: Vec>>, - executor: E, - context: MonitorContext, +pub struct Monitor { + futures: Vec>, + workers: Vec>, terminator: Option>, -} - -/// The internal context of a [Monitor] -/// Usually shared with multiple workers -#[derive(Clone)] -pub struct MonitorContext { - #[allow(clippy::type_complexity)] - event_handler: Arc) + Send + Sync>>>>, shutdown: Shutdown, + event_handler: EventHandler, } -impl fmt::Debug for MonitorContext { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("MonitorContext") - .field("events", &self.event_handler.type_id()) - .field("shutdown", &"[Shutdown]") - .finish() - } -} - -impl MonitorContext { - fn new() -> MonitorContext { - Self { - event_handler: Arc::default(), - shutdown: Shutdown::new(), - } - } - - /// Get the shutdown handle - pub fn shutdown(&self) -> &Shutdown { - &self.shutdown - } - /// Get the events handle - pub fn notify(&self, event: Worker) { - let _ = self - .event_handler - .as_ref() - .read() - .map(|caller| caller.as_ref().map(|caller| caller(event))); - } -} - -impl Debug for Monitor { +impl Debug for Monitor { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("Monitor") .field("shutdown", &"[Graceful shutdown listener]") - .field("workers", &self.workers) - .field("executor", &std::any::type_name::()) + .field("workers", &self.futures.len()) .finish() } } -impl Monitor { +impl Monitor { /// Registers a single instance of a [Worker] - pub fn register< - J: Send + Sync + 'static, - S: Service> + Send + 'static + Clone, - P: Backend> + 'static, - >( - mut self, - worker: Worker>, - ) -> Self + pub fn register(mut self, mut worker: Worker>) -> Self where + S: Service, Response = Res> + Send + 'static, S::Future: Send, - S::Response: 'static, + S::Response: Send + Sync + Serialize + 'static, S::Error: Send + Sync + 'static + Into, -

>>::Stream: Unpin + Send + 'static, + P: Backend, Res> + Send + 'static, + P::Stream: Unpin + Send + 'static, + P::Layer: Layer + Send, + >::Service: Service, Response = Res> + Send, + <>::Service as Service>>::Future: Send, + <>::Service as Service>>::Error: + Send + Sync + Into, + Req: Send + Sync + 'static, + Ctx: Send + Sync + 'static, + Res: 'static, { - self.workers.push(worker.with_monitor(&self)); - + worker.state.shutdown = Some(self.shutdown.clone()); + worker.state.event_handler = self.event_handler.clone(); + let runnable = worker.run(); + let handle = runnable.get_handle(); + self.workers.push(handle); + self.futures.push(runnable.boxed()); self } @@ -107,23 +75,37 @@ impl Monitor { /// # Returns /// /// The monitor instance, with all workers added to the collection. - pub fn register_with_count< - J: Send + Sync + 'static, - S: Service> + Send + 'static + Clone, - P: Backend> + 'static, - >( + #[deprecated( + since = "0.6.0", + note = "Consider using the `.register` as workers now offer concurrency by default" + )] + pub fn register_with_count( mut self, count: usize, worker: Worker>, ) -> Self where + S: Service, Response = Res> + Send + 'static + Clone, S::Future: Send, - S::Response: 'static, + S::Response: Send + Sync + Serialize + 'static, S::Error: Send + Sync + 'static + Into, -

>>::Stream: Unpin + Send + 'static, + P: Backend, Res> + Send + 'static + Clone, + P::Stream: Unpin + Send + 'static, + P::Layer: Layer + Send, + >::Service: Service, Response = Res> + Send, + <>::Service as Service>>::Future: Send, + <>::Service as Service>>::Error: + Send + Sync + Into, + Req: Send + Sync + 'static, + Ctx: Send + Sync + 'static, + Res: 'static, { - let workers = worker.with_monitor_instances(count, &self); - self.workers.extend(workers); + for index in 0..count { + let mut worker = worker.clone(); + let name = format!("{}-{index}", worker.id()); + worker.id = WorkerId::new(name); + self = self.register(worker); + } self } /// Runs the monitor and all its registered workers until they have all completed or a shutdown signal is received. @@ -135,19 +117,35 @@ impl Monitor { /// # Errors /// /// If the monitor fails to shutdown gracefully, an `std::io::Error` will be returned. + /// + /// # Remarks + /// + /// If a timeout has been set using the `Monitor::shutdown_timeout` method, the monitor + /// will wait for all workers to complete up to the timeout duration before exiting. + /// If the timeout is reached and workers have not completed, the monitor will exit forcefully. - pub async fn run_with_signal>>( - self, - signal: S, - ) -> std::io::Result<()> + pub async fn run_with_signal(self, signal: S) -> std::io::Result<()> where - E: Executor + Clone + Send + 'static, + S: Send + Future>, { - let shutdown = self.context.shutdown.clone(); - let shutdown_after = self.context.shutdown.shutdown_after(signal); - let runner = self.run(); - futures::try_join!(shutdown_after, runner)?; - shutdown.await; + let shutdown = self.shutdown.clone(); + let shutdown_after = self.shutdown.shutdown_after(signal); + if let Some(terminator) = self.terminator { + let _res = futures::future::select( + futures::future::join_all(self.futures) + .map(|_| shutdown.start_shutdown()) + .boxed(), + async { + let _res = shutdown_after.await; + terminator.await; + } + .boxed(), + ) + .await; + } else { + let runner = self.run(); + let _res = futures::join!(shutdown_after, runner); // If no terminator is provided, we wait for both the shutdown call and all workers to complete + } Ok(()) } @@ -155,103 +153,51 @@ impl Monitor { /// /// # Errors /// - /// If the monitor fails to shutdown gracefully, an `std::io::Error` will be returned. + /// If the monitor fails to run gracefully, an `std::io::Error` will be returned. /// /// # Remarks /// - /// If a timeout has been set using the `shutdown_timeout` method, the monitor - /// will wait for all workers to complete up to the timeout duration before exiting. - /// If the timeout is reached and workers have not completed, the monitor will exit forcefully. - pub async fn run(self) -> std::io::Result<()> - where - E: Executor + Clone + Send + 'static, - { - let mut futures = Vec::new(); - for worker in self.workers { - futures.push(worker.run().boxed()); - } - let shutdown_future = self.context.shutdown.boxed().map(|_| ()); - if let Some(terminator) = self.terminator { - let runner = futures::future::select( - futures::future::join_all(futures).map(|_| ()), - shutdown_future, - ); - futures::join!(runner, terminator); - } else { - futures::join!( - futures::future::join_all(futures).map(|_| ()), - shutdown_future, - ); - } + /// If all workers have completed execution, then by default the monitor will start a shutdown + pub async fn run(self) -> std::io::Result<()> { + let shutdown = self.shutdown.clone(); + let shutdown_future = self.shutdown.boxed().map(|_| ()); + futures::join!( + futures::future::join_all(self.futures).map(|_| shutdown.start_shutdown()), + shutdown_future, + ); + Ok(()) } /// Handles events emitted pub fn on_event) + Send + Sync + 'static>(self, f: F) -> Self { - let _ = self.context.event_handler.write().map(|mut res| { + let _ = self.event_handler.write().map(|mut res| { let _ = res.insert(Box::new(f)); }); self } - /// Get the current executor - pub fn executor(&self) -> &E { - &self.executor - } - - pub(crate) fn context(&self) -> &MonitorContext { - &self.context - } } -impl Default for Monitor { +impl Default for Monitor { fn default() -> Self { Self { - executor: E::default(), - context: MonitorContext::new(), - workers: Vec::new(), + shutdown: Shutdown::new(), + futures: Vec::new(), terminator: None, + event_handler: Arc::default(), + workers: Vec::new(), } } } -impl Monitor { +impl Monitor { /// Creates a new monitor instance. /// /// # Returns /// /// A new monitor instance, with an empty collection of workers. - pub fn new() -> Self - where - E: Default, - { - Self::new_with_executor(E::default()) - } - /// Creates a new monitor instance with an executor - /// - /// # Returns - /// - /// A new monitor instance, with an empty collection of workers. - pub fn new_with_executor(executor: E) -> Self { - Self { - context: MonitorContext::new(), - workers: Vec::new(), - executor, - terminator: None, - } - } - - /// Sets a custom executor for the monitor, allowing the usage of another runtime apart from Tokio. - /// The executor must implement the `Executor` trait. - pub fn set_executor(self, executor: NE) -> Monitor { - if !self.workers.is_empty() { - panic!("Tried changing executor when already loaded some workers"); - } - Monitor { - context: self.context, - workers: Vec::new(), - executor, - terminator: self.terminator, - } + pub fn new() -> Self { + Self::default() } /// Sets a timeout duration for the monitor's shutdown process. @@ -283,6 +229,7 @@ impl Monitor { #[cfg(test)] mod tests { + use crate::test_utils::apalis_test_service_fn; use std::{io, time::Duration}; use tokio::time::sleep; @@ -293,62 +240,65 @@ mod tests { monitor::Monitor, mq::MessageQueue, request::Request, - TestExecutor, + test_message_queue, + test_utils::TestWrapper, }; + test_message_queue!(MemoryStorage::new()); + #[tokio::test] - async fn it_works() { + async fn it_works_with_workers() { let backend = MemoryStorage::new(); - let handle = backend.clone(); + let mut handle = backend.clone(); tokio::spawn(async move { for i in 0..10 { handle.enqueue(i).await.unwrap(); } }); - let service = tower::service_fn(|request: Request| async { + let service = tower::service_fn(|request: Request| async { tokio::time::sleep(Duration::from_secs(1)).await; Ok::<_, io::Error>(request) }); let worker = WorkerBuilder::new("rango-tango") - .source(backend) + .backend(backend) .build(service); - let monitor: Monitor = Monitor::new(); + let monitor: Monitor = Monitor::new(); let monitor = monitor.register(worker); - let shutdown = monitor.context.shutdown.clone(); + let shutdown = monitor.shutdown.clone(); tokio::spawn(async move { sleep(Duration::from_millis(1500)).await; - shutdown.shutdown(); + shutdown.start_shutdown(); }); monitor.run().await.unwrap(); } #[tokio::test] async fn test_monitor_run() { let backend = MemoryStorage::new(); - let handle = backend.clone(); + let mut handle = backend.clone(); tokio::spawn(async move { - for i in 0..1000 { + for i in 0..10 { handle.enqueue(i).await.unwrap(); } }); - let service = tower::service_fn(|request: Request| async { + let service = tower::service_fn(|request: Request| async { tokio::time::sleep(Duration::from_secs(1)).await; Ok::<_, io::Error>(request) }); let worker = WorkerBuilder::new("rango-tango") - .source(backend) + .backend(backend) .build(service); - let monitor: Monitor = Monitor::new(); + let monitor: Monitor = Monitor::new(); let monitor = monitor.on_event(|e| { println!("{e:?}"); }); - let monitor = monitor.register_with_count(5, worker); - assert_eq!(monitor.workers.len(), 5); - let shutdown = monitor.context.shutdown.clone(); + let monitor = monitor.register(worker); + assert_eq!(monitor.futures.len(), 1); + let shutdown = monitor.shutdown.clone(); tokio::spawn(async move { sleep(Duration::from_millis(1000)).await; - shutdown.shutdown(); + shutdown.start_shutdown(); }); let result = monitor.run().await; diff --git a/packages/apalis-core/src/monitor/shutdown.rs b/packages/apalis-core/src/monitor/shutdown.rs index 83d188a..ce31554 100644 --- a/packages/apalis-core/src/monitor/shutdown.rs +++ b/packages/apalis-core/src/monitor/shutdown.rs @@ -16,17 +16,19 @@ pub struct Shutdown { } impl Shutdown { + /// Create a new shutdown handle pub fn new() -> Shutdown { Shutdown { inner: Arc::new(ShutdownCtx::new()), } } + /// Set the future to await before shutting down pub fn shutdown_after(&self, f: F) -> impl Future { let handle = self.clone(); async move { let result = f.await; - handle.shutdown(); + handle.start_shutdown(); result } } @@ -51,7 +53,6 @@ impl ShutdownCtx { } } fn shutdown(&self) { - // Set the shutdown state to true self.state.store(true, Ordering::Relaxed); self.wake(); } @@ -68,11 +69,13 @@ impl ShutdownCtx { } impl Shutdown { + /// Check if the system is shutting down pub fn is_shutting_down(&self) -> bool { self.inner.is_shutting_down() } - pub fn shutdown(&self) { + /// Start the shutdown process + pub fn start_shutdown(&self) { self.inner.shutdown() } } diff --git a/packages/apalis-core/src/mq/mod.rs b/packages/apalis-core/src/mq/mod.rs index 3a0eb61..e6d9a2d 100644 --- a/packages/apalis-core/src/mq/mod.rs +++ b/packages/apalis-core/src/mq/mod.rs @@ -4,36 +4,19 @@ use futures::Future; -use crate::{request::Request, Backend}; - /// Represents a message queue that can be pushed and consumed. -pub trait MessageQueue: Backend> { +pub trait MessageQueue { /// The error produced by the queue type Error; /// Enqueues a message to the queue. - fn enqueue(&self, message: Message) -> impl Future> + Send; + fn enqueue(&mut self, message: Message) + -> impl Future> + Send; /// Attempts to dequeue a message from the queue. /// Returns `None` if the queue is empty. - fn dequeue(&self) -> impl Future, Self::Error>> + Send; + fn dequeue(&mut self) -> impl Future, Self::Error>> + Send; /// Returns the current size of the queue. - fn size(&self) -> impl Future> + Send; -} - -/// Trait representing a job. -/// -/// -/// # Example -/// ```rust -/// # use apalis_core::mq::Message; -/// # struct Email; -/// impl Message for Email { -/// const NAME: &'static str = "redis::Email"; -/// } -/// ``` -pub trait Message { - /// Represents the name for job. - const NAME: &'static str; + fn size(&mut self) -> impl Future> + Send; } diff --git a/packages/apalis-core/src/poller/mod.rs b/packages/apalis-core/src/poller/mod.rs index 4e3855d..fb3468d 100644 --- a/packages/apalis-core/src/poller/mod.rs +++ b/packages/apalis-core/src/poller/mod.rs @@ -1,8 +1,6 @@ use futures::{future::BoxFuture, Future, FutureExt}; -use std::{ - fmt::{self, Debug}, - ops::{Deref, DerefMut}, -}; +use std::fmt::{self, Debug}; +use tower::layer::util::Identity; /// Util for controlling pollers pub mod controller; @@ -10,29 +8,47 @@ pub mod controller; pub mod stream; /// A poller type that allows fetching from a stream and a heartbeat future that can be used to do periodic tasks -pub struct Poller { - pub(crate) stream: S, - pub(crate) heartbeat: BoxFuture<'static, ()>, +pub struct Poller { + /// The stream of jobs + pub stream: S, + /// The heartbeat for the backend + pub heartbeat: BoxFuture<'static, ()>, + /// The tower middleware provided by the backend + pub layer: L, + pub(crate) _priv: (), } -impl Poller { +impl Poller { /// Build a new poller pub fn new(stream: S, heartbeat: impl Future + Send + 'static) -> Self { - Self { + Self::new_with_layer(stream, heartbeat, Identity::new()) + } + + /// Build a poller with layer + pub fn new_with_layer( + stream: S, + heartbeat: impl Future + Send + 'static, + layer: L, + ) -> Poller { + Poller { stream, heartbeat: heartbeat.boxed(), + layer, + _priv: (), } } } -impl Debug for Poller +impl Debug for Poller where S: Debug, + L: Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Poller") .field("stream", &self.stream) .field("heartbeat", &"...") + .field("layer", &self.layer) .finish() } } @@ -40,28 +56,3 @@ where const STOPPED: usize = 2; const PLUGGED: usize = 1; const UNPLUGGED: usize = 0; - -/// Tells the poller that the worker is ready for a new request -#[derive(Debug)] -pub struct FetchNext { - sender: async_oneshot::Sender, -} - -impl Deref for FetchNext { - type Target = async_oneshot::Sender; - fn deref(&self) -> &Self::Target { - &self.sender - } -} - -impl DerefMut for FetchNext { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.sender - } -} -impl FetchNext { - /// Generate a new instance of ready - pub fn new(sender: async_oneshot::Sender) -> Self { - Self { sender } - } -} diff --git a/packages/apalis-core/src/request.rs b/packages/apalis-core/src/request.rs index 2468b76..21d04ef 100644 --- a/packages/apalis-core/src/request.rs +++ b/packages/apalis-core/src/request.rs @@ -1,55 +1,154 @@ use futures::{future::BoxFuture, Stream}; use serde::{Deserialize, Serialize}; -use tower::{layer::util::Identity, ServiceBuilder}; +use tower::layer::util::Identity; -use std::{fmt::Debug, pin::Pin}; +use std::{fmt, fmt::Debug, pin::Pin, str::FromStr}; -use crate::{data::Extensions, error::Error, poller::Poller, worker::WorkerId, Backend}; +use crate::{ + backend::Backend, + data::Extensions, + error::Error, + poller::Poller, + task::{attempt::Attempt, namespace::Namespace, task_id::TaskId}, + worker::{Context, Worker}, +}; /// Represents a job which can be serialized and executed -#[derive(Serialize, Debug, Deserialize, Clone)] -pub struct Request { - pub(crate) req: T, +#[derive(Serialize, Debug, Deserialize, Clone, Default)] +pub struct Request { + /// The inner request part + pub args: Args, + /// Parts of the request eg id, attempts and context + pub parts: Parts, +} + +/// Component parts of a `Request` +#[non_exhaustive] +#[derive(Serialize, Debug, Deserialize, Clone, Default)] +pub struct Parts { + /// The request's id + pub task_id: TaskId, + + /// The request's extensions + #[serde(skip)] + pub data: Extensions, + + /// The request's attempts + pub attempt: Attempt, + + /// The Context stored by the storage + pub context: Ctx, + + /// Represents the namespace #[serde(skip)] - pub(crate) data: Extensions, + pub namespace: Option, + //TODO: add State } -impl Request { +impl Request { /// Creates a new [Request] - pub fn new(req: T) -> Self { - Self { - req, - data: Extensions::new(), - } + pub fn new(args: T) -> Self { + Self::new_with_data(args, Extensions::default(), Ctx::default()) + } + + /// Creates a request with all parts provided + pub fn new_with_parts(args: T, parts: Parts) -> Self { + Self { args, parts } } /// Creates a request with context provided - pub fn new_with_data(req: T, data: Extensions) -> Self { - Self { req, data } + pub fn new_with_ctx(req: T, ctx: Ctx) -> Self { + Self { + args: req, + parts: Parts { + context: ctx, + ..Default::default() + }, + } } - /// Get the underlying reference of the request - pub fn inner(&self) -> &T { - &self.req + /// Creates a request with data and context provided + pub fn new_with_data(req: T, data: Extensions, ctx: Ctx) -> Self { + Self { + args: req, + parts: Parts { + context: ctx, + data, + ..Default::default() + }, + } } - /// Take the underlying reference of the request - pub fn take(self) -> T { - self.req + /// Take the parts + pub fn take_parts(self) -> (T, Parts) { + (self.args, self.parts) } } -impl std::ops::Deref for Request { +impl std::ops::Deref for Request { type Target = Extensions; fn deref(&self) -> &Self::Target { - &self.data + &self.parts.data } } -impl std::ops::DerefMut for Request { +impl std::ops::DerefMut for Request { fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.data + &mut self.parts.data + } +} + +/// Represents the state of a job/task +#[derive(Serialize, Deserialize, Debug, Clone, Hash, PartialEq, std::cmp::Eq)] +pub enum State { + /// Job is pending + #[serde(alias = "Latest")] + Pending, + /// Job is in the queue but not ready for execution + Scheduled, + /// Job is running + Running, + /// Job was done successfully + Done, + /// Job has failed. Check `last_error` + Failed, + /// Job has been killed + Killed, +} + +impl Default for State { + fn default() -> Self { + State::Pending + } +} + +impl FromStr for State { + type Err = Error; + + fn from_str(s: &str) -> Result { + match s { + "Pending" | "Latest" => Ok(State::Pending), + "Running" => Ok(State::Running), + "Done" => Ok(State::Done), + "Failed" => Ok(State::Failed), + "Killed" => Ok(State::Killed), + "Scheduled" => Ok(State::Scheduled), + _ => Err(Error::MissingData("Invalid Job state".to_string())), + } + } +} + +impl fmt::Display for State { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self { + State::Pending => write!(f, "Pending"), + State::Running => write!(f, "Running"), + State::Done => write!(f, "Done"), + State::Failed => write!(f, "Failed"), + State::Killed => write!(f, "Killed"), + State::Scheduled => write!(f, "Scheduled"), + } } } @@ -61,19 +160,17 @@ pub type RequestFuture = BoxFuture<'static, T>; /// Represents a stream for T. pub type RequestStream = BoxStream<'static, Result, Error>>; -impl Backend> for RequestStream> { +impl Backend, Res> for RequestStream> { type Stream = Self; - type Layer = ServiceBuilder; - - fn common_layer(&self, _worker: WorkerId) -> Self::Layer { - ServiceBuilder::new() - } + type Layer = Identity; - fn poll(self, _worker: WorkerId) -> Poller { + fn poll(self, _worker: &Worker) -> Poller { Poller { stream: self, - heartbeat: Box::pin(async {}), + heartbeat: Box::pin(futures::future::pending()), + layer: Identity::new(), + _priv: (), } } } diff --git a/packages/apalis-core/src/response.rs b/packages/apalis-core/src/response.rs index 7c2a231..eb917be 100644 --- a/packages/apalis-core/src/response.rs +++ b/packages/apalis-core/src/response.rs @@ -1,6 +1,116 @@ -use std::any::Any; +use std::{any::Any, fmt::Debug, sync::Arc}; -use crate::error::Error; +use crate::{ + error::Error, + task::{attempt::Attempt, task_id::TaskId}, +}; + +/// A generic `Response` struct that wraps the result of a task, containing the outcome (`Ok` or `Err`), +/// task metadata such as `task_id`, `attempt`, and an internal marker field for future extensions. +/// +/// # Type Parameters +/// - `Res`: The successful result type of the response. +/// +/// # Fields +/// - `inner`: A `Result` that holds either the success value of type `Res` or an `Error` on failure. +/// - `task_id`: A `TaskId` representing the unique identifier for the task. +/// - `attempt`: An `Attempt` representing how many attempts were made to complete the task. +/// - `_priv`: A private marker field to prevent external construction of the `Response`. +#[derive(Debug, Clone)] +pub struct Response { + /// The result from a task + pub inner: Result, + /// The task id + pub task_id: TaskId, + /// The current attempt + pub attempt: Attempt, + pub(crate) _priv: (), +} + +impl Response { + /// Creates a new `Response` instance. + /// + /// # Arguments + /// - `inner`: A `Result` holding either a successful response of type `Res` or an `Error`. + /// - `task_id`: A `TaskId` representing the unique identifier for the task. + /// - `attempt`: The attempt count when creating this response. + /// + /// # Returns + /// A new `Response` instance. + pub fn new(inner: Result, task_id: TaskId, attempt: Attempt) -> Self { + Response { + inner, + task_id, + attempt, + _priv: (), + } + } + + /// Constructs a successful `Response`. + /// + /// # Arguments + /// - `res`: The success value of type `Res`. + /// - `task_id`: A `TaskId` representing the unique identifier for the task. + /// - `attempt`: The attempt count when creating this response. + /// + /// # Returns + /// A `Response` instance containing the success value. + pub fn success(res: Res, task_id: TaskId, attempt: Attempt) -> Self { + Self::new(Ok(res), task_id, attempt) + } + + /// Constructs a failed `Response`. + /// + /// # Arguments + /// - `error`: The `Error` that occurred. + /// - `task_id`: A `TaskId` representing the unique identifier for the task. + /// - `attempt`: The attempt count when creating this response. + /// + /// # Returns + /// A `Response` instance containing the error. + pub fn failure(error: Error, task_id: TaskId, attempt: Attempt) -> Self { + Self::new(Err(error), task_id, attempt) + } + + /// Checks if the `Response` contains a success (`Ok`). + /// + /// # Returns + /// `true` if the `Response` is successful, `false` otherwise. + pub fn is_success(&self) -> bool { + self.inner.is_ok() + } + + /// Checks if the `Response` contains a failure (`Err`). + /// + /// # Returns + /// `true` if the `Response` is a failure, `false` otherwise. + pub fn is_failure(&self) -> bool { + self.inner.is_err() + } + + /// Maps the success value (`Res`) of the `Response` to another type using the provided function. + /// + /// # Arguments + /// - `f`: A function that takes a reference to the success value and returns a new value of type `T`. + /// + /// # Returns + /// A new `Response` with the transformed success value or the same error. + /// + /// # Type Parameters + /// - `F`: A function or closure that takes a reference to a value of type `Res` and returns a value of type `T`. + /// - `T`: The new type of the success value after mapping. + pub fn map(&self, f: F) -> Response + where + F: FnOnce(&Res) -> T, + { + Response { + inner: self.inner.as_ref().map(f).map_err(|e| e.clone()), + task_id: self.task_id.clone(), + attempt: self.attempt.clone(), + _priv: (), + } + } +} /// Helper for Job Responses pub trait IntoResponse { @@ -15,22 +125,30 @@ impl IntoResponse for bool { fn into_response(self) -> std::result::Result { match self { true => Ok(true), - false => Err(Error::Failed(Box::new(std::io::Error::new( + false => Err(Error::Failed(Arc::new(Box::new(std::io::Error::new( std::io::ErrorKind::Other, "Job returned false", - )))), + ))))), } } } -impl IntoResponse +impl IntoResponse for std::result::Result { type Result = Result; fn into_response(self) -> Result { match self { Ok(value) => Ok(value), - Err(e) => Err(Error::Failed(Box::new(e))), + Err(e) => { + // Try to downcast the error to see if it is already of type `Error` + if let Some(custom_error) = + (&e as &(dyn std::error::Error + 'static)).downcast_ref::() + { + return Err(custom_error.clone()); + } + Err(Error::Failed(Arc::new(Box::new(e)))) + } } } } diff --git a/packages/apalis-core/src/service_fn.rs b/packages/apalis-core/src/service_fn.rs index 66007e4..85ef4e3 100644 --- a/packages/apalis-core/src/service_fn.rs +++ b/packages/apalis-core/src/service_fn.rs @@ -1,3 +1,4 @@ +use crate::error::Error; use crate::layers::extensions::Data; use crate::request::Request; use crate::response::IntoResponse; @@ -10,20 +11,25 @@ use std::task::{Context, Poll}; use tower::Service; /// A helper method to build functions -pub fn service_fn(f: T) -> ServiceFn { - ServiceFn { f, k: PhantomData } +pub fn service_fn(f: T) -> ServiceFn { + ServiceFn { + f, + req: PhantomData, + fn_args: PhantomData, + } } /// An executable service implemented by a closure. /// /// See [`service_fn`] for more details. #[derive(Copy, Clone)] -pub struct ServiceFn { +pub struct ServiceFn { f: T, - k: PhantomData, + req: PhantomData>, + fn_args: PhantomData, } -impl fmt::Debug for ServiceFn { +impl fmt::Debug for ServiceFn { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ServiceFn") .field("f", &format_args!("{}", std::any::type_name::())) @@ -34,48 +40,62 @@ impl fmt::Debug for ServiceFn { /// The Future returned from [`ServiceFn`] service. pub type FnFuture = Map std::result::Result>; -/// Allows getting some type from the [Request] data -pub trait FromData: Sized + Clone + Send + Sync + 'static { - /// Gets the value - fn get(data: &crate::data::Extensions) -> Self { - data.get::().unwrap().clone() - } +/// Handles extraction +pub trait FromRequest: Sized { + /// Perform the extraction. + fn from_request(req: &Req) -> Result; } -impl FromData for Data { - fn get(ctx: &crate::data::Extensions) -> Self { - Data::new(ctx.get::().unwrap().clone()) +impl FromRequest> for Data { + fn from_request(req: &Request) -> Result { + req.parts.data.get_checked().cloned().map(Data::new) } } macro_rules! impl_service_fn { ($($K:ident),+) => { #[allow(unused_parens)] - impl Service> for ServiceFn + impl Service> for ServiceFn where T: FnMut(Req, $($K),+) -> F, F: Future, F::Output: IntoResponse>, - $($K: FromData),+, + $($K: FromRequest>),+, + E: From { type Response = R; type Error = E; - type Future = FnFuture; + type Future = futures::future::Either>, FnFuture>; fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn call(&mut self, task: Request) -> Self::Future { - let fut = (self.f)(task.req, $($K::get(&task.data)),+); + fn call(&mut self, task: Request) -> Self::Future { + + #[allow(non_snake_case)] + let fut = { + let results: Result<($($K),+), E> = (|| { + Ok(($($K::from_request(&task)?),+)) + })(); + + match results { + Ok(($($K),+)) => { + let req = task.args; + (self.f)(req, $($K),+) + } + Err(e) => return futures::future::Either::Left(futures::future::err(e).into()), + } + }; + - fut.map(F::Output::into_response) + futures::future::Either::Right(fut.map(F::Output::into_response)) } } }; } -impl Service> for ServiceFn +impl Service> for ServiceFn where T: FnMut(Req) -> F, F: Future, @@ -89,8 +109,8 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, task: Request) -> Self::Future { - let fut = (self.f)(task.req); + fn call(&mut self, task: Request) -> Self::Future { + let fut = (self.f)(task.args); fut.map(F::Output::into_response) } diff --git a/packages/apalis-core/src/storage/mod.rs b/packages/apalis-core/src/storage/mod.rs index 557154d..c761c74 100644 --- a/packages/apalis-core/src/storage/mod.rs +++ b/packages/apalis-core/src/storage/mod.rs @@ -1,78 +1,78 @@ use std::time::Duration; -use futures::{stream::BoxStream, Future}; +use futures::Future; -use crate::{request::Request, Backend}; - -/// The result of sa stream produced by a [Storage] -pub type StorageStream = BoxStream<'static, Result>, E>>; +use crate::{ + request::{Parts, Request}, + task::task_id::TaskId, +}; /// Represents a [Storage] that can persist a request. -/// The underlying type must implement [Job] -pub trait Storage: Backend> { +pub trait Storage { /// The type of job that can be persisted - type Job: Job; + type Job; /// The error produced by the storage type Error; - /// Jobs must have Ids. - type Identifier; + /// This is the type that storages store as the metadata related to a job + type Context: Default; /// Pushes a job to a storage fn push( &mut self, job: Self::Job, - ) -> impl Future> + Send; + ) -> impl Future, Self::Error>> + Send { + self.push_request(Request::new(job)) + } + + /// Pushes a constructed request to a storage + fn push_request( + &mut self, + req: Request, + ) -> impl Future, Self::Error>> + Send; - /// Push a job into the scheduled set + /// Push a job with defaults into the scheduled set fn schedule( &mut self, job: Self::Job, on: i64, - ) -> impl Future> + Send; + ) -> impl Future, Self::Error>> + Send { + self.schedule_request(Request::new(job), on) + } + + /// Push a request into the scheduled set + fn schedule_request( + &mut self, + request: Request, + on: i64, + ) -> impl Future, Self::Error>> + Send; /// Return the number of pending jobs from the queue - fn len(&self) -> impl Future> + Send; + fn len(&mut self) -> impl Future> + Send; /// Fetch a job given an id fn fetch_by_id( - &self, - job_id: &Self::Identifier, - ) -> impl Future>, Self::Error>> + Send; + &mut self, + job_id: &TaskId, + ) -> impl Future>, Self::Error>> + Send; /// Update a job details fn update( - &self, - job: Request, + &mut self, + job: Request, ) -> impl Future> + Send; /// Reschedule a job fn reschedule( &mut self, - job: Request, + job: Request, wait: Duration, ) -> impl Future> + Send; /// Returns true if there is no jobs in the storage - fn is_empty(&self) -> impl Future> + Send; + fn is_empty(&mut self) -> impl Future> + Send; /// Vacuum the storage, removes done and killed jobs - fn vacuum(&self) -> impl Future> + Send; -} - -/// Trait representing a job. -/// -/// -/// # Example -/// ```rust -/// # use apalis_core::storage::Job; -/// # struct Email; -/// impl Job for Email { -/// const NAME: &'static str = "apalis::Email"; -/// } -/// ``` -pub trait Job { - /// Represents the name for job. - const NAME: &'static str; + fn vacuum(&mut self) -> impl Future> + Send; } diff --git a/packages/apalis-core/src/task/attempt.rs b/packages/apalis-core/src/task/attempt.rs index ba557c4..9c1d84e 100644 --- a/packages/apalis-core/src/task/attempt.rs +++ b/packages/apalis-core/src/task/attempt.rs @@ -1,9 +1,52 @@ -use std::sync::{atomic::AtomicUsize, Arc}; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; + +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use crate::{request::Request, service_fn::FromRequest}; /// A wrapper to keep count of the attempts tried by a task #[derive(Debug, Clone)] pub struct Attempt(Arc); +// Custom serialization function +fn serialize(attempt: &Attempt, serializer: S) -> Result +where + S: Serializer, +{ + let value = attempt.0.load(Ordering::SeqCst); + serializer.serialize_u64(value as u64) +} + +// Custom deserialization function +fn deserialize<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + let value = u64::deserialize(deserializer)?; + Ok(Attempt(Arc::new(AtomicUsize::new(value as usize)))) +} + +impl Serialize for Attempt { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serialize(self, serializer) + } +} + +impl<'de> Deserialize<'de> for Attempt { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserialize(deserializer) + } +} + impl Default for Attempt { fn default() -> Self { Self(Arc::new(AtomicUsize::new(0))) @@ -31,3 +74,9 @@ impl Attempt { self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed) } } + +impl FromRequest> for Attempt { + fn from_request(req: &Request) -> Result { + Ok(req.parts.attempt.clone()) + } +} diff --git a/packages/apalis-core/src/task/mod.rs b/packages/apalis-core/src/task/mod.rs index e50b782..169bd61 100644 --- a/packages/apalis-core/src/task/mod.rs +++ b/packages/apalis-core/src/task/mod.rs @@ -1,4 +1,6 @@ /// A unique tracker for number of attempts pub mod attempt; +/// A wrapper type for storing the namespace +pub mod namespace; /// A unique ID that can be used by a backend pub mod task_id; diff --git a/packages/apalis-core/src/task/namespace.rs b/packages/apalis-core/src/task/namespace.rs new file mode 100644 index 0000000..dfed96b --- /dev/null +++ b/packages/apalis-core/src/task/namespace.rs @@ -0,0 +1,52 @@ +use std::convert::From; +use std::fmt::{self, Display, Formatter}; +use std::ops::Deref; + +use serde::{Deserialize, Serialize}; + +use crate::error::Error; +use crate::request::Request; +use crate::service_fn::FromRequest; + +/// A wrapper type that defines a task's namespace. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Namespace(pub String); + +impl Deref for Namespace { + type Target = String; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Display for Namespace { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for Namespace { + fn from(s: String) -> Self { + Namespace(s) + } +} + +impl From for String { + fn from(value: Namespace) -> String { + value.0 + } +} + +impl AsRef for Namespace { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl FromRequest> for Namespace { + fn from_request(req: &Request) -> Result { + let msg = "Missing `Namespace`. This is a bug, please file a report with the backend you are using".to_owned(); + req.parts.namespace.clone().ok_or(Error::MissingData(msg)) + } +} diff --git a/packages/apalis-core/src/task/task_id.rs b/packages/apalis-core/src/task/task_id.rs index 6d6a250..2296705 100644 --- a/packages/apalis-core/src/task/task_id.rs +++ b/packages/apalis-core/src/task/task_id.rs @@ -6,8 +6,10 @@ use std::{ use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer}; use ulid::Ulid; +use crate::{error::Error, request::Request, service_fn::FromRequest}; + /// A wrapper type that defines a task id. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, Hash, PartialEq)] pub struct TaskId(Ulid); impl TaskId { @@ -58,6 +60,12 @@ impl<'de> Deserialize<'de> for TaskId { } } +impl FromRequest> for TaskId { + fn from_request(req: &Request) -> Result { + Ok(req.parts.task_id.clone()) + } +} + struct TaskIdVisitor; impl<'de> Visitor<'de> for TaskIdVisitor { diff --git a/packages/apalis-core/src/worker/mod.rs b/packages/apalis-core/src/worker/mod.rs index c92bef5..57261b4 100644 --- a/packages/apalis-core/src/worker/mod.rs +++ b/packages/apalis-core/src/worker/mod.rs @@ -1,15 +1,13 @@ -use self::stream::WorkerStream; +use crate::backend::Backend; use crate::error::{BoxDynError, Error}; -use crate::executor::Executor; use crate::layers::extensions::Data; -use crate::monitor::{Monitor, MonitorContext}; -use crate::notify::Notify; -use crate::poller::FetchNext; +use crate::monitor::shutdown::Shutdown; use crate::request::Request; -use crate::service_fn::FromData; -use crate::Backend; -use futures::future::Shared; -use futures::{Future, FutureExt}; +use crate::service_fn::FromRequest; +use crate::task::task_id::TaskId; +use futures::future::{join, select, BoxFuture}; +use futures::stream::BoxStream; +use futures::{Future, FutureExt, Stream, StreamExt}; use pin_project_lite::pin_project; use serde::{Deserialize, Serialize}; use std::fmt::Debug; @@ -18,70 +16,32 @@ use std::ops::{Deref, DerefMut}; use std::pin::Pin; use std::str::FromStr; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, RwLock}; use std::task::{Context as TaskCtx, Poll, Waker}; use thiserror::Error; -use tower::{Service, ServiceBuilder, ServiceExt}; - -mod stream; -// By default a worker starts 3 futures, one for polling, one for worker stream and the other for consuming. -const WORKER_FUTURES: usize = 3; - -type WorkerNotify = Notify>>; +use tower::util::CallAllUnordered; +use tower::{Layer, Service, ServiceBuilder}; /// A worker name wrapper usually used by Worker builder #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct WorkerId { name: String, - instance: Option, } +/// An event handler for [`Worker`] +pub type EventHandler = Arc) + Send + Sync>>>>; + impl FromStr for WorkerId { type Err = (); fn from_str(s: &str) -> Result { - let mut parts: Vec<&str> = s.rsplit('-').collect(); - - match parts.len() { - 1 => Ok(WorkerId { - name: parts[0].to_string(), - instance: None, - }), - _ => { - let instance_str = parts[0]; - match instance_str.parse() { - Ok(instance) => { - let remainder = &mut parts[1..]; - remainder.reverse(); - let name = remainder.join("-"); - Ok(WorkerId { - name: name.to_string(), - instance: Some(instance), - }) - } - Err(_) => Ok(WorkerId { - name: { - let all = &mut parts[0..]; - all.reverse(); - all.join("-") - }, - instance: None, - }), - } - } - } + Ok(WorkerId { name: s.to_owned() }) } } -impl FromData for WorkerId {} - impl Display for WorkerId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str(self.name())?; - if let Some(instance) = self.instance { - f.write_str("-")?; - f.write_str(&instance.to_string())?; - } Ok(()) } } @@ -91,26 +51,13 @@ impl WorkerId { pub fn new>(name: T) -> Self { Self { name: name.as_ref().to_string(), - instance: None, } } - /// Build a new worker ref - pub fn new_with_instance>(name: T, instance: usize) -> Self { - Self { - name: name.as_ref().to_string(), - instance: Some(instance), - } - } /// Get the name of the worker pub fn name(&self) -> &str { &self.name } - - /// Get the name of the worker - pub fn instance(&self) -> &Option { - &self.instance - } } /// Events emitted by a worker @@ -119,9 +66,11 @@ pub enum Event { /// Worker started Start, /// Worker got a job - Engage, + Engage(TaskId), /// Worker is idle, stream has no new request for now Idle, + /// A custom event + Custom(String), /// Worker encountered an error Error(BoxDynError), /// Worker stopped @@ -130,8 +79,24 @@ pub enum Event { Exit, } +impl fmt::Display for Worker { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let event_description = match &self.state { + Event::Start => "Worker started".to_string(), + Event::Engage(task_id) => format!("Worker engaged with Task ID: {}", task_id), + Event::Idle => "Worker is idle".to_string(), + Event::Custom(msg) => format!("Custom event: {}", msg), + Event::Error(err) => format!("Worker encountered an error: {}", err), + Event::Stop => "Worker stopped".to_string(), + Event::Exit => "Worker completed all pending tasks and exited".to_string(), + }; + + write!(f, "Worker [{}]: {}", self.id.name, event_description) + } +} + /// Possible errors that can occur when starting a worker. -#[derive(Error, Debug)] +#[derive(Error, Debug, Clone)] pub enum WorkerError { /// An error occurred while processing a job. #[error("Failed to process job: {0}")] @@ -145,26 +110,60 @@ pub enum WorkerError { } /// A worker that is ready for running -#[derive(Debug)] pub struct Ready { service: S, backend: P, + pub(crate) shutdown: Option, + pub(crate) event_handler: EventHandler, +} + +impl fmt::Debug for Ready +where + S: fmt::Debug, + P: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Ready") + .field("service", &self.service) + .field("backend", &self.backend) + .field("shutdown", &self.shutdown) + .field("event_handler", &"...") // Avoid dumping potentially sensitive or verbose data + .finish() + } } + +impl Clone for Ready +where + S: Clone, + P: Clone, +{ + fn clone(&self) -> Self { + Ready { + service: self.service.clone(), + backend: self.backend.clone(), + shutdown: self.shutdown.clone(), + event_handler: self.event_handler.clone(), + } + } +} + impl Ready { /// Build a worker that is ready for execution pub fn new(service: S, poller: P) -> Self { Ready { service, backend: poller, + shutdown: None, + event_handler: EventHandler::default(), } } } /// Represents a generic [Worker] that can be in many different states -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize)] pub struct Worker { - id: WorkerId, - state: T, + pub(crate) id: WorkerId, + pub(crate) state: T, } impl Worker { @@ -197,338 +196,227 @@ impl DerefMut for Worker { } } -impl Worker> { - /// Start a worker - pub async fn run(self) { - let instance = self.instance; - let monitor = self.state.context.clone(); - self.state.running.store(true, Ordering::Relaxed); - self.state.await; - if let Some(ctx) = monitor.as_ref() { - ctx.notify(Worker { - state: Event::Exit, - id: WorkerId::new_with_instance(self.id.name, instance), +impl Worker { + /// Allows workers to emit events + pub fn emit(&self, event: Event) -> bool { + if let Some(handler) = self.state.event_handler.read().unwrap().as_ref() { + handler(Worker { + id: self.id().clone(), + state: event, }); - }; + return true; + } + false } } -impl Worker> { - /// Start a worker with a custom executor - pub fn with_executor(self, executor: E) -> Worker> - where - S: Service> + Send + 'static + Clone, - P: Backend> + 'static, - J: Send + 'static + Sync, - S::Future: Send, - S::Response: 'static, - S::Error: Send + Sync + 'static + Into, - S::Error: Send + Sync + 'static + Into, -

>>::Stream: Unpin + Send + 'static, - E: Executor + Clone + Send + 'static + Sync, - { - let notifier = Notify::new(); - let service = self.state.service; - let backend = self.state.backend; - let poller = backend.poll(self.id.clone()); - let polling = poller.heartbeat.shared(); - let worker_stream = WorkerStream::new(poller.stream, notifier.clone()) - .into_future() - .shared(); - Self::build_worker_instance( - WorkerId::new(self.id.name()), - service.clone(), - executor.clone(), - notifier.clone(), - polling.clone(), - worker_stream.clone(), - None, - ) +impl FromRequest> for Worker { + fn from_request(req: &Request) -> Result { + req.parts.data.get_checked().cloned() } +} - /// Run as a monitored worker - pub fn with_monitor(self, monitor: &Monitor) -> Worker> - where - S: Service> + Send + 'static + Clone, - P: Backend> + 'static, - J: Send + 'static + Sync, - S::Future: Send, - S::Response: 'static, - S::Error: Send + Sync + 'static + Into, -

>>::Stream: Unpin + Send + 'static, - E: Executor + Clone + Send + 'static + Sync, - { - let notifier = Notify::new(); - let service = self.state.service; - let backend = self.state.backend; - let executor = monitor.executor().clone(); - let context = monitor.context().clone(); - let poller = backend.poll(self.id.clone()); - let polling = poller.heartbeat.shared(); - let worker_stream = WorkerStream::new(poller.stream, notifier.clone()) - .into_future() - .shared(); - Self::build_worker_instance( - WorkerId::new(self.id.name()), - service.clone(), - executor.clone(), - notifier.clone(), - polling.clone(), - worker_stream.clone(), - Some(context.clone()), - ) +impl Worker> { + /// Add an event handler to the worker + pub fn on_event) + Send + Sync + 'static>(self, f: F) -> Self { + let _ = self.event_handler.write().map(|mut res| { + let _ = res.insert(Box::new(f)); + }); + self } - /// Run a specified amounts of instances - pub fn with_monitor_instances( - self, - instances: usize, - monitor: &Monitor, - ) -> Vec>> + fn poll_jobs( + worker: Worker, + service: Svc, + stream: Stm, + ) -> BoxStream<'static, ()> where - S: Service> + Send + 'static + Clone, - P: Backend> + 'static, - J: Send + 'static + Sync, - S::Future: Send, - S::Response: 'static, - S::Error: Send + Sync + 'static + Into, -

>>::Stream: Unpin + Send + 'static, - E: Executor + Clone + Send + 'static + Sync, + Svc: Service, Response = Res> + Send + 'static, + Stm: Stream>, Error>> + Send + Unpin + 'static, + Req: Send + 'static + Sync, + Svc::Future: Send, + Svc::Response: 'static + Send + Sync + Serialize, + Svc::Error: Send + Sync + 'static + Into, + Ctx: Send + 'static + Sync, + Res: 'static, { - let notifier = Notify::new(); - let service = self.state.service; - let backend = self.state.backend; - let executor = monitor.executor().clone(); - let context = monitor.context().clone(); - let poller = backend.poll(self.id.clone()); - let polling = poller.heartbeat.shared(); - let worker_stream = WorkerStream::new(poller.stream, notifier.clone()) - .into_future() - .shared(); - let mut workers = Vec::new(); - - for instance in 0..instances { - workers.push(Self::build_worker_instance( - WorkerId::new_with_instance(self.id.name(), instance), - service.clone(), - executor.clone(), - notifier.clone(), - polling.clone(), - worker_stream.clone(), - Some(context.clone()), - )); - } - - workers + let w = worker.clone(); + let stream = stream.filter_map(move |result| { + let worker = worker.clone(); + + async move { + match result { + Ok(Some(request)) => { + worker.emit(Event::Engage(request.parts.task_id.clone())); + Some(request) + } + Ok(None) => { + worker.emit(Event::Idle); + None + } + Err(err) => { + worker.emit(Event::Error(Box::new(err))); + None + } + } + } + }); + let stream = CallAllUnordered::new(service, stream).map(move |res| { + if let Err(error) = res { + let error = error.into(); + if let Some(Error::MissingData(_)) = error.downcast_ref::() { + w.stop(); + } + w.emit(Event::Error(error)); + } + }); + stream.boxed() } - - /// Run specified worker instances via a specific executor - pub fn with_executor_instances( - self, - instances: usize, - executor: E, - ) -> Vec>> + /// Start a worker + pub fn run(self) -> Runnable where - S: Service> + Send + 'static + Clone, - P: Backend> + 'static, - J: Send + 'static + Sync, + S: Service, Response = Res> + Send + 'static, + P: Backend, Res> + 'static, + Req: Send + 'static + Sync, S::Future: Send, - S::Response: 'static, - S::Error: Send + Sync + 'static + Into, + S::Response: 'static + Send + Sync + Serialize, S::Error: Send + Sync + 'static + Into, -

>>::Stream: Unpin + Send + 'static, - E: Executor + Clone + Send + 'static + Sync, - { - let worker_id = self.id.clone(); - let notifier = Notify::new(); - let service = self.state.service; - let backend = self.state.backend; - let poller = backend.poll(worker_id.clone()); - let polling = poller.heartbeat.shared(); - let worker_stream = WorkerStream::new(poller.stream, notifier.clone()) - .into_future() - .shared(); - - let mut workers = Vec::new(); - for instance in 0..instances { - workers.push(Self::build_worker_instance( - WorkerId::new_with_instance(self.id.name(), instance), - service.clone(), - executor.clone(), - notifier.clone(), - polling.clone(), - worker_stream.clone(), - None, - )); - } - workers - } - - pub(crate) fn build_worker_instance( - id: WorkerId, - service: LS, - executor: E, - notifier: WorkerNotify>, Error>>, - polling: Shared + Send + 'static>, - worker_stream: Shared + Send + 'static>, - context: Option, - ) -> Worker> - where - LS: Service> + Send + 'static + Clone, - LS::Future: Send + 'static, - LS::Response: 'static, - LS::Error: Send + Sync + Into + 'static, - P: Backend>, - E: Executor + Send + Clone + 'static + Sync, - J: Sync + Send + 'static, - S: 'static, - P: 'static, + P::Stream: Unpin + Send + 'static, + P::Layer: Layer, + >::Service: Service, Response = Res> + Send, + <>::Service as Service>>::Future: Send, + <>::Service as Service>>::Error: + Send + Into + Sync, + Ctx: Send + 'static + Sync, + Res: 'static, { - let instance = id.instance.unwrap_or_default(); + let worker_id = self.id().clone(); let ctx = Context { - context, - executor, - instance, running: Arc::default(), task_count: Arc::default(), wakers: Arc::default(), + shutdown: self.state.shutdown, + event_handler: self.state.event_handler.clone(), }; - let worker = Worker { id, state: ctx }; - - let fut = Self::build_instance(instance, service, worker.clone(), notifier); - - worker.spawn(fut); - worker.spawn(polling); - worker.spawn(worker_stream); - worker - } - - pub(crate) async fn build_instance( - instance: usize, - service: LS, - worker: Worker>, - notifier: WorkerNotify>, Error>>, - ) where - LS: Service> + Send + 'static + Clone, - LS::Future: Send + 'static, - LS::Response: 'static, - LS::Error: Send + Sync + Into + 'static, - P: Backend>, - E: Executor + Send + Clone + 'static + Sync, - { - if let Some(ctx) = worker.state.context.as_ref() { - ctx.notify(Worker { - state: Event::Start, - id: WorkerId::new_with_instance(worker.id.name(), instance), - }); + let worker = Worker { + id: worker_id.clone(), + state: ctx.clone(), }; - let worker_layers = ServiceBuilder::new() - .layer(Data::new(worker.id.clone())) - .layer(Data::new(worker.state.clone())); - let mut service = worker_layers.service(service); - worker.running.store(true, Ordering::Relaxed); - let worker_id = worker.id().clone(); - loop { - if worker.is_shutting_down() { - if let Some(ctx) = worker.state.context.as_ref() { - ctx.notify(Worker { - state: Event::Stop, - id: WorkerId::new_with_instance(worker.id.name(), instance), - }); - }; - break; - } - match service.ready().await { - Ok(service) => { - let (sender, receiver) = async_oneshot::oneshot(); - let res = notifier.notify(Worker { - id: WorkerId::new_with_instance(worker.id.name(), instance), - state: FetchNext::new(sender), - }); - - if res.is_ok() { - match receiver.await { - Ok(Ok(Some(req))) => { - let fut = service.call(req); - let worker_id = worker_id.clone(); - let state = worker.state.clone(); - worker.spawn(fut.map(move |res| { - if let Err(e) = res { - if let Some(ctx) = state.context.as_ref() { - ctx.notify(Worker { - state: Event::Error(e.into()), - id: WorkerId::new_with_instance( - worker_id.name(), - instance, - ), - }); - }; - } - })); - } - Ok(Err(e)) => { - if let Some(ctx) = worker.state.context.as_ref() { - ctx.notify(Worker { - state: Event::Error(Box::new(e)), - id: WorkerId::new_with_instance(worker.id.name(), instance), - }); - }; - } - Ok(Ok(None)) => { - if let Some(ctx) = worker.state.context.as_ref() { - ctx.notify(Worker { - state: Event::Idle, - id: WorkerId::new_with_instance(worker.id.name(), instance), - }); - }; - } - Err(_) => { - // Listener was dropped, no need to notify - } - } - } - } - Err(e) => { - if let Some(ctx) = worker.state.context.as_ref() { - ctx.notify(Worker { - state: Event::Error(e.into()), - id: WorkerId::new_with_instance(worker.id.name(), instance), - }); - }; - } + let backend = self.state.backend; + let service = self.state.service; + let poller = backend.poll::(&worker); + let stream = poller.stream; + let heartbeat = poller.heartbeat.boxed(); + let layer = poller.layer; + let service = ServiceBuilder::new() + .layer(TrackerLayer::new(worker.state.clone())) + .layer(Data::new(worker.clone())) + .layer(layer) + .service(service); + + Runnable { + poller: Self::poll_jobs(worker.clone(), service, stream), + heartbeat, + worker, + running: false, + } + } +} + +/// A `Runnable` represents a unit of work that manages a worker's lifecycle and execution flow. +/// +/// The `Runnable` struct is responsible for coordinating the core tasks of a worker, such as polling for jobs, +/// maintaining heartbeats, and tracking its running state. It integrates various components required for +/// the worker to operate effectively within an asynchronous runtime. +#[must_use = "A Runnable must be awaited of no jobs will be consumed"] +pub struct Runnable { + poller: BoxStream<'static, ()>, + heartbeat: BoxFuture<'static, ()>, + worker: Worker, + running: bool, +} + +impl Runnable { + /// Returns a handle to the worker, allowing control and functionality like stopping + pub fn get_handle(&self) -> Worker { + self.worker.clone() + } +} + +impl fmt::Debug for Runnable { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Runnable") + .field("poller", &"") + .field("heartbeat", &"") + .field("worker", &self.worker) + .field("running", &self.running) + .finish() + } +} + +impl Future for Runnable { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let this = self.get_mut(); + let poller = &mut this.poller; + let heartbeat = &mut this.heartbeat; + let worker = &mut this.worker; + + let poller_future = async { while (poller.next().await).is_some() {} }; + + if !this.running { + worker.running.store(true, Ordering::Relaxed); + this.running = true; + worker.emit(Event::Start); + } + let combined = Box::pin(join(poller_future, heartbeat.as_mut())); + + let mut combined = select( + combined, + worker.state.clone().map(|_| worker.emit(Event::Stop)), + ) + .boxed(); + match Pin::new(&mut combined).poll(cx) { + Poll::Ready(_) => { + worker.emit(Event::Exit); + Poll::Ready(()) } + Poll::Pending => Poll::Pending, } } } + /// Stores the Workers context -#[derive(Clone)] -pub struct Context { - context: Option, - executor: E, +#[derive(Clone, Default)] +pub struct Context { task_count: Arc, wakers: Arc>>, running: Arc, - instance: usize, + shutdown: Option, + event_handler: EventHandler, } -impl fmt::Debug for Context { +impl fmt::Debug for Context { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("WorkerContext") .field("shutdown", &["Shutdown handle"]) - .field("instance", &self.instance) + .field("task_count", &self.task_count) + .field("running", &self.running) .finish() } } pin_project! { - struct Tracked { - worker: Context, + /// A future tracked by the worker + pub struct Tracked { + ctx: Context, #[pin] task: F, } } -impl Future for Tracked { +impl Future for Tracked { type Output = F::Output; fn poll(self: Pin<&mut Self>, cx: &mut TaskCtx<'_>) -> Poll { @@ -536,7 +424,7 @@ impl Future for Tracked { match this.task.poll(cx) { res @ Poll::Ready(_) => { - this.worker.end_task(); + this.ctx.end_task(); res } Poll::Pending => Poll::Pending, @@ -544,21 +432,17 @@ impl Future for Tracked { } } -impl Context { - /// Allows spawning of futures that will be gracefully shutdown by the worker - pub fn spawn(&self, future: impl Future + Send + 'static) { - self.executor.spawn(self.track(future)); - } - - fn track>(&self, task: F) -> Tracked { +impl Context { + /// Start a task that is tracked by the worker + pub fn track(&self, task: F) -> Tracked { self.start_task(); Tracked { - worker: self.clone(), + ctx: self.clone(), task, } } - /// Calling this function triggers shutting down the worker + /// Calling this function triggers shutting down the worker while waiting for any tasks to complete pub fn stop(&self) { self.running.store(false, Ordering::Relaxed); self.wake() @@ -569,7 +453,7 @@ impl Context { } fn end_task(&self) { - if self.task_count.fetch_sub(1, Ordering::Relaxed) == WORKER_FUTURES { + if self.task_count.fetch_sub(1, Ordering::Relaxed) == 1 { self.wake(); } } @@ -587,12 +471,23 @@ impl Context { self.running.load(Ordering::Relaxed) } + /// Returns the current futures in the worker domain + /// This include futures spawned via `worker.track` + pub fn task_count(&self) -> usize { + self.task_count.load(Ordering::Relaxed) + } + + /// Returns whether the worker has pending tasks + pub fn has_pending_tasks(&self) -> bool { + self.task_count.load(Ordering::Relaxed) > 0 + } + /// Is the shutdown token called pub fn is_shutting_down(&self) -> bool { - self.context + self.shutdown .as_ref() - .map(|s| s.shutdown().is_shutting_down()) - .unwrap_or(false) + .map(|s| !self.is_running() || s.is_shutting_down()) + .unwrap_or(!self.is_running()) } fn add_waker(&self, cx: &mut TaskCtx<'_>) { @@ -604,22 +499,13 @@ impl Context { } } -impl FromData for Context {} - -impl Future for Context { +impl Future for Context { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut TaskCtx<'_>) -> Poll<()> { - let running = self.is_running(); let task_count = self.task_count.load(Ordering::Relaxed); - if self.is_shutting_down() || !running { - if task_count <= WORKER_FUTURES { - self.stop(); - Poll::Ready(()) - } else { - self.add_waker(cx); - Poll::Pending - } + if self.is_shutting_down() && task_count == 0 { + Poll::Ready(()) } else { self.add_waker(cx); Poll::Pending @@ -627,18 +513,53 @@ impl Future for Context { } } -#[cfg(test)] -mod tests { - use std::{io, ops::Deref, sync::atomic::AtomicUsize, time::Duration}; +#[derive(Debug, Clone)] +struct TrackerLayer { + ctx: Context, +} - #[derive(Debug, Clone)] - struct TokioTestExecutor; +impl TrackerLayer { + fn new(ctx: Context) -> Self { + Self { ctx } + } +} - impl Executor for TokioTestExecutor { - fn spawn(&self, future: impl Future + Send + 'static) { - tokio::spawn(future); +impl Layer for TrackerLayer { + type Service = TrackerService; + + fn layer(&self, service: S) -> Self::Service { + TrackerService { + ctx: self.ctx.clone(), + service, } } +} +#[derive(Debug, Clone)] +struct TrackerService { + ctx: Context, + service: S, +} + +impl Service> for TrackerService +where + S: Service>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = Tracked; + + fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, request: Request) -> Self::Future { + self.ctx.track(self.service.call(request)) + } +} + +#[cfg(test)] +mod tests { + use std::{ops::Deref, sync::atomic::AtomicUsize}; use crate::{ builder::{WorkerBuilder, WorkerFactoryFn}, @@ -656,30 +577,27 @@ mod tests { assert_eq!( WorkerId::from_str("worker").unwrap(), WorkerId { - instance: None, name: "worker".to_string() } ); assert_eq!( WorkerId::from_str("worker-0").unwrap(), WorkerId { - instance: Some(0), - name: "worker".to_string() + name: "worker-0".to_string() } ); assert_eq!( WorkerId::from_str("complex&*-worker-name-0").unwrap(), WorkerId { - instance: Some(0), - name: "complex&*-worker-name".to_string() + name: "complex&*-worker-name-0".to_string() } ); } #[tokio::test] async fn it_works() { - let backend = MemoryStorage::new(); - let handle = backend.clone(); + let in_memory = MemoryStorage::new(); + let mut handle = in_memory.clone(); tokio::spawn(async move { for i in 0..ITEMS { @@ -697,25 +615,16 @@ mod tests { } } - async fn task(job: u32, count: Data) -> Result<(), io::Error> { + async fn task(job: u32, count: Data, worker: Worker) { count.fetch_add(1, Ordering::Relaxed); if job == ITEMS - 1 { - tokio::time::sleep(Duration::from_secs(1)).await; + worker.stop(); } - Ok(()) } let worker = WorkerBuilder::new("rango-tango") - // .chain(|svc| svc.timeout(Duration::from_millis(500))) .data(Count::default()) - .source(backend); + .backend(in_memory); let worker = worker.build_fn(task); - let worker = worker.with_executor(TokioTestExecutor); - let w = worker.clone(); - - tokio::spawn(async move { - tokio::time::sleep(Duration::from_secs(3)).await; - w.stop(); - }); worker.run().await; } } diff --git a/packages/apalis-core/src/worker/stream.rs b/packages/apalis-core/src/worker/stream.rs deleted file mode 100644 index 820e289..0000000 --- a/packages/apalis-core/src/worker/stream.rs +++ /dev/null @@ -1,56 +0,0 @@ -use futures::{Future, Stream, StreamExt}; -use std::pin::Pin; -use std::task::{Context, Poll}; - -use super::WorkerNotify; - -// Define your struct -pub(crate) struct WorkerStream -where - S: Stream, -{ - notify: WorkerNotify, - stream: S, -} - -impl WorkerStream -where - S: Stream + Unpin + 'static, -{ - pub(crate) fn new(stream: S, notify: WorkerNotify) -> Self { - Self { notify, stream } - } - pub(crate) fn into_future(mut self) -> impl Future { - Box::pin(async move { - loop { - self.next().await; - } - }) - } -} - -impl Stream for WorkerStream -where - S: Stream + Unpin, -{ - type Item = (); - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - // Poll for the next listener - match this.notify.poll_next_unpin(cx) { - Poll::Ready(Some(mut worker)) => { - match this.stream.poll_next_unpin(cx) { - Poll::Ready(Some(item)) => { - if let Err(_e) = worker.send(item) {} - Poll::Ready(Some(())) - } - Poll::Ready(None) => Poll::Ready(None), // Inner stream is exhausted - Poll::Pending => Poll::Pending, - } - } - Poll::Ready(None) => Poll::Ready(None), // No more workers - Poll::Pending => Poll::Pending, - } - } -} diff --git a/packages/apalis-cron/Cargo.toml b/packages/apalis-cron/Cargo.toml index a13ac9f..e9c607b 100644 --- a/packages/apalis-cron/Cargo.toml +++ b/packages/apalis-cron/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "apalis-cron" -version = "0.5.5" +version = "0.6.0" edition.workspace = true repository.workspace = true authors = ["Njuguna Mureithi "] @@ -10,9 +10,8 @@ description = "A simple yet extensible library for cron-like job scheduling for # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -apalis-core = { path = "../../packages/apalis-core", version = "0.5.5", default-features = false, features = [ +apalis-core = { path = "../../packages/apalis-core", version = "0.6.0", default-features = false, features = [ "sleep", - "json", ] } cron = "0.13.0" futures = "0.3.30" @@ -22,7 +21,7 @@ chrono = { version = "0.4.38", default-features = false, features = [ "serde", ] } async-stream = "0.3.5" -async-std = { version = "1.12.0", optional = true } +async-std = { version = "1.13.0", optional = true } [dev-dependencies] tokio = { version = "1", features = ["macros"] } @@ -30,10 +29,6 @@ apalis-core = { path = "../../packages/apalis-core" } apalis = { path = "../../", default-features = false, features = ["retry"] } serde = { version = "1.0", features = ["derive"] } -[features] -default = ["tokio-comp"] -async-std-comp = ["async-std"] -tokio-comp = ["tokio/net"] [package.metadata.docs.rs] # defines the configuration attribute `docsrs` diff --git a/packages/apalis-cron/README.md b/packages/apalis-cron/README.md index ee1158b..deb30ff 100644 --- a/packages/apalis-cron/README.md +++ b/packages/apalis-cron/README.md @@ -6,34 +6,40 @@ Since apalis-cron is build on top of apalis which supports tower middleware, you ## Example ```rust -use apalis::prelude::*; -use apalis::layers::{Extension, DefaultRetryPolicy, RetryLayer}; -use apalis::cron::Schedule; +use apalis::layers::retry::RetryLayer; +use apalis::layers::retry::RetryPolicy; use tower::ServiceBuilder; +use apalis_cron::Schedule; use std::str::FromStr; +use apalis::prelude::*; +use apalis_cron::CronStream; +use chrono::{DateTime, Utc}; -#[derive(Default, Debug, Clone)] -struct Reminder; +#[derive(Clone)] +struct FakeService; +impl FakeService { + fn execute(&self, item: Reminder){} +} -impl Job for Reminder { - const NAME: &'static str = "reminder::DailyReminder"; +#[derive(Default, Debug, Clone)] +struct Reminder(DateTime); +impl From> for Reminder { + fn from(t: DateTime) -> Self { + Reminder(t) + } } -async fn send_reminder(job: Reminder, ctx: JobContext) { - // Do reminder stuff +async fn send_reminder(job: Reminder, svc: Data) { + svc.execute(job); } #[tokio::main] async fn main() { let schedule = Schedule::from_str("@daily").unwrap(); - - let service = ServiceBuilder::new() - .layer(RetryLayer::new(DefaultRetryPolicy)) - .service(job_fn(send_reminder)); - - let worker = WorkerBuilder::new("daily-cron-worker") - .stream(CronStream::new(schedule).to_stream()) - .build(service); - + let worker = WorkerBuilder::new("morning-cereal") + .retry(RetryPolicy::retries(5)) + .data(FakeService) + .stream(CronStream::new(schedule).into_stream()) + .build_fn(send_reminder); Monitor::new() .register(worker) .run() diff --git a/packages/apalis-cron/src/lib.rs b/packages/apalis-cron/src/lib.rs index 3a16355..4db6d62 100644 --- a/packages/apalis-cron/src/lib.rs +++ b/packages/apalis-cron/src/lib.rs @@ -15,17 +15,12 @@ //! ## Example //! //! ```rust,no_run -//! # use apalis_utils::layers::retry::RetryLayer; -//! # use apalis_utils::layers::retry::DefaultRetryPolicy; -//! # use apalis_core::extensions::Data; -//! # use apalis_core::service_fn::service_fn; +//! # use apalis::layers::retry::RetryLayer; +//! # use apalis::layers::retry::RetryPolicy; //! use tower::ServiceBuilder; //! use apalis_cron::Schedule; //! use std::str::FromStr; -//! # use apalis_core::monitor::Monitor; -//! # use apalis_core::builder::WorkerBuilder; -//! # use apalis_core::builder::WorkerFactoryFn; -//! # use apalis_utils::TokioExecutor; +//! # use apalis::prelude::*; //! use apalis_cron::CronStream; //! use chrono::{DateTime, Utc}; //! @@ -50,11 +45,11 @@ //! async fn main() { //! let schedule = Schedule::from_str("@daily").unwrap(); //! let worker = WorkerBuilder::new("morning-cereal") -//! .layer(RetryLayer::new(DefaultRetryPolicy)) +//! .retry(RetryPolicy::retries(5)) //! .data(FakeService) -//! .stream(CronStream::new(schedule).into_stream()) +//! .backend(CronStream::new(schedule)) //! .build_fn(send_reminder); -//! Monitor::::new() +//! Monitor::new() //! .register(worker) //! .run() //! .await @@ -62,13 +57,25 @@ //! } //! ``` -use apalis_core::data::Extensions; +use apalis_core::backend::Backend; +use apalis_core::error::BoxDynError; +use apalis_core::layers::Identity; +use apalis_core::mq::MessageQueue; +use apalis_core::poller::Poller; use apalis_core::request::RequestStream; -use apalis_core::task::task_id::TaskId; +use apalis_core::storage::Storage; +use apalis_core::task::namespace::Namespace; +use apalis_core::worker::{Context, Worker}; use apalis_core::{error::Error, request::Request}; use chrono::{DateTime, TimeZone, Utc}; pub use cron::Schedule; +use futures::StreamExt; +use pipe::CronPipe; use std::marker::PhantomData; +use std::sync::Arc; + +/// Allows piping of cronjobs to a Storage or MessageQueue +pub mod pipe; /// Represents a stream from a cron schedule with a timezone #[derive(Clone, Debug)] @@ -102,14 +109,14 @@ where } } } -impl CronStream +impl CronStream where - J: From> + Send + Sync + 'static, + Req: From> + Send + Sync + 'static, Tz: TimeZone + Send + Sync + 'static, Tz::Offset: Send + Sync, { /// Convert to consumable - pub fn into_stream(self) -> RequestStream> { + fn into_stream(self) -> RequestStream> { let timezone = self.timezone.clone(); let stream = async_stream::stream! { let mut schedule = self.schedule.upcoming_owned(timezone.clone()); @@ -118,11 +125,13 @@ where match next { Some(next) => { let to_sleep = next - timezone.from_utc_datetime(&Utc::now().naive_utc()); - let to_sleep = to_sleep.to_std().map_err(|e| Error::Failed(e.into()))?; + let to_sleep = to_sleep.to_std().map_err(|e| Error::SourceError(Arc::new(e.into())))?; apalis_core::sleep(to_sleep).await; - let mut data = Extensions::new(); - data.insert(TaskId::new()); - yield Ok(Some(Request::new_with_data(J::from(timezone.from_utc_datetime(&Utc::now().naive_utc())), data))); + let timestamp = timezone.from_utc_datetime(&Utc::now().naive_utc()); + let namespace = Namespace(format!("{}:{timestamp:?}", self.schedule)); + let mut req = Request::new(Req::from(timestamp)); + req.parts.namespace = Some(namespace); + yield Ok(Some(req)); }, None => { yield Ok(None); @@ -132,4 +141,80 @@ where }; Box::pin(stream) } + + /// Push cron job events to a storage and get a consumable Backend + pub fn pipe_to_storage(self, storage: S) -> CronPipe + where + S: Storage + Clone + Send + Sync + 'static, + S::Error: std::error::Error + Send + Sync + 'static, + { + let stream = self + .into_stream() + .then({ + let storage = storage.clone(); + move |res| { + let mut storage = storage.clone(); + async move { + match res { + Ok(Some(req)) => storage + .push(req.args) + .await + .map(|_| ()) + .map_err(|e| Box::new(e) as BoxDynError), + _ => Ok(()), + } + } + } + }) + .boxed(); + + CronPipe { + stream, + inner: storage, + } + } + /// Push cron job events to a message queue and get a consumable Backend + pub fn pipe_to_mq(self, mq: Mq) -> CronPipe + where + Mq: MessageQueue + Clone + Send + Sync + 'static, + Mq::Error: std::error::Error + Send + Sync + 'static, + { + let stream = self + .into_stream() + .then({ + let mq = mq.clone(); + move |res| { + let mut mq = mq.clone(); + async move { + match res { + Ok(Some(req)) => mq + .enqueue(req.args) + .await + .map(|_| ()) + .map_err(|e| Box::new(e) as BoxDynError), + _ => Ok(()), + } + } + } + }) + .boxed(); + + CronPipe { stream, inner: mq } + } +} + +impl Backend, Res> for CronStream +where + Req: From> + Send + Sync + 'static, + Tz: TimeZone + Send + Sync + 'static, + Tz::Offset: Send + Sync, +{ + type Stream = RequestStream>; + + type Layer = Identity; + + fn poll(self, _worker: &Worker) -> Poller { + let stream = self.into_stream(); + Poller::new(stream, futures::future::pending()) + } } diff --git a/packages/apalis-cron/src/pipe.rs b/packages/apalis-cron/src/pipe.rs new file mode 100644 index 0000000..bad125a --- /dev/null +++ b/packages/apalis-cron/src/pipe.rs @@ -0,0 +1,77 @@ +use apalis_core::backend::Backend; +use apalis_core::error::BoxDynError; +use apalis_core::request::BoxStream; +use apalis_core::{poller::Poller, request::Request, worker::Context, worker::Worker}; +use futures::StreamExt; +use std::{error, fmt}; +use tower::Service; + +/// A generic Pipe that wraps an inner type along with a `RequestStream`. +pub struct CronPipe { + pub(crate) stream: BoxStream<'static, Result<(), BoxDynError>>, + pub(crate) inner: Inner, +} + +impl fmt::Debug for CronPipe { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Pipe") + .field("stream", &">") // Placeholder as `RequestStream` might not implement Debug + .field("inner", &self.inner) + .finish() + } +} + +impl Backend, Res> for CronPipe +where + Inner: Backend, Res>, +{ + type Stream = Inner::Stream; + + type Layer = Inner::Layer; + + fn poll, Response = Res>>( + mut self, + worker: &Worker, + ) -> Poller { + let pipe_heartbeat = async move { while (self.stream.next().await).is_some() {} }; + let inner = self.inner.poll::(worker); + let heartbeat = inner.heartbeat; + + Poller::new_with_layer( + inner.stream, + async { + futures::join!(heartbeat, pipe_heartbeat); + }, + inner.layer, + ) + } +} + +/// A cron error +#[derive(Debug)] +pub struct PipeError { + kind: PipeErrorKind, +} + +/// The kind of pipe error that occurred +#[derive(Debug)] +pub enum PipeErrorKind { + /// The cron stream provided a None + EmptyStream, +} + +impl fmt::Display for PipeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.kind { + PipeErrorKind::EmptyStream => write!(f, "The cron stream provided a None",), + } + } +} + +impl error::Error for PipeError {} + +impl From for PipeError { + fn from(kind: PipeErrorKind) -> PipeError { + PipeError { kind } + } +} diff --git a/packages/apalis-redis/Cargo.toml b/packages/apalis-redis/Cargo.toml index 942e759..ca9a2e9 100644 --- a/packages/apalis-redis/Cargo.toml +++ b/packages/apalis-redis/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "apalis-redis" -version = "0.5.5" +version = "0.6.0" authors = ["Njuguna Mureithi "] edition.workspace = true repository.workspace = true @@ -12,16 +12,17 @@ description = "Redis Storage for apalis: use Redis for background jobs and messa # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -apalis-core = { path = "../../packages/apalis-core", version = "0.5.5", default-features = false, features = [ +apalis-core = { path = "../../packages/apalis-core", version = "0.6.0", default-features = false, features = [ "sleep", "json", ] } -redis = { version = "0.25.3", default-features = false, features = [ +redis = { version = "0.27", default-features = false, features = [ "script", "aio", "connection-manager", ] } serde = "1" +serde_json = "1" log = "0.4.21" chrono = { version = "0.4.38", default-features = false, features = [ "clock", @@ -30,16 +31,17 @@ chrono = { version = "0.4.38", default-features = false, features = [ async-stream = "0.3.5" futures = "0.3.30" tokio = { version = "1", features = ["rt", "net"], optional = true } -async-std = { version = "1.12.0", optional = true } +async-std = { version = "1.13.0", optional = true } async-trait = "0.1.80" +tower = "0.4" +thiserror = "2.0.0" [dev-dependencies] tokio = { version = "1", features = ["macros", "rt-multi-thread"] } email-service = { path = "../../examples/email-service" } -apalis = { path = "../../", default-features = false, features = [ - "tokio-comp", "redis" -] } +apalis = { path = "../../", default-features = false } +apalis-core = { path = "../apalis-core", features = ["test-utils"] } [features] default = ["tokio-comp"] diff --git a/packages/apalis-redis/lua/ack_job.lua b/packages/apalis-redis/lua/done_job.lua similarity index 70% rename from packages/apalis-redis/lua/ack_job.lua rename to packages/apalis-redis/lua/done_job.lua index d451bae..62be5a1 100644 --- a/packages/apalis-redis/lua/ack_job.lua +++ b/packages/apalis-redis/lua/done_job.lua @@ -1,17 +1,20 @@ -- KEYS[1]: this consumer's inflight set -- KEYS[2]: the done jobs set +-- KEYS[3]: the job data hash -- ARGV[1]: the job ID -- ARGV[2]: the current time +-- ARGV[3]: the result of the job --- Returns: nil +-- Returns: bool -- Remove the job from this consumer's inflight set local removed = redis.call("srem", KEYS[1], ARGV[1]) - +local ns = "::result" if removed == 1 then -- Push the job on to the done jobs set redis.call("zadd", KEYS[2], ARGV[2], ARGV[1]) + redis.call("hmset", KEYS[3].. ns, ARGV[1], ARGV[3] ) return true end diff --git a/packages/apalis-redis/lua/kill_job.lua b/packages/apalis-redis/lua/kill_job.lua index 3bc7dd3..ebe8d68 100644 --- a/packages/apalis-redis/lua/kill_job.lua +++ b/packages/apalis-redis/lua/kill_job.lua @@ -1,24 +1,22 @@ -- KEYS[1]: this consumer's inflight set -- KEYS[2]: the dead jobs set -- KEYS[3]: the job data hash - -- ARGV[1]: the job ID -- ARGV[2]: the current time --- ARGV[3]: the serialized job data - +-- ARGV[3]: the result of the job -- Returns: nil - -- Remove the job from this consumer's inflight set local removed = redis.call("srem", KEYS[1], ARGV[1]) if removed == 1 then - -- Push the job on to the dead jobs set - redis.call("zadd", KEYS[2], ARGV[2], ARGV[1]) + -- Push the job on to the dead jobs set + redis.call("zadd", KEYS[2], ARGV[2], ARGV[1]) - -- Reset the job data - redis.call("hset", KEYS[3], ARGV[1], ARGV[3]) + -- Save the result of the job + local ns = "::result" + redis.call("hmset", KEYS[3] .. ns, ARGV[1], ARGV[3]) - return 1 + return 1 end return 0 diff --git a/packages/apalis-redis/lua/retry_job.lua b/packages/apalis-redis/lua/retry_job.lua index 6d13b8d..e3df0e7 100644 --- a/packages/apalis-redis/lua/retry_job.lua +++ b/packages/apalis-redis/lua/retry_job.lua @@ -4,7 +4,7 @@ -- ARGV[1]: the job ID -- ARGV[2]: the time at which to retry --- ARGV[3]: the serialized job data +-- ARGV[3]: the result of the job -- Returns: nil @@ -15,8 +15,15 @@ if removed == 1 then -- Push the job on to the scheduled set redis.call("zadd", KEYS[2], ARGV[2], ARGV[1]) + local job = redis.call('HGET', KEYS[3], ARGV[1]) + -- Reset the job data - redis.call("hset", KEYS[3], ARGV[1], ARGV[3]) + redis.call("hset", KEYS[3], ARGV[1], job) + + -- Save the result of the job + local ns = "::result" + redis.call("hmset", KEYS[3].. ns, ARGV[1], ARGV[4] ) + end return removed diff --git a/packages/apalis-redis/src/expose.rs b/packages/apalis-redis/src/expose.rs new file mode 100644 index 0000000..4e4e426 --- /dev/null +++ b/packages/apalis-redis/src/expose.rs @@ -0,0 +1,258 @@ +use crate::RedisContext; +use crate::RedisStorage; +use apalis_core::backend::BackendExpose; +use apalis_core::backend::Stat; +use apalis_core::backend::WorkerState; +use apalis_core::codec::json::JsonCodec; +use apalis_core::codec::Codec; +use apalis_core::request::Request; +use apalis_core::request::State; +use apalis_core::worker::Worker; +use apalis_core::worker::WorkerId; +use redis::{ErrorKind, Value}; +use serde::{de::DeserializeOwned, Serialize}; + +type RedisCodec = JsonCodec>; + +impl BackendExpose for RedisStorage +where + T: 'static + Serialize + DeserializeOwned + Send + Unpin + Sync, +{ + type Request = Request; + type Error = redis::RedisError; + async fn stats(&self) -> Result { + let mut conn = self.get_connection().clone(); + let queue = self.get_config(); + let script = r#" + local pending_jobs_set = KEYS[1] + local running_jobs_set = KEYS[2] + local dead_jobs_set = KEYS[3] + local failed_jobs_set = KEYS[4] + local success_jobs_set = KEYS[5] + + local pending_count = redis.call('ZCARD', pending_jobs_set) + local running_count = redis.call('ZCARD', running_jobs_set) + local dead_count = redis.call('ZCARD', dead_jobs_set) + local failed_count = redis.call('ZCARD', failed_jobs_set) + local success_count = redis.call('ZCARD', success_jobs_set) + + return {pending_count, running_count, dead_count, failed_count, success_count} + "#; + + let keys = vec![ + queue.inflight_jobs_set().to_string(), + queue.active_jobs_list().to_string(), + queue.dead_jobs_set().to_string(), + queue.failed_jobs_set().to_string(), + queue.done_jobs_set().to_string(), + ]; + + let results: Vec = redis::cmd("EVAL") + .arg(script) + .arg(keys.len().to_string()) + .arg(keys) + .query_async(&mut conn) + .await?; + + Ok(Stat { + pending: results[0], + running: results[1], + dead: results[2], + failed: results[3], + success: results[4], + }) + } + async fn list_jobs( + &self, + status: &State, + page: i32, + ) -> Result, redis::RedisError> { + let mut conn = self.get_connection().clone(); + let queue = self.get_config(); + match status { + State::Pending | State::Scheduled => { + let active_jobs_list = &queue.active_jobs_list(); + let job_data_hash = &queue.job_data_hash(); + let ids: Vec = redis::cmd("LRANGE") + .arg(active_jobs_list) + .arg(((page - 1) * 10).to_string()) + .arg((page * 10).to_string()) + .query_async(&mut conn) + .await?; + + if ids.is_empty() { + return Ok(Vec::new()); + } + let data: Option = redis::cmd("HMGET") + .arg(job_data_hash) + .arg(&ids) + .query_async(&mut conn) + .await?; + + let jobs: Vec> = + deserialize_multiple_jobs::<_, RedisCodec>(data.as_ref()).unwrap(); + Ok(jobs) + } + State::Running => { + let consumers_set = &queue.consumers_set(); + let job_data_hash = &queue.job_data_hash(); + let workers: Vec = redis::cmd("ZRANGE") + .arg(consumers_set) + .arg("0") + .arg("-1") + .query_async(&mut conn) + .await?; + + if workers.is_empty() { + return Ok(Vec::new()); + } + let mut all_jobs = Vec::new(); + for worker in workers { + let ids: Vec = redis::cmd("SMEMBERS") + .arg(&worker) + .query_async(&mut conn) + .await?; + + if ids.is_empty() { + continue; + }; + let data: Option = redis::cmd("HMGET") + .arg(job_data_hash.clone()) + .arg(&ids) + .query_async(&mut conn) + .await?; + + let jobs: Vec> = + deserialize_multiple_jobs::<_, RedisCodec>(data.as_ref()).unwrap(); + all_jobs.extend(jobs); + } + + Ok(all_jobs) + } + State::Done => { + let done_jobs_set = &queue.done_jobs_set(); + let job_data_hash = &queue.job_data_hash(); + let ids: Vec = redis::cmd("ZRANGE") + .arg(done_jobs_set) + .arg(((page - 1) * 10).to_string()) + .arg((page * 10).to_string()) + .query_async(&mut conn) + .await?; + + if ids.is_empty() { + return Ok(Vec::new()); + } + let data: Option = redis::cmd("HMGET") + .arg(job_data_hash) + .arg(&ids) + .query_async(&mut conn) + .await?; + + let jobs: Vec> = + deserialize_multiple_jobs::<_, RedisCodec>(data.as_ref()).unwrap(); + Ok(jobs) + } + // State::Retry => Ok(Vec::new()), + State::Failed => { + let failed_jobs_set = &queue.failed_jobs_set(); + let job_data_hash = &queue.job_data_hash(); + let ids: Vec = redis::cmd("ZRANGE") + .arg(failed_jobs_set) + .arg(((page - 1) * 10).to_string()) + .arg((page * 10).to_string()) + .query_async(&mut conn) + .await?; + if ids.is_empty() { + return Ok(Vec::new()); + } + let data: Option = redis::cmd("HMGET") + .arg(job_data_hash) + .arg(&ids) + .query_async(&mut conn) + .await?; + let jobs: Vec> = + deserialize_multiple_jobs::<_, RedisCodec>(data.as_ref()).unwrap(); + + Ok(jobs) + } + State::Killed => { + let dead_jobs_set = &queue.dead_jobs_set(); + let job_data_hash = &queue.job_data_hash(); + let ids: Vec = redis::cmd("ZRANGE") + .arg(dead_jobs_set) + .arg(((page - 1) * 10).to_string()) + .arg((page * 10).to_string()) + .query_async(&mut conn) + .await?; + + if ids.is_empty() { + return Ok(Vec::new()); + } + let data: Option = redis::cmd("HMGET") + .arg(job_data_hash) + .arg(&ids) + .query_async(&mut conn) + .await?; + + let jobs: Vec> = + deserialize_multiple_jobs::<_, RedisCodec>(data.as_ref()).unwrap(); + + Ok(jobs) + } + } + } + async fn list_workers(&self) -> Result>, redis::RedisError> { + let queue = self.get_config(); + let consumers_set = &queue.consumers_set(); + let mut conn = self.get_connection().clone(); + let workers: Vec = redis::cmd("ZRANGE") + .arg(consumers_set) + .arg("0") + .arg("-1") + .query_async(&mut conn) + .await?; + Ok(workers + .into_iter() + .map(|w| { + Worker::new( + WorkerId::new(w.replace(&format!("{}:", &queue.inflight_jobs_set()), "")), + WorkerState::new::(queue.get_namespace().to_owned()), + ) + }) + .collect()) + } +} + +fn deserialize_multiple_jobs>>( + jobs: Option<&Value>, +) -> Option>> +where + T: DeserializeOwned, +{ + let jobs = match jobs { + None => None, + Some(Value::Array(val)) => Some(val), + _ => { + // error!( + // "Decoding Message Failed: {:?}", + // "unknown result type for next message" + // ); + None + } + }; + + jobs.map(|values| { + values + .iter() + .filter_map(|v| match v { + Value::BulkString(data) => { + let inner = C::decode(data.to_vec()) + .map_err(|e| (ErrorKind::IoError, "Decode error", e.into().to_string())) + .unwrap(); + Some(inner) + } + _ => None, + }) + .collect() + }) +} diff --git a/packages/apalis-redis/src/lib.rs b/packages/apalis-redis/src/lib.rs index 828beb2..fb91bf2 100644 --- a/packages/apalis-redis/src/lib.rs +++ b/packages/apalis-redis/src/lib.rs @@ -8,17 +8,17 @@ //! apalis storage using Redis as a backend //! ```rust,no_run //! use apalis::prelude::*; -//! use apalis::redis::RedisStorage; +//! use apalis_redis::{RedisStorage, Config}; //! use email_service::send_email; //! //! #[tokio::main] //! async fn main() { -//! let conn = apalis::redis::connect("redis://127.0.0.1/").await.unwrap(); +//! let conn = apalis_redis::connect("redis://127.0.0.1/").await.unwrap(); //! let storage = RedisStorage::new(conn); -//! Monitor::::new() +//! Monitor::new() //! .register( //! WorkerBuilder::new("tasty-pear") -//! .source(storage.clone()) +//! .backend(storage.clone()) //! .build_fn(send_email), //! ) //! .run() @@ -27,7 +27,11 @@ //! } //! ``` +mod expose; mod storage; pub use storage::connect; pub use storage::Config; +pub use storage::RedisContext; +pub use storage::RedisPollError; +pub use storage::RedisQueueInfo; pub use storage::RedisStorage; diff --git a/packages/apalis-redis/src/storage.rs b/packages/apalis-redis/src/storage.rs index b9a9c01..e335c21 100644 --- a/packages/apalis-redis/src/storage.rs +++ b/packages/apalis-redis/src/storage.rs @@ -1,26 +1,30 @@ use apalis_core::codec::json::JsonCodec; -use apalis_core::data::Extensions; use apalis_core::error::Error; -use apalis_core::layers::{Ack, AckLayer}; +use apalis_core::layers::{Ack, AckLayer, Service}; use apalis_core::poller::controller::Controller; use apalis_core::poller::stream::BackendStream; use apalis_core::poller::Poller; -use apalis_core::request::{Request, RequestStream}; -use apalis_core::storage::{Job, Storage}; -use apalis_core::task::attempt::Attempt; +use apalis_core::request::{Parts, Request, RequestStream}; +use apalis_core::response::Response; +use apalis_core::service_fn::FromRequest; +use apalis_core::storage::Storage; +use apalis_core::task::namespace::Namespace; use apalis_core::task::task_id::TaskId; -use apalis_core::worker::WorkerId; -use apalis_core::{Backend, Codec}; -use async_stream::try_stream; -use chrono::Utc; -use futures::{FutureExt, TryFutureExt, TryStreamExt}; +use apalis_core::worker::{Event, Worker, WorkerId}; +use apalis_core::{backend::Backend, codec::Codec}; +use chrono::{DateTime, Utc}; +use futures::channel::mpsc::{self, SendError, Sender}; +use futures::{select, FutureExt, SinkExt, StreamExt, TryFutureExt}; use log::*; +use redis::aio::ConnectionLike; use redis::ErrorKind; use redis::{aio::ConnectionManager, Client, IntoConnectionInfo, RedisError, Script, Value}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use std::fmt; +use std::any::type_name; +use std::fmt::{self, Debug}; +use std::io; use std::num::TryFromIntError; -use std::sync::Arc; +use std::time::SystemTime; use std::{marker::PhantomData, time::Duration}; /// Shorthand to create a client and connect @@ -40,22 +44,42 @@ const JOB_DATA_HASH: &str = "{queue}:data"; const SCHEDULED_JOBS_SET: &str = "{queue}:scheduled"; const SIGNAL_LIST: &str = "{queue}:signal"; +/// Represents redis key names for various components of the RedisStorage. +/// +/// This struct defines keys used in Redis to manage jobs and their lifecycle in the storage. #[derive(Clone, Debug)] -struct RedisQueueInfo { - active_jobs_list: String, - consumers_set: String, - dead_jobs_set: String, - done_jobs_set: String, - failed_jobs_set: String, - inflight_jobs_set: String, - job_data_hash: String, - scheduled_jobs_set: String, - signal_list: String, +pub struct RedisQueueInfo { + /// Key for the list of currently active jobs. + pub active_jobs_list: String, + + /// Key for the set of active consumers. + pub consumers_set: String, + + /// Key for the set of jobs that are no longer retryable. + pub dead_jobs_set: String, + + /// Key for the set of jobs that have completed successfully. + pub done_jobs_set: String, + + /// Key for the set of jobs that have failed. + pub failed_jobs_set: String, + + /// Key for the set of jobs that are currently being processed. + pub inflight_jobs_set: String, + + /// Key for the hash storing data for each job. + pub job_data_hash: String, + + /// Key for the set of jobs scheduled for future execution. + pub scheduled_jobs_set: String, + + /// Key for the list used for signaling and communication between consumers and producers. + pub signal_list: String, } #[derive(Clone, Debug)] struct RedisScript { - ack_job: Script, + done_job: Script, enqueue_scheduled: Script, get_jobs: Script, kill_job: Script, @@ -68,72 +92,78 @@ struct RedisScript { vacuum: Script, } -#[derive(Clone, Debug, Serialize, Deserialize)] -struct RedisJob { - ctx: Context, - job: J, +/// The context for a redis storage job +#[derive(Clone, Debug, Serialize, Deserialize, Default)] +pub struct RedisContext { + max_attempts: usize, + lock_by: Option, + run_at: Option, } -impl From> for Request { - fn from(val: RedisJob) -> Self { - let mut data = Extensions::new(); - data.insert(val.ctx.id.clone()); - data.insert(Attempt::new_with_value(val.ctx.attempts)); - data.insert(val.ctx); - Request::new_with_data(val.job, data) +impl FromRequest> for RedisContext { + fn from_request(req: &Request) -> Result { + Ok(req.parts.context.clone()) } } -impl TryFrom> for RedisJob { - type Error = RedisError; - fn try_from(val: Request) -> Result { - let task_id = val - .get::() - .cloned() - .ok_or((ErrorKind::IoError, "Missing TaskId"))?; - let attempts = val.get::().cloned().unwrap_or_default(); - Ok(RedisJob { - job: val.take(), - ctx: Context { - attempts: attempts.current(), - id: task_id, - }, - }) - } -} +/// Errors that can occur while polling a Redis backend. +#[derive(thiserror::Error, Debug)] +pub enum RedisPollError { + /// Error during a keep-alive heartbeat. + #[error("KeepAlive heartbeat encountered an error: `{0}`")] + KeepAliveError(RedisError), -#[derive(Clone, Debug, Serialize, Deserialize)] -struct Context { - id: TaskId, - attempts: usize, + /// Error during enqueueing scheduled tasks. + #[error("EnqueueScheduled heartbeat encountered an error: `{0}`")] + EnqueueScheduledError(RedisError), + + /// Error during polling for the next task or message. + #[error("PollNext heartbeat encountered an error: `{0}`")] + PollNextError(RedisError), + + /// Error during enqueueing tasks for worker consumption. + #[error("Enqueue for worker consumption encountered an error: `{0}`")] + EnqueueError(SendError), + + /// Error during acknowledgment of tasks. + #[error("Ack heartbeat encountered an error: `{0}`")] + AckError(RedisError), + + /// Error during re-enqueuing orphaned tasks. + #[error("ReenqueueOrphaned heartbeat encountered an error: `{0}`")] + ReenqueueOrphanedError(RedisError), } /// Config for a [RedisStorage] #[derive(Clone, Debug)] pub struct Config { - fetch_interval: Duration, + poll_interval: Duration, buffer_size: usize, max_retries: usize, keep_alive: Duration, enqueue_scheduled: Duration, + reenqueue_orphaned_after: Duration, + namespace: String, } impl Default for Config { fn default() -> Self { Self { - fetch_interval: Duration::from_millis(100), + poll_interval: Duration::from_millis(100), buffer_size: 10, max_retries: 5, keep_alive: Duration::from_secs(30), enqueue_scheduled: Duration::from_secs(30), + reenqueue_orphaned_after: Duration::from_secs(300), + namespace: String::from("apalis_redis"), } } } impl Config { - /// Get the rate of polling per unit of time - pub fn get_fetch_interval(&self) -> &Duration { - &self.fetch_interval + /// Get the interval of polling + pub fn get_poll_interval(&self) -> &Duration { + &self.poll_interval } /// Get the number of jobs to fetch @@ -156,101 +186,209 @@ impl Config { &self.enqueue_scheduled } - /// get the fetch interval - pub fn set_fetch_interval(&mut self, fetch_interval: Duration) { - self.fetch_interval = fetch_interval; + /// get the namespace + pub fn get_namespace(&self) -> &String { + &self.namespace + } + + /// get the poll interval + pub fn set_poll_interval(mut self, poll_interval: Duration) -> Self { + self.poll_interval = poll_interval; + self } /// set the buffer setting - pub fn set_buffer_size(&mut self, buffer_size: usize) { + pub fn set_buffer_size(mut self, buffer_size: usize) -> Self { self.buffer_size = buffer_size; + self } /// set the max-retries setting - pub fn set_max_retries(&mut self, max_retries: usize) { + pub fn set_max_retries(mut self, max_retries: usize) -> Self { self.max_retries = max_retries; + self } /// set the keep-alive setting - pub fn set_keep_alive(&mut self, keep_alive: Duration) { + pub fn set_keep_alive(mut self, keep_alive: Duration) -> Self { self.keep_alive = keep_alive; + self } /// get the enqueued setting - pub fn set_enqueue_scheduled(&mut self, enqueue_scheduled: Duration) { + pub fn set_enqueue_scheduled(mut self, enqueue_scheduled: Duration) -> Self { self.enqueue_scheduled = enqueue_scheduled; + self + } + + /// set the namespace for the Storage + pub fn set_namespace(mut self, namespace: &str) -> Self { + self.namespace = namespace.to_string(); + self + } + + /// Returns the Redis key for the list of active jobs associated with the queue. + /// The key is dynamically generated using the namespace of the queue. + /// + /// # Returns + /// A `String` representing the Redis key for the active jobs list. + pub fn active_jobs_list(&self) -> String { + ACTIVE_JOBS_LIST.replace("{queue}", &self.namespace) + } + + /// Returns the Redis key for the set of consumers associated with the queue. + /// The key is dynamically generated using the namespace of the queue. + /// + /// # Returns + /// A `String` representing the Redis key for the consumers set. + pub fn consumers_set(&self) -> String { + CONSUMERS_SET.replace("{queue}", &self.namespace) + } + + /// Returns the Redis key for the set of dead jobs associated with the queue. + /// The key is dynamically generated using the namespace of the queue. + /// + /// # Returns + /// A `String` representing the Redis key for the dead jobs set. + pub fn dead_jobs_set(&self) -> String { + DEAD_JOBS_SET.replace("{queue}", &self.namespace) + } + + /// Returns the Redis key for the set of done jobs associated with the queue. + /// The key is dynamically generated using the namespace of the queue. + /// + /// # Returns + /// A `String` representing the Redis key for the done jobs set. + pub fn done_jobs_set(&self) -> String { + DONE_JOBS_SET.replace("{queue}", &self.namespace) } -} -type InnerCodec = Arc< - Box, Vec, Error = apalis_core::error::Error> + Sync + Send + 'static>, ->; + /// Returns the Redis key for the set of failed jobs associated with the queue. + /// The key is dynamically generated using the namespace of the queue. + /// + /// # Returns + /// A `String` representing the Redis key for the failed jobs set. + pub fn failed_jobs_set(&self) -> String { + FAILED_JOBS_SET.replace("{queue}", &self.namespace) + } + + /// Returns the Redis key for the set of inflight jobs associated with the queue. + /// The key is dynamically generated using the namespace of the queue. + /// + /// # Returns + /// A `String` representing the Redis key for the inflight jobs set. + pub fn inflight_jobs_set(&self) -> String { + INFLIGHT_JOB_SET.replace("{queue}", &self.namespace) + } + + /// Returns the Redis key for the hash storing job data associated with the queue. + /// The key is dynamically generated using the namespace of the queue. + /// + /// # Returns + /// A `String` representing the Redis key for the job data hash. + pub fn job_data_hash(&self) -> String { + JOB_DATA_HASH.replace("{queue}", &self.namespace) + } + + /// Returns the Redis key for the set of scheduled jobs associated with the queue. + /// The key is dynamically generated using the namespace of the queue. + /// + /// # Returns + /// A `String` representing the Redis key for the scheduled jobs set. + pub fn scheduled_jobs_set(&self) -> String { + SCHEDULED_JOBS_SET.replace("{queue}", &self.namespace) + } + + /// Returns the Redis key for the list of signals associated with the queue. + /// The key is dynamically generated using the namespace of the queue. + /// + /// # Returns + /// A `String` representing the Redis key for the signal list. + pub fn signal_list(&self) -> String { + SIGNAL_LIST.replace("{queue}", &self.namespace) + } + + /// Gets the reenqueue_orphaned_after duration. + pub fn reenqueue_orphaned_after(&self) -> Duration { + self.reenqueue_orphaned_after + } + + /// Gets a mutable reference to the reenqueue_orphaned_after. + pub fn reenqueue_orphaned_after_mut(&mut self) -> &mut Duration { + &mut self.reenqueue_orphaned_after + } + + /// Occasionally some workers die, or abandon jobs because of panics. + /// This is the time a task takes before its back to the queue + /// + /// Defaults to 5 minutes + pub fn set_reenqueue_orphaned_after(mut self, after: Duration) -> Self { + self.reenqueue_orphaned_after = after; + self + } +} /// Represents a [Storage] that uses Redis for storage. -pub struct RedisStorage { - conn: ConnectionManager, +pub struct RedisStorage>> { + conn: Conn, job_type: PhantomData, - queue: RedisQueueInfo, scripts: RedisScript, controller: Controller, config: Config, - codec: InnerCodec, + codec: PhantomData, } -impl fmt::Debug for RedisStorage { +impl fmt::Debug for RedisStorage { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("RedisStorage") .field("conn", &"ConnectionManager") .field("job_type", &std::any::type_name::()) - .field("queue", &self.queue) .field("scripts", &self.scripts) .field("config", &self.config) .finish() } } -impl Clone for RedisStorage { +impl Clone for RedisStorage { fn clone(&self) -> Self { Self { conn: self.conn.clone(), job_type: PhantomData, - queue: self.queue.clone(), scripts: self.scripts.clone(), controller: self.controller.clone(), config: self.config.clone(), - codec: self.codec.clone(), + codec: self.codec, } } } -impl RedisStorage { +impl RedisStorage { /// Start a new connection - pub fn new(conn: ConnectionManager) -> Self { - Self::new_with_config(conn, Config::default()) + pub fn new(conn: Conn) -> Self { + Self::new_with_codec::>>( + conn, + Config::default().set_namespace(type_name::()), + ) + } + + /// Start a connection with a custom config + pub fn new_with_config(conn: Conn, config: Config) -> Self { + Self::new_with_codec::>>(conn, config) } - /// Start a new connection providing custom config - pub fn new_with_config(conn: ConnectionManager, config: Config) -> Self { - let name = T::NAME; + /// Start a new connection providing custom config and a codec + pub fn new_with_codec(conn: Conn, config: Config) -> RedisStorage + where + C: Codec + Sync + Send + 'static, + { RedisStorage { conn, job_type: PhantomData, controller: Controller::new(), config, - codec: Arc::new(Box::new(JsonCodec)), - queue: RedisQueueInfo { - active_jobs_list: ACTIVE_JOBS_LIST.replace("{queue}", name), - consumers_set: CONSUMERS_SET.replace("{queue}", name), - dead_jobs_set: DEAD_JOBS_SET.replace("{queue}", name), - done_jobs_set: DONE_JOBS_SET.replace("{queue}", name), - failed_jobs_set: FAILED_JOBS_SET.replace("{queue}", name), - inflight_jobs_set: INFLIGHT_JOB_SET.replace("{queue}", name), - job_data_hash: JOB_DATA_HASH.replace("{queue}", name), - scheduled_jobs_set: SCHEDULED_JOBS_SET.replace("{queue}", name), - signal_list: SIGNAL_LIST.replace("{queue}", name), - }, + codec: PhantomData::, scripts: RedisScript { - ack_job: redis::Script::new(include_str!("../lua/ack_job.lua")), + done_job: redis::Script::new(include_str!("../lua/done_job.lua")), push_job: redis::Script::new(include_str!("../lua/push_job.lua")), retry_job: redis::Script::new(include_str!("../lua/retry_job.lua")), enqueue_scheduled: redis::Script::new(include_str!( @@ -272,155 +410,255 @@ impl RedisStorage { } /// Get current connection - pub fn get_connection(&self) -> ConnectionManager { - self.conn.clone() + pub fn get_connection(&self) -> &Conn { + &self.conn + } + + /// Get the config used by the storage + pub fn get_config(&self) -> &Config { + &self.config } } -impl Backend> - for RedisStorage -{ - type Stream = BackendStream>>; +impl RedisStorage { + /// Get the underlying codec details + pub fn get_codec(&self) -> &PhantomData { + &self.codec + } +} - type Layer = AckLayer, T>; +impl Backend, Res> for RedisStorage +where + T: Serialize + DeserializeOwned + Sync + Send + Unpin + 'static, + Conn: ConnectionLike + Send + Sync + 'static, + Res: Send + Serialize + Sync + 'static, + C: Codec> + Send + 'static, +{ + type Stream = BackendStream>>; - fn common_layer(&self, worker_id: WorkerId) -> Self::Layer { - AckLayer::new(self.clone(), worker_id) - } + type Layer = AckLayer)>, T, RedisContext, Res>; - fn poll(self, worker: WorkerId) -> Poller { - let mut storage = self.clone(); + fn poll>>( + mut self, + worker: &Worker, + ) -> Poller { + let (mut tx, rx) = mpsc::channel(self.config.buffer_size); + let (ack, ack_rx) = mpsc::channel(self.config.buffer_size); + let layer = AckLayer::new(ack); let controller = self.controller.clone(); let config = self.config.clone(); - let stream: RequestStream> = Box::pin( - self.stream_jobs(&worker, config.fetch_interval, config.buffer_size) - .map_err(|e| Error::SourceError(e.into())), - ); + let stream: RequestStream> = Box::pin(rx); + let worker = worker.clone(); + let heartbeat = async move { + let mut reenqueue_orphaned_stm = + apalis_core::interval::interval(config.poll_interval).fuse(); - let keep_alive = async move { - loop { - if let Err(e) = storage.keep_alive(&worker).await { - error!("Could not call keep_alive for Worker [{worker}]: {e}") - } - apalis_core::sleep(config.keep_alive).await; + let mut keep_alive_stm = apalis_core::interval::interval(config.keep_alive).fuse(); + + let mut enqueue_scheduled_stm = + apalis_core::interval::interval(config.enqueue_scheduled).fuse(); + + let mut poll_next_stm = apalis_core::interval::interval(config.poll_interval).fuse(); + + let mut ack_stream = ack_rx.fuse(); + + if let Err(e) = self.keep_alive(worker.id()).await { + worker.emit(Event::Error(Box::new(RedisPollError::KeepAliveError(e)))); } - } - .boxed(); - let mut storage = self.clone(); - let enqueue_scheduled = async move { + loop { - if let Err(e) = storage.enqueue_scheduled(config.buffer_size).await { - error!("Could not call enqueue_scheduled: {e}") - } - apalis_core::sleep(config.enqueue_scheduled).await; + select! { + _ = keep_alive_stm.next() => { + if let Err(e) = self.keep_alive(worker.id()).await { + worker.emit(Event::Error(Box::new(RedisPollError::KeepAliveError(e)))); + } + } + _ = enqueue_scheduled_stm.next() => { + if let Err(e) = self.enqueue_scheduled(config.buffer_size).await { + worker.emit(Event::Error(Box::new(RedisPollError::EnqueueScheduledError(e)))); + } + } + _ = poll_next_stm.next() => { + let res = self.fetch_next(worker.id()).await; + match res { + Err(e) => { + worker.emit(Event::Error(Box::new(RedisPollError::PollNextError(e)))); + } + Ok(res) => { + for job in res { + if let Err(e) = tx.send(Ok(Some(job))).await { + worker.emit(Event::Error(Box::new(RedisPollError::EnqueueError(e)))); + } + } + } + } + + } + id_to_ack = ack_stream.next() => { + if let Some((ctx, res)) = id_to_ack { + if let Err(e) = self.ack(&ctx, &res).await { + worker.emit(Event::Error(Box::new(RedisPollError::AckError(e)))); + } + } + } + _ = reenqueue_orphaned_stm.next() => { + let dead_since = Utc::now() + - chrono::Duration::from_std(config.reenqueue_orphaned_after).unwrap(); + if let Err(e) = self.reenqueue_orphaned((config.buffer_size * 10) as i32, dead_since).await { + worker.emit(Event::Error(Box::new(RedisPollError::ReenqueueOrphanedError(e)))); + } + } + }; } - } - .boxed(); - let heartbeat = async move { - futures::join!(enqueue_scheduled, keep_alive); }; - Poller::new(BackendStream::new(stream, controller), heartbeat.boxed()) + Poller::new_with_layer( + BackendStream::new(stream, controller), + heartbeat.boxed(), + layer, + ) } } -impl Ack for RedisStorage { - type Acknowledger = TaskId; - type Error = RedisError; - async fn ack( - &self, - worker_id: &WorkerId, - task_id: &Self::Acknowledger, - ) -> Result<(), RedisError> { - let mut conn = self.conn.clone(); - let ack_job = self.scripts.ack_job.clone(); - let inflight_set = format!("{}:{}", self.queue.inflight_jobs_set, worker_id); - let done_jobs_set = &self.queue.done_jobs_set.to_string(); +impl Ack for RedisStorage +where + Res: Serialize + Sync + Send + 'static, + T: Sync + Send, + Conn: ConnectionLike + Send + Sync + 'static, + C: Codec> + Send, +{ + type Context = RedisContext; + type AckError = RedisError; + async fn ack(&mut self, ctx: &Self::Context, res: &Response) -> Result<(), RedisError> { + let inflight_set = format!( + "{}:{}", + self.config.inflight_jobs_set(), + ctx.lock_by.clone().unwrap() + ); let now: i64 = Utc::now().timestamp(); - - ack_job - .key(inflight_set) - .key(done_jobs_set) - .arg(task_id.to_string()) - .arg(now) - .invoke_async(&mut conn) - .await + let task_id = res.task_id.to_string(); + match &res.inner { + Ok(success_res) => { + let done_job = self.scripts.done_job.clone(); + let done_jobs_set = &self.config.done_jobs_set(); + done_job + .key(inflight_set) + .key(done_jobs_set) + .key(self.config.job_data_hash()) + .arg(task_id) + .arg(now) + .arg(C::encode(success_res).map_err(Into::into).unwrap()) + .invoke_async(&mut self.conn) + .await + } + Err(e) => match e { + Error::Abort(e) => { + let kill_job = self.scripts.kill_job.clone(); + let kill_jobs_set = &self.config.dead_jobs_set(); + kill_job + .key(inflight_set) + .key(kill_jobs_set) + .key(self.config.job_data_hash()) + .arg(task_id) + .arg(now) + .arg(e.to_string()) + .invoke_async(&mut self.conn) + .await + } + _ => { + let retry_job = self.scripts.retry_job.clone(); + let retry_jobs_set = &self.config.scheduled_jobs_set(); + retry_job + .key(inflight_set) + .key(retry_jobs_set) + .key(self.config.job_data_hash()) + .arg(task_id) + .arg(now) + .arg(e.to_string()) + .invoke_async(&mut self.conn) + .await + } + }, + } } } -impl RedisStorage { - fn stream_jobs( - &self, +impl RedisStorage +where + T: DeserializeOwned + Send + Unpin + Send + Sync + 'static, + Conn: ConnectionLike + Send + Sync + 'static, + C: Codec>, +{ + async fn fetch_next( + &mut self, worker_id: &WorkerId, - interval: Duration, - buffer_size: usize, - ) -> RequestStream> { - let mut conn = self.conn.clone(); + ) -> Result>, RedisError> { let fetch_jobs = self.scripts.get_jobs.clone(); - let consumers_set = self.queue.consumers_set.to_string(); - let active_jobs_list = self.queue.active_jobs_list.to_string(); - let job_data_hash = self.queue.job_data_hash.to_string(); - let inflight_set = format!("{}:{}", self.queue.inflight_jobs_set, worker_id); - let signal_list = self.queue.signal_list.to_string(); - let codec = self.codec.clone(); - Box::pin(try_stream! { - loop { - apalis_core::sleep(interval).await; - let result = fetch_jobs - .key(&consumers_set) - .key(&active_jobs_list) - .key(&inflight_set) - .key(&job_data_hash) - .key(&signal_list) - .arg(buffer_size) // No of jobs to fetch - .arg(&inflight_set) - .invoke_async::<_, Vec>(&mut conn).await; - match result { - Ok(jobs) => { - for job in jobs { - yield deserialize_job(&job).map(|res| codec.decode(res)).transpose()?.map(Into::into) - } - }, - Err(e) => { - warn!("An error occurred during streaming jobs: {e}"); - } - } - + let consumers_set = self.config.consumers_set(); + let active_jobs_list = self.config.active_jobs_list(); + let job_data_hash = self.config.job_data_hash(); + let inflight_set = format!("{}:{}", self.config.inflight_jobs_set(), worker_id); + let signal_list = self.config.signal_list(); + let namespace = &self.config.namespace; + + let result = fetch_jobs + .key(&consumers_set) + .key(&active_jobs_list) + .key(&inflight_set) + .key(&job_data_hash) + .key(&signal_list) + .arg(self.config.buffer_size) // No of jobs to fetch + .arg(&inflight_set) + .invoke_async::>(&mut self.conn) + .await; + match result { + Ok(jobs) => { + let mut processed = vec![]; + for job in jobs { + let bytes = deserialize_job(&job)?; + let mut request: Request = C::decode(bytes.to_vec()) + .map_err(|e| build_error(&e.into().to_string()))?; + request.parts.context.lock_by = Some(worker_id.clone()); + request.parts.namespace = Some(Namespace(namespace.clone())); + processed.push(request) + } + Ok(processed) + } + Err(e) => { + warn!("An error occurred during streaming jobs: {e}"); + Err(e) } - }) + } } } -fn deserialize_job(job: &Value) -> Option<&Vec> { - let job = match job { - job @ Value::Data(_) => Some(job), - Value::Bulk(val) => val.first(), - _ => { - error!( - "Decoding Message Failed: {:?}", - "unknown result type for next message" - ); - None - } - }; +fn build_error(message: &str) -> RedisError { + RedisError::from(io::Error::new(io::ErrorKind::InvalidData, message)) +} +fn deserialize_job(job: &Value) -> Result<&Vec, RedisError> { match job { - Some(Value::Data(v)) => Some(v), - None => None, - _ => { - error!("Decoding Message Failed: {:?}", "Expected Data(&Vec)"); - None - } + Value::BulkString(bytes) => Ok(bytes), + Value::Array(val) | Value::Set(val) => val + .first() + .and_then(|val| { + if let Value::BulkString(bytes) = val { + Some(bytes) + } else { + None + } + }) + .ok_or(build_error("Value::Bulk: Invalid data returned by storage")), + _ => Err(build_error("unknown result type for next message")), } } -impl RedisStorage { +impl RedisStorage { async fn keep_alive(&mut self, worker_id: &WorkerId) -> Result<(), RedisError> { - let mut conn = self.conn.clone(); let register_consumer = self.scripts.register_consumer.clone(); - let inflight_set = format!("{}:{}", self.queue.inflight_jobs_set, worker_id); - let consumers_set = self.queue.consumers_set.to_string(); + let inflight_set = format!("{}:{}", self.config.inflight_jobs_set(), worker_id); + let consumers_set = self.config.consumers_set(); let now: i64 = Utc::now().timestamp(); @@ -428,158 +666,136 @@ impl RedisStorage { .key(consumers_set) .arg(now) .arg(inflight_set) - .invoke_async(&mut conn) + .invoke_async(&mut self.conn) .await } } -impl Storage for RedisStorage +impl Storage for RedisStorage where - T: Serialize + DeserializeOwned + Send + 'static + Unpin + Job + Sync, + T: Serialize + DeserializeOwned + Send + 'static + Unpin + Sync, + Conn: ConnectionLike + Send + Sync + 'static, + C: Codec> + Send + 'static, { type Job = T; type Error = RedisError; - type Identifier = TaskId; + type Context = RedisContext; - async fn push(&mut self, job: Self::Job) -> Result { - let mut conn = self.conn.clone(); + async fn push_request( + &mut self, + req: Request, + ) -> Result, RedisError> { + let conn = &mut self.conn; let push_job = self.scripts.push_job.clone(); - let job_data_hash = self.queue.job_data_hash.to_string(); - let active_jobs_list = self.queue.active_jobs_list.to_string(); - let signal_list = self.queue.signal_list.to_string(); - let job_id = TaskId::new(); - let ctx = Context { - attempts: 0, - id: job_id.clone(), - }; - let job = self - .codec - .encode(&RedisJob { ctx, job }) - .map_err(|e| (ErrorKind::IoError, "Encode error", e.to_string()))?; + let job_data_hash = self.config.job_data_hash(); + let active_jobs_list = self.config.active_jobs_list(); + let signal_list = self.config.signal_list(); + + let job = C::encode(&req) + .map_err(|e| (ErrorKind::IoError, "Encode error", e.into().to_string()))?; push_job .key(job_data_hash) .key(active_jobs_list) .key(signal_list) - .arg(job_id.to_string()) + .arg(req.parts.task_id.to_string()) .arg(job) - .invoke_async(&mut conn) + .invoke_async(conn) .await?; - Ok(job_id.clone()) + Ok(req.parts) } - async fn schedule(&mut self, job: Self::Job, on: i64) -> Result { - let mut conn = self.conn.clone(); + async fn schedule_request( + &mut self, + req: Request, + on: i64, + ) -> Result, RedisError> { let schedule_job = self.scripts.schedule_job.clone(); - let job_data_hash = self.queue.job_data_hash.to_string(); - let scheduled_jobs_set = self.queue.scheduled_jobs_set.to_string(); - let job_id = TaskId::new(); - let ctx = Context { - attempts: 0, - id: job_id.clone(), - }; - let job = RedisJob { job, ctx }; - let job = self - .codec - .encode(&job) - .map_err(|e| (ErrorKind::IoError, "Encode error", e.to_string()))?; + let job_data_hash = self.config.job_data_hash(); + let scheduled_jobs_set = self.config.scheduled_jobs_set(); + let job = C::encode(&req) + .map_err(|e| (ErrorKind::IoError, "Encode error", e.into().to_string()))?; schedule_job .key(job_data_hash) .key(scheduled_jobs_set) - .arg(job_id.to_string()) + .arg(req.parts.task_id.to_string()) .arg(job) .arg(on) - .invoke_async(&mut conn) + .invoke_async(&mut self.conn) .await?; - Ok(job_id.clone()) + Ok(req.parts) } - async fn len(&self) -> Result { - let mut conn = self.conn.clone(); + async fn len(&mut self) -> Result { let all_jobs: i64 = redis::cmd("HLEN") - .arg(self.queue.job_data_hash.to_string()) - .query_async(&mut conn) + .arg(self.config.job_data_hash()) + .query_async(&mut self.conn) .await?; let done_jobs: i64 = redis::cmd("ZCOUNT") - .arg(self.queue.done_jobs_set.to_owned()) + .arg(self.config.done_jobs_set()) .arg("-inf") .arg("+inf") - .query_async(&mut conn) + .query_async(&mut self.conn) .await?; Ok(all_jobs - done_jobs) } - async fn fetch_by_id(&self, job_id: &TaskId) -> Result>, RedisError> { - let mut conn = self.conn.clone(); + async fn fetch_by_id( + &mut self, + job_id: &TaskId, + ) -> Result>, RedisError> { let data: Value = redis::cmd("HMGET") - .arg(self.queue.job_data_hash.to_string()) + .arg(self.config.job_data_hash()) .arg(job_id.to_string()) - .query_async(&mut conn) + .query_async(&mut self.conn) .await?; - let job = deserialize_job(&data); - match job { - None => Err(RedisError::from(( - ErrorKind::ResponseError, - "Invalid data returned by storage", - ))), - Some(bytes) => { - let inner = self - .codec - .decode(bytes) - .map_err(|e| (ErrorKind::IoError, "Decode error", e.to_string()))?; - Ok(Some(inner.into())) - } - } + let bytes = deserialize_job(&data)?; + + let inner: Request = C::decode(bytes.to_vec()) + .map_err(|e| (ErrorKind::IoError, "Decode error", e.into().to_string()))?; + Ok(Some(inner)) } - async fn update(&self, job: Request) -> Result<(), RedisError> { - let job = job.try_into()?; - let mut conn = self.conn.clone(); - let bytes = self - .codec - .encode(&job) - .map_err(|e| (ErrorKind::IoError, "Encode error", e.to_string()))?; + async fn update(&mut self, job: Request) -> Result<(), RedisError> { + let task_id = job.parts.task_id.to_string(); + let bytes = C::encode(&job) + .map_err(|e| (ErrorKind::IoError, "Encode error", e.into().to_string()))?; let _: i64 = redis::cmd("HSET") - .arg(self.queue.job_data_hash.to_string()) - .arg(job.ctx.id.to_string()) + .arg(self.config.job_data_hash()) + .arg(task_id) .arg(bytes) - .query_async(&mut conn) + .query_async(&mut self.conn) .await?; Ok(()) } - async fn reschedule(&mut self, job: Request, wait: Duration) -> Result<(), RedisError> { - let mut conn = self.conn.clone(); + async fn reschedule( + &mut self, + job: Request, + wait: Duration, + ) -> Result<(), RedisError> { let schedule_job = self.scripts.schedule_job.clone(); - let job_id = job - .get::() - .cloned() - .ok_or((ErrorKind::IoError, "Missing TaskId"))?; - let worker_id = job - .get::() - .cloned() - .ok_or((ErrorKind::IoError, "Missing WorkerId"))?; - let job = self - .codec - .encode(&(job.try_into()?)) - .map_err(|e| (ErrorKind::IoError, "Encode error", e.to_string()))?; - let job_data_hash = self.queue.job_data_hash.to_string(); - let scheduled_jobs_set = self.queue.scheduled_jobs_set.to_string(); + let job_id = &job.parts.task_id; + let worker_id = &job.parts.context.lock_by.clone().unwrap(); + let job = C::encode(&job) + .map_err(|e| (ErrorKind::IoError, "Encode error", e.into().to_string()))?; + let job_data_hash = self.config.job_data_hash(); + let scheduled_jobs_set = self.config.scheduled_jobs_set(); let on: i64 = Utc::now().timestamp(); let wait: i64 = wait .as_secs() .try_into() .map_err(|e: TryFromIntError| (ErrorKind::IoError, "Duration error", e.to_string()))?; - let inflight_set = format!("{}:{}", self.queue.inflight_jobs_set, worker_id); - let failed_jobs_set = self.queue.failed_jobs_set.to_string(); + let inflight_set = format!("{}:{}", self.config.inflight_jobs_set(), worker_id); + let failed_jobs_set = self.config.failed_jobs_set(); redis::cmd("SREM") .arg(inflight_set) .arg(job_id.to_string()) - .query_async(&mut conn) + .query_async(&mut self.conn) .await?; redis::cmd("ZADD") .arg(failed_jobs_set) .arg(on) .arg(job_id.to_string()) - .query_async(&mut conn) + .query_async(&mut self.conn) .await?; schedule_job .key(job_data_hash) @@ -587,58 +803,57 @@ where .arg(job_id.to_string()) .arg(job) .arg(on + wait) - .invoke_async(&mut conn) + .invoke_async(&mut self.conn) .await } - async fn is_empty(&self) -> Result { + async fn is_empty(&mut self) -> Result { self.len().map_ok(|res| res == 0).await } - async fn vacuum(&self) -> Result { + async fn vacuum(&mut self) -> Result { let vacuum_script = self.scripts.vacuum.clone(); - let mut conn = self.conn.clone(); - vacuum_script - .key(self.queue.done_jobs_set.clone()) - .key(self.queue.job_data_hash.clone()) - .invoke_async(&mut conn) + .key(self.config.dead_jobs_set()) + .key(self.config.job_data_hash()) + .invoke_async(&mut self.conn) .await } } -impl RedisStorage { +impl RedisStorage +where + Conn: ConnectionLike + Send + Sync + 'static, + C: Codec> + Send + 'static, +{ /// Attempt to retry a job pub async fn retry(&mut self, worker_id: &WorkerId, task_id: &TaskId) -> Result where - T: Send + DeserializeOwned + Serialize + Job + Unpin + Sync + 'static, + T: Send + DeserializeOwned + Serialize + Unpin + Sync + 'static, { - let mut conn = self.conn.clone(); let retry_job = self.scripts.retry_job.clone(); - let inflight_set = format!("{}:{}", self.queue.inflight_jobs_set, worker_id); - let scheduled_jobs_set = self.queue.scheduled_jobs_set.to_string(); - let job_data_hash = self.queue.job_data_hash.to_string(); + let inflight_set = format!("{}:{}", self.config.inflight_jobs_set(), worker_id); + let scheduled_jobs_set = self.config.scheduled_jobs_set(); + let job_data_hash = self.config.job_data_hash(); + let failed_jobs_set = self.config.failed_jobs_set(); let job_fut = self.fetch_by_id(task_id); - let failed_jobs_set = self.queue.failed_jobs_set.to_string(); - let mut storage = self.clone(); let now: i64 = Utc::now().timestamp(); let res = job_fut.await?; + let conn = &mut self.conn; match res { Some(job) => { - let attempt = job.get::().cloned().unwrap_or_default(); + let attempt = &job.parts.attempt; if attempt.current() >= self.config.max_retries { redis::cmd("ZADD") .arg(failed_jobs_set) .arg(now) .arg(task_id.to_string()) - .query_async(&mut conn) + .query_async(conn) .await?; - storage.kill(worker_id, task_id).await?; + self.kill(worker_id, task_id).await?; return Ok(1); } - let job = self - .codec - .encode(&(job.try_into()?)) - .map_err(|e| (ErrorKind::IoError, "Encode error", e.to_string()))?; + let job = C::encode(job) + .map_err(|e| (ErrorKind::IoError, "Encode error", e.into().to_string()))?; let res: Result = retry_job .key(inflight_set) @@ -647,7 +862,7 @@ impl RedisStorage { .arg(task_id.to_string()) .arg(now) .arg(job) - .invoke_async(&mut conn) + .invoke_async(conn) .await; match res { Ok(count) => Ok(count), @@ -661,42 +876,30 @@ impl RedisStorage { /// Attempt to kill a job pub async fn kill(&mut self, worker_id: &WorkerId, task_id: &TaskId) -> Result<(), RedisError> where - T: Send + DeserializeOwned + Serialize + Job + Unpin + Sync + 'static, + T: Send + DeserializeOwned + Serialize + Unpin + Sync + 'static, { - let mut conn = self.conn.clone(); let kill_job = self.scripts.kill_job.clone(); - let current_worker_id = format!("{}:{}", self.queue.inflight_jobs_set, worker_id); - let job_data_hash = self.queue.job_data_hash.to_string(); - let dead_jobs_set = self.queue.dead_jobs_set.to_string(); - let fetch_job = self.fetch_by_id(task_id); + let current_worker_id = format!("{}:{}", self.config.inflight_jobs_set(), worker_id); + let job_data_hash = self.config.job_data_hash(); + let dead_jobs_set = self.config.dead_jobs_set(); let now: i64 = Utc::now().timestamp(); - let res = fetch_job.await?; - match res { - Some(job) => { - let data = self - .codec - .encode(&job.try_into()?) - .map_err(|e| (ErrorKind::IoError, "Encode error", e.to_string()))?; - kill_job - .key(current_worker_id) - .key(dead_jobs_set) - .key(job_data_hash) - .arg(task_id.to_string()) - .arg(now) - .arg(data) - .invoke_async(&mut conn) - .await - } - None => Err(RedisError::from((ErrorKind::ResponseError, "Id not found"))), - } + kill_job + .key(current_worker_id) + .key(dead_jobs_set) + .key(job_data_hash) + .arg(task_id.to_string()) + .arg(now) + .arg("AbortError") + .invoke_async(&mut self.conn) + .await } /// Required to add scheduled jobs to the active set pub async fn enqueue_scheduled(&mut self, count: usize) -> Result { let enqueue_jobs = self.scripts.enqueue_scheduled.clone(); - let scheduled_jobs_set = self.queue.scheduled_jobs_set.to_string(); - let active_jobs_list = self.queue.active_jobs_list.to_string(); - let signal_list = self.queue.signal_list.to_string(); + let scheduled_jobs_set = self.config.scheduled_jobs_set(); + let active_jobs_list = self.config.active_jobs_list(); + let signal_list = self.config.signal_list(); let now: i64 = Utc::now().timestamp(); let res: Result = enqueue_jobs .key(scheduled_jobs_set) @@ -714,11 +917,10 @@ impl RedisStorage { /// Re-enqueue some jobs that might be abandoned. pub async fn reenqueue_active(&mut self, job_ids: Vec<&TaskId>) -> Result<(), RedisError> { - let mut conn = self.conn.clone(); let reenqueue_active = self.scripts.reenqueue_active.clone(); - let inflight_set = self.queue.inflight_jobs_set.to_string(); - let active_jobs_list = self.queue.active_jobs_list.to_string(); - let signal_list = self.queue.signal_list.to_string(); + let inflight_set = self.config.inflight_jobs_set().to_string(); + let active_jobs_list = self.config.active_jobs_list(); + let signal_list = self.config.signal_list(); reenqueue_active .key(inflight_set) @@ -730,19 +932,23 @@ impl RedisStorage { .map(|j| j.to_string()) .collect::>(), ) - .invoke_async(&mut conn) + .invoke_async(&mut self.conn) .await } - /// Re-enqueue some jobs that might be orphaned. + /// Re-enqueue some jobs that might be orphaned after a number of seconds pub async fn reenqueue_orphaned( &mut self, - count: usize, - dead_since: i64, + count: i32, + dead_since: DateTime, ) -> Result { let reenqueue_orphaned = self.scripts.reenqueue_orphaned.clone(); - let consumers_set = self.queue.consumers_set.to_string(); - let active_jobs_list = self.queue.active_jobs_list.to_string(); - let signal_list = self.queue.signal_list.to_string(); + let consumers_set = self.config.consumers_set(); + let active_jobs_list = self.config.active_jobs_list(); + let signal_list = self.config.signal_list(); + + let now = Utc::now(); + let duration = now.signed_duration_since(dead_since); + let dead_since = duration.num_seconds(); let res: Result = reenqueue_orphaned .key(consumers_set) @@ -761,34 +967,38 @@ impl RedisStorage { #[cfg(test)] mod tests { + use apalis_core::generic_storage_test; use email_service::Email; - use futures::StreamExt; + + use apalis_core::test_utils::apalis_test_service_fn; + use apalis_core::test_utils::TestWrapper; + + generic_storage_test!(setup); use super::*; /// migrate DB and return a storage instance. - async fn setup() -> RedisStorage { + async fn setup() -> RedisStorage { let redis_url = std::env::var("REDIS_URL").expect("No REDIS_URL is specified"); // Because connections cannot be shared across async runtime // (different runtimes are created for each test), // we don't share the storage and tests must be run sequentially. let conn = connect(redis_url).await.unwrap(); - let storage = RedisStorage::new(conn); + let mut storage = RedisStorage::new(conn); + cleanup(&mut storage, &WorkerId::new("test-worker")).await; storage } /// rollback DB changes made by tests. /// /// You should execute this function in the end of a test - async fn cleanup(mut storage: RedisStorage, _worker_id: &WorkerId) { + async fn cleanup(storage: &mut RedisStorage, _worker_id: &WorkerId) { let _resp: String = redis::cmd("FLUSHDB") .query_async(&mut storage.conn) .await .expect("failed to Flushdb"); } - struct DummyService {} - fn example_email() -> Email { Email { subject: "Test Subject".to_string(), @@ -797,14 +1007,17 @@ mod tests { } } - async fn consume_one(storage: &RedisStorage, worker_id: &WorkerId) -> Request { - let mut stream = storage.stream_jobs(worker_id, std::time::Duration::from_secs(10), 1); + async fn consume_one( + storage: &mut RedisStorage, + worker_id: &WorkerId, + ) -> Request { + let stream = storage.fetch_next(worker_id); stream - .next() .await .expect("stream is empty") + .first() .expect("failed to poll job") - .expect("no job is pending") + .clone() } async fn register_worker_at(storage: &mut RedisStorage) -> WorkerId { @@ -825,7 +1038,10 @@ mod tests { storage.push(email).await.expect("failed to push a job"); } - async fn get_job(storage: &mut RedisStorage, job_id: &TaskId) -> Request { + async fn get_job( + storage: &mut RedisStorage, + job_id: &TaskId, + ) -> Request { storage .fetch_by_id(job_id) .await @@ -841,8 +1057,6 @@ mod tests { let worker_id = register_worker(&mut storage).await; let _job = consume_one(&mut storage, &worker_id).await; - - cleanup(storage, &worker_id).await; } #[tokio::test] @@ -853,15 +1067,17 @@ mod tests { let worker_id = register_worker(&mut storage).await; let job = consume_one(&mut storage, &worker_id).await; - let job_id = &job.get::().unwrap().id; - + let ctx = &job.parts.context; + let res = 42usize; storage - .ack(&worker_id, &job_id) + .ack( + ctx, + &Response::success(res, job.parts.task_id.clone(), job.parts.attempt.clone()), + ) .await .expect("failed to acknowledge the job"); - let _job = get_job(&mut storage, &job_id).await; - cleanup(storage, &worker_id).await; + let _job = get_job(&mut storage, &job.parts.task_id).await; } #[tokio::test] @@ -873,7 +1089,7 @@ mod tests { let worker_id = register_worker(&mut storage).await; let job = consume_one(&mut storage, &worker_id).await; - let job_id = &job.get::().unwrap().id; + let job_id = &job.parts.task_id; storage .kill(&worker_id, &job_id) @@ -881,8 +1097,6 @@ mod tests { .expect("failed to kill job"); let _job = get_job(&mut storage, &job_id).await; - - cleanup(storage, &worker_id).await; } #[tokio::test] @@ -893,12 +1107,21 @@ mod tests { let worker_id = register_worker_at(&mut storage).await; - let _job = consume_one(&mut storage, &worker_id).await; + let job = consume_one(&mut storage, &worker_id).await; + let dead_since = Utc::now() - chrono::Duration::from_std(Duration::from_secs(300)).unwrap(); storage - .reenqueue_orphaned(5, 300) + .reenqueue_orphaned(1, dead_since) .await .expect("failed to reenqueue_orphaned"); - cleanup(storage, &worker_id).await; + let job = get_job(&mut storage, &job.parts.task_id).await; + let ctx = &job.parts.context; + // assert_eq!(*ctx.status(), State::Pending); + // assert!(ctx.done_at().is_none()); + assert!(ctx.lock_by.is_none()); + // assert!(ctx.lock_at().is_none()); + // assert_eq!(*ctx.last_error(), Some("Job was abandoned".to_owned())); + // TODO: Redis should store context aside + // assert_eq!(job.parts.attempt.current(), 1); } #[tokio::test] @@ -909,12 +1132,19 @@ mod tests { let worker_id = register_worker_at(&mut storage).await; - let _job = consume_one(&mut storage, &worker_id).await; - let result = storage - .reenqueue_orphaned(5, 300) + let job = consume_one(&mut storage, &worker_id).await; + let dead_since = Utc::now() - chrono::Duration::from_std(Duration::from_secs(300)).unwrap(); + storage + .reenqueue_orphaned(1, dead_since) .await .expect("failed to reenqueue_orphaned"); - - cleanup(storage, &worker_id).await; + let job = get_job(&mut storage, &job.parts.task_id).await; + let _ctx = &job.parts.context; + // assert_eq!(*ctx.status(), State::Running); + // TODO: update redis context + // assert_eq!(ctx.lock_by, Some(worker_id)); + // assert!(ctx.lock_at().is_some()); + // assert_eq!(*ctx.last_error(), None); + assert_eq!(job.parts.attempt.current(), 0); } } diff --git a/packages/apalis-sql/Cargo.toml b/packages/apalis-sql/Cargo.toml index a018d80..6b78766 100644 --- a/packages/apalis-sql/Cargo.toml +++ b/packages/apalis-sql/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "apalis-sql" -version = "0.5.5" +version = "0.6.0" authors = ["Njuguna Mureithi "] edition.workspace = true repository.workspace = true @@ -10,7 +10,7 @@ license = "MIT" description = "SQL Storage for apalis. Use sqlite, postgres and mysql for background job processing" [features] -default = ["sqlite", "migrate", "postgres"] +default = ["migrate"] postgres = ["sqlx/postgres", "sqlx/json"] sqlite = ["sqlx/sqlite", "sqlx/json"] mysql = ["sqlx/mysql", "sqlx/json", "sqlx/bigdecimal"] @@ -26,7 +26,7 @@ features = ["chrono"] [dependencies] serde = { version = "1", features = ["derive"] } serde_json = "1" -apalis-core = { path = "../../packages/apalis-core", version = "0.5.5", default-features = false, features = [ +apalis-core = { path = "../../packages/apalis-core", version = "0.6.0", default-features = false, features = [ "sleep", "json", ] } @@ -35,15 +35,18 @@ futures = "0.3.30" async-stream = "0.3.5" tokio = { version = "1", features = ["rt", "net"], optional = true } futures-lite = "2.3.0" -async-std = { version = "1.12.0", optional = true } +async-std = { version = "1.13.0", optional = true } +chrono = { version = "0.4", features = ["serde"] } +thiserror = "2.0.0" + [dev-dependencies] tokio = { version = "1", features = ["macros", "rt-multi-thread"] } email-service = { path = "../../examples/email-service" } -apalis = { path = "../../", default-features = false, features = [ - "tokio-comp", -] } +apalis = { path = "../../", default-features = false } once_cell = "1.19.0" +apalis-sql = { path = ".", features = ["tokio-comp"] } +apalis-core = { path = "../apalis-core", features = ["test-utils"] } [package.metadata.docs.rs] # defines the configuration attribute `docsrs` diff --git a/packages/apalis-sql/src/context.rs b/packages/apalis-sql/src/context.rs index 50cc0cf..40a37d0 100644 --- a/packages/apalis-sql/src/context.rs +++ b/packages/apalis-sql/src/context.rs @@ -1,18 +1,16 @@ -use apalis_core::error::Error; -use apalis_core::task::{attempt::Attempt, task_id::TaskId}; +use apalis_core::request::Request; +use apalis_core::service_fn::FromRequest; use apalis_core::worker::WorkerId; +use apalis_core::{error::Error, request::State}; +use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; -use sqlx::types::chrono::{DateTime, Utc}; -use std::{fmt, str::FromStr}; /// The context for a job is represented here -/// Used to provide a context when a job is defined through the [Job] trait -#[derive(Debug, Clone)] +/// Used to provide a context for a job with an sql backend +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct SqlContext { - id: TaskId, status: State, run_at: DateTime, - attempts: Attempt, max_attempts: i32, last_error: Option, lock_at: Option, @@ -20,16 +18,20 @@ pub struct SqlContext { done_at: Option, } +impl Default for SqlContext { + fn default() -> Self { + Self::new() + } +} + impl SqlContext { - /// Build a new context with defaults given an ID. - pub fn new(id: TaskId) -> Self { + /// Build a new context with defaults + pub fn new() -> Self { SqlContext { - id, status: State::Pending, run_at: Utc::now(), lock_at: None, done_at: None, - attempts: Default::default(), max_attempts: 25, last_error: None, lock_by: None, @@ -46,21 +48,6 @@ impl SqlContext { self.max_attempts } - /// Get the id for a job - pub fn id(&self) -> &TaskId { - &self.id - } - - /// Gets the current attempts for a job. Default 0 - pub fn attempts(&self) -> &Attempt { - &self.attempts - } - - /// Set the number of attempts - pub fn set_attempts(&mut self, attempts: i32) { - self.attempts = Attempt::new_with_value(attempts.try_into().unwrap()); - } - /// Get the time a job was done pub fn done_at(&self) -> &Option { &self.done_at @@ -117,65 +104,13 @@ impl SqlContext { } /// Set the last error - pub fn set_last_error(&mut self, error: String) { - self.last_error = Some(error); - } - - /// Record an attempt to execute the request - pub fn record_attempt(&mut self) { - self.attempts.increment(); - } -} - -/// Represents the state of a [Request] -#[derive(Serialize, Deserialize, Debug, Clone, Hash, PartialEq, std::cmp::Eq)] -pub enum State { - /// Job is pending - #[serde(alias = "Latest")] - Pending, - /// Job is running - Running, - /// Job was done successfully - Done, - /// Retry Job - Retry, - /// Job has failed. Check `last_error` - Failed, - /// Job has been killed - Killed, -} - -impl Default for State { - fn default() -> Self { - State::Pending - } -} - -impl FromStr for State { - type Err = Error; - - fn from_str(s: &str) -> Result { - match s { - "Pending" | "Latest" => Ok(State::Pending), - "Running" => Ok(State::Running), - "Done" => Ok(State::Done), - "Retry" => Ok(State::Retry), - "Failed" => Ok(State::Failed), - "Killed" => Ok(State::Killed), - _ => Err(Error::InvalidContext("Invalid Job state".to_string())), - } + pub fn set_last_error(&mut self, error: Option) { + self.last_error = error; } } -impl fmt::Display for State { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match &self { - State::Pending => write!(f, "Pending"), - State::Running => write!(f, "Running"), - State::Done => write!(f, "Done"), - State::Retry => write!(f, "Retry"), - State::Failed => write!(f, "Failed"), - State::Killed => write!(f, "Killed"), - } +impl FromRequest> for SqlContext { + fn from_request(req: &Request) -> Result { + Ok(req.parts.context.clone()) } } diff --git a/packages/apalis-sql/src/from_row.rs b/packages/apalis-sql/src/from_row.rs index b14c675..db5787d 100644 --- a/packages/apalis-sql/src/from_row.rs +++ b/packages/apalis-sql/src/from_row.rs @@ -1,23 +1,24 @@ +use apalis_core::request::Parts; +use apalis_core::task::attempt::Attempt; use apalis_core::task::task_id::TaskId; -use apalis_core::{data::Extensions, request::Request, worker::WorkerId}; +use apalis_core::{request::Request, worker::WorkerId}; + +use serde::{Deserialize, Serialize}; use sqlx::{Decode, Type}; use crate::context::SqlContext; /// Wrapper for [Request] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct SqlRequest { - pub(crate) req: T, - pub(crate) context: SqlContext, + /// The inner request + pub req: Request, + pub(crate) _priv: (), } -impl From> for Request { - fn from(val: SqlRequest) -> Self { - let mut data = Extensions::new(); - data.insert(val.context.id().clone()); - data.insert(val.context.attempts().clone()); - data.insert(val.context); - - Request::new_with_data(val.req, data) +impl SqlRequest { + /// Creates a new SqlRequest. + pub fn new(req: Request) -> Self { + SqlRequest { req, _priv: () } } } @@ -27,26 +28,30 @@ impl<'r, T: Decode<'r, sqlx::Sqlite> + Type> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for SqlRequest { fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> Result { - use sqlx::types::chrono::DateTime; + use chrono::DateTime; use sqlx::Row; use std::str::FromStr; let job: T = row.try_get("job")?; - let id: TaskId = + let task_id: TaskId = TaskId::from_str(row.try_get("id")?).map_err(|e| sqlx::Error::ColumnDecode { index: "id".to_string(), source: Box::new(e), })?; - let mut context = crate::context::SqlContext::new(id); + let mut parts = Parts::::default(); + parts.task_id = task_id; + + let attempt: i32 = row.try_get("attempts").unwrap_or(0); + parts.attempt = Attempt::new_with_value(attempt as usize); + + let mut context = crate::context::SqlContext::new(); let run_at: i64 = row.try_get("run_at")?; context.set_run_at(DateTime::from_timestamp(run_at, 0).unwrap_or_default()); - let attempts = row.try_get("attempts").unwrap_or(0); - context.set_attempts(attempts); - - let max_attempts = row.try_get("max_attempts").unwrap_or(25); - context.set_max_attempts(max_attempts); + if let Ok(max_attempts) = row.try_get("max_attempts") { + context.set_max_attempts(max_attempts) + } let done_at: Option = row.try_get("done_at").unwrap_or_default(); context.set_done_at(done_at); @@ -74,8 +79,11 @@ impl<'r, T: Decode<'r, sqlx::Sqlite> + Type> source: "Could not parse lock_by as a WorkerId".into(), })?, ); - - Ok(SqlRequest { context, req: job }) + parts.context = context; + Ok(SqlRequest { + req: Request::new_with_parts(job, parts), + _priv: (), + }) } } @@ -85,32 +93,35 @@ impl<'r, T: Decode<'r, sqlx::Postgres> + Type> sqlx::FromRow<'r, sqlx::postgres::PgRow> for SqlRequest { fn from_row(row: &'r sqlx::postgres::PgRow) -> Result { + use chrono::Utc; use sqlx::Row; use std::str::FromStr; - type Timestamp = i64; let job: T = row.try_get("job")?; - let id: TaskId = + let task_id: TaskId = TaskId::from_str(row.try_get("id")?).map_err(|e| sqlx::Error::ColumnDecode { index: "id".to_string(), source: Box::new(e), })?; - let mut context = SqlContext::new(id); + let mut parts = Parts::::default(); + parts.task_id = task_id; + + let attempt: i32 = row.try_get("attempts").unwrap_or(0); + parts.attempt = Attempt::new_with_value(attempt as usize); + let mut context = SqlContext::new(); let run_at = row.try_get("run_at")?; context.set_run_at(run_at); - let attempts = row.try_get("attempts").unwrap_or(0); - context.set_attempts(attempts); + if let Ok(max_attempts) = row.try_get("max_attempts") { + context.set_max_attempts(max_attempts) + } - let max_attempts = row.try_get("max_attempts").unwrap_or(25); - context.set_max_attempts(max_attempts); - - let done_at: Option = row.try_get("done_at").unwrap_or_default(); - context.set_done_at(done_at); + let done_at: Option> = row.try_get("done_at").unwrap_or_default(); + context.set_done_at(done_at.map(|d| d.timestamp())); - let lock_at: Option = row.try_get("lock_at").unwrap_or_default(); - context.set_lock_at(lock_at); + let lock_at: Option> = row.try_get("lock_at").unwrap_or_default(); + context.set_lock_at(lock_at.map(|d| d.timestamp())); let last_error = row.try_get("last_error").unwrap_or_default(); context.set_last_error(last_error); @@ -132,7 +143,11 @@ impl<'r, T: Decode<'r, sqlx::Postgres> + Type> source: "Could not parse lock_by as a WorkerId".into(), })?, ); - Ok(SqlRequest { context, req: job }) + parts.context = context; + Ok(SqlRequest { + req: Request::new_with_parts(job, parts), + _priv: (), + }) } } @@ -144,31 +159,32 @@ impl<'r, T: Decode<'r, sqlx::MySql> + Type> sqlx::FromRow<'r, sqlx: fn from_row(row: &'r sqlx::mysql::MySqlRow) -> Result { use sqlx::Row; use std::str::FromStr; - - type Timestamp = i64; - let job: T = row.try_get("job")?; - let id: TaskId = + let task_id: TaskId = TaskId::from_str(row.try_get("id")?).map_err(|e| sqlx::Error::ColumnDecode { index: "id".to_string(), source: Box::new(e), })?; - let mut context = SqlContext::new(id); + let mut parts = Parts::::default(); + parts.task_id = task_id; + + let attempt: i32 = row.try_get("attempts").unwrap_or(0); + parts.attempt = Attempt::new_with_value(attempt as usize); + + let mut context = SqlContext::new(); let run_at = row.try_get("run_at")?; context.set_run_at(run_at); - let attempts = row.try_get("attempts").unwrap_or(0); - context.set_attempts(attempts); + if let Ok(max_attempts) = row.try_get("max_attempts") { + context.set_max_attempts(max_attempts) + } - let max_attempts = row.try_get("max_attempts").unwrap_or(25); - context.set_max_attempts(max_attempts); + let done_at: Option = row.try_get("done_at").unwrap_or_default(); + context.set_done_at(done_at.map(|d| d.and_utc().timestamp())); - let done_at: Option = row.try_get("done_at").unwrap_or_default(); - context.set_done_at(done_at); - - let lock_at: Option = row.try_get("lock_at").unwrap_or_default(); - context.set_lock_at(lock_at); + let lock_at: Option = row.try_get("lock_at").unwrap_or_default(); + context.set_lock_at(lock_at.map(|d| d.and_utc().timestamp())); let last_error = row.try_get("last_error").unwrap_or_default(); context.set_last_error(last_error); @@ -190,7 +206,10 @@ impl<'r, T: Decode<'r, sqlx::MySql> + Type> sqlx::FromRow<'r, sqlx: source: "Could not parse lock_by as a WorkerId".into(), })?, ); - - Ok(SqlRequest { context, req: job }) + parts.context = context; + Ok(SqlRequest { + req: Request::new_with_parts(job, parts), + _priv: (), + }) } } diff --git a/packages/apalis-sql/src/lib.rs b/packages/apalis-sql/src/lib.rs index 25caf27..4d3a1e0 100644 --- a/packages/apalis-sql/src/lib.rs +++ b/packages/apalis-sql/src/lib.rs @@ -10,7 +10,9 @@ //! apalis offers Sqlite, Mysql and Postgres storages for its workers. //! See relevant modules for examples -use std::time::Duration; +use std::{num::TryFromIntError, time::Duration}; + +use apalis_core::{error::Error, request::State}; /// The context of the sql job pub mod context; @@ -33,12 +35,28 @@ pub mod sqlite; #[cfg_attr(docsrs, doc(cfg(feature = "mysql")))] pub mod mysql; +// Re-exports +pub use sqlx; + /// Config for sql storages #[derive(Debug, Clone)] pub struct Config { keep_alive: Duration, buffer_size: usize, poll_interval: Duration, + reenqueue_orphaned_after: Duration, + namespace: String, +} + +/// A general sql error +#[derive(Debug, thiserror::Error)] +pub enum SqlError { + /// Handles sqlx errors + #[error("sqlx::Error: {0}")] + Sqlx(#[from] sqlx::Error), + /// Handles int conversion errors + #[error("TryFromIntError: {0}")] + TryFromInt(#[from] TryFromIntError), } impl Default for Config { @@ -46,16 +64,23 @@ impl Default for Config { Self { keep_alive: Duration::from_secs(30), buffer_size: 10, - poll_interval: Duration::from_millis(50), + poll_interval: Duration::from_millis(100), + reenqueue_orphaned_after: Duration::from_secs(300), // 5 minutes + namespace: String::from("apalis::sql"), } } } impl Config { + /// Create a new config with a jobs namespace + pub fn new(namespace: &str) -> Self { + Config::default().set_namespace(namespace) + } + /// Interval between database poll queries /// - /// Defaults to 30ms - pub fn poll_interval(mut self, interval: Duration) -> Self { + /// Defaults to 100ms + pub fn set_poll_interval(mut self, interval: Duration) -> Self { self.poll_interval = interval; self } @@ -63,7 +88,7 @@ impl Config { /// Interval between worker keep-alive database updates /// /// Defaults to 30s - pub fn keep_alive(mut self, keep_alive: Duration) -> Self { + pub fn set_keep_alive(mut self, keep_alive: Duration) -> Self { self.keep_alive = keep_alive; self } @@ -71,8 +96,206 @@ impl Config { /// Buffer size to use when querying for jobs /// /// Defaults to 10 - pub fn buffer_size(mut self, buffer_size: usize) -> Self { + pub fn set_buffer_size(mut self, buffer_size: usize) -> Self { self.buffer_size = buffer_size; self } + + /// Set the namespace to consume and push jobs to + /// + /// Defaults to "apalis::sql" + pub fn set_namespace(mut self, namespace: &str) -> Self { + self.namespace = namespace.to_string(); + self + } + + /// Gets a reference to the keep_alive duration. + pub fn keep_alive(&self) -> &Duration { + &self.keep_alive + } + + /// Gets a mutable reference to the keep_alive duration. + pub fn keep_alive_mut(&mut self) -> &mut Duration { + &mut self.keep_alive + } + + /// Gets the buffer size. + pub fn buffer_size(&self) -> usize { + self.buffer_size + } + + /// Gets a reference to the poll_interval duration. + pub fn poll_interval(&self) -> &Duration { + &self.poll_interval + } + + /// Gets a mutable reference to the poll_interval duration. + pub fn poll_interval_mut(&mut self) -> &mut Duration { + &mut self.poll_interval + } + + /// Gets a reference to the namespace. + pub fn namespace(&self) -> &String { + &self.namespace + } + + /// Gets a mutable reference to the namespace. + pub fn namespace_mut(&mut self) -> &mut String { + &mut self.namespace + } + + /// Gets the reenqueue_orphaned_after duration. + pub fn reenqueue_orphaned_after(&self) -> Duration { + self.reenqueue_orphaned_after + } + + /// Gets a mutable reference to the reenqueue_orphaned_after. + pub fn reenqueue_orphaned_after_mut(&mut self) -> &mut Duration { + &mut self.reenqueue_orphaned_after + } + + /// Occasionally some workers die, or abandon jobs because of panics. + /// This is the time a task takes before its back to the queue + /// + /// Defaults to 5 minutes + pub fn set_reenqueue_orphaned_after(mut self, after: Duration) -> Self { + self.reenqueue_orphaned_after = after; + self + } +} + +/// Calculates the status from a result +pub fn calculate_status(res: &Result) -> State { + match res { + Ok(_) => State::Done, + Err(e) => match &e { + Error::Abort(_) => State::Killed, + _ => State::Failed, + }, + } +} + +/// Standard checks for any sql backend +#[macro_export] +macro_rules! sql_storage_tests { + ($setup:path, $storage_type:ty, $job_type:ty) => { + async fn setup_test_wrapper( + ) -> TestWrapper<$storage_type, Request<$job_type, SqlContext>, ()> { + let (mut t, poller) = TestWrapper::new_with_service( + $setup().await, + apalis_core::service_fn::service_fn(email_service::send_email), + ); + tokio::spawn(poller); + t.vacuum().await.unwrap(); + t + } + + #[tokio::test] + async fn integration_test_kill_job() { + let mut storage = setup_test_wrapper().await; + + storage + .push(email_service::example_killed_email()) + .await + .unwrap(); + + let (job_id, res) = storage.execute_next().await; + assert_eq!(res, Err("AbortError: Invalid character.".to_owned())); + apalis_core::sleep(Duration::from_secs(1)).await; + let job = storage + .fetch_by_id(&job_id) + .await + .unwrap() + .expect("No job found"); + let ctx = job.parts.context; + assert_eq!(*ctx.status(), State::Killed); + // assert!(ctx.done_at().is_some()); + assert_eq!( + ctx.last_error().clone().unwrap(), + "{\"Err\":\"AbortError: Invalid character.\"}" + ); + } + + #[tokio::test] + async fn integration_test_acknowledge_good_job() { + let mut storage = setup_test_wrapper().await; + storage + .push(email_service::example_good_email()) + .await + .unwrap(); + + let (job_id, res) = storage.execute_next().await; + assert_eq!(res, Ok("()".to_owned())); + apalis_core::sleep(Duration::from_secs(1)).await; + let job = storage.fetch_by_id(&job_id).await.unwrap().unwrap(); + let ctx = job.parts.context; + assert_eq!(*ctx.status(), State::Done); + assert!(ctx.done_at().is_some()); + } + + #[tokio::test] + async fn integration_test_acknowledge_failed_job() { + let mut storage = setup_test_wrapper().await; + + storage + .push(email_service::example_retry_able_email()) + .await + .unwrap(); + + let (job_id, res) = storage.execute_next().await; + assert_eq!( + res, + Err("FailedError: Missing separator character '@'.".to_owned()) + ); + apalis_core::sleep(Duration::from_secs(1)).await; + let job = storage.fetch_by_id(&job_id).await.unwrap().unwrap(); + let ctx = job.parts.context; + assert_eq!(*ctx.status(), State::Failed); + assert!(job.parts.attempt.current() >= 1); + assert_eq!( + ctx.last_error().clone().unwrap(), + "{\"Err\":\"FailedError: Missing separator character '@'.\"}" + ); + } + + #[tokio::test] + async fn worker_consume() { + use apalis_core::builder::WorkerBuilder; + use apalis_core::builder::WorkerFactoryFn; + let storage = $setup().await; + let mut handle = storage.clone(); + + let parts = handle + .push(email_service::example_good_email()) + .await + .unwrap(); + + async fn task(_job: Email) -> &'static str { + tokio::time::sleep(Duration::from_millis(100)).await; + "Job well done" + } + let worker = WorkerBuilder::new("rango-tango").backend(storage); + let worker = worker.build_fn(task); + let wkr = worker.run(); + + let w = wkr.get_handle(); + + let runner = async move { + apalis_core::sleep(Duration::from_secs(3)).await; + let job_id = &parts.task_id; + let job = get_job(&mut handle, job_id).await; + let ctx = job.parts.context; + + assert_eq!(*ctx.status(), State::Done); + assert!(ctx.done_at().is_some()); + assert!(ctx.lock_by().is_some()); + assert!(ctx.lock_at().is_some()); + assert!(ctx.last_error().is_some()); // TODO: rename last_error to last_result + + w.stop(); + }; + + tokio::join!(runner, wkr); + } + }; } diff --git a/packages/apalis-sql/src/mysql.rs b/packages/apalis-sql/src/mysql.rs index 555e628..c925d0d 100644 --- a/packages/apalis-sql/src/mysql.rs +++ b/packages/apalis-sql/src/mysql.rs @@ -1,62 +1,76 @@ +use apalis_core::backend::{BackendExpose, Stat, WorkerState}; use apalis_core::codec::json::JsonCodec; -use apalis_core::error::Error; +use apalis_core::error::{BoxDynError, Error}; use apalis_core::layers::{Ack, AckLayer}; use apalis_core::notify::Notify; use apalis_core::poller::controller::Controller; use apalis_core::poller::stream::BackendStream; use apalis_core::poller::Poller; -use apalis_core::request::{Request, RequestStream}; -use apalis_core::storage::{Job, Storage}; +use apalis_core::request::{Parts, Request, RequestStream, State}; +use apalis_core::response::Response; +use apalis_core::storage::Storage; +use apalis_core::task::namespace::Namespace; use apalis_core::task::task_id::TaskId; -use apalis_core::worker::WorkerId; -use apalis_core::{Backend, Codec}; +use apalis_core::worker::{Context, Event, Worker, WorkerId}; +use apalis_core::{backend::Backend, codec::Codec}; use async_stream::try_stream; +use chrono::{DateTime, Utc}; use futures::{Stream, StreamExt, TryStreamExt}; use log::error; use serde::{de::DeserializeOwned, Serialize}; use serde_json::Value; use sqlx::mysql::MySqlRow; -use sqlx::types::chrono::{DateTime, Utc}; use sqlx::{MySql, Pool, Row}; +use std::any::type_name; use std::convert::TryInto; +use std::fmt::Debug; use std::sync::Arc; use std::{fmt, io}; use std::{marker::PhantomData, ops::Add, time::Duration}; use crate::context::SqlContext; use crate::from_row::SqlRequest; -use crate::Config; +use crate::{calculate_status, Config, SqlError}; pub use sqlx::mysql::MySqlPool; +type MysqlCodec = JsonCodec; + /// Represents a [Storage] that persists to MySQL -pub struct MysqlStorage { +pub struct MysqlStorage> +where + C: Codec, +{ pool: Pool, job_type: PhantomData, controller: Controller, config: Config, - codec: Arc + Sync + Send + 'static>>, - ack_notify: Notify<(WorkerId, TaskId)>, + codec: PhantomData, + ack_notify: Notify<(SqlContext, Response)>, } -impl fmt::Debug for MysqlStorage { +impl fmt::Debug for MysqlStorage +where + C: Debug + Codec, + C::Compact: Debug, +{ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("MysqlStorage") .field("pool", &self.pool) .field("job_type", &"PhantomData") .field("controller", &self.controller) .field("config", &self.config) - .field( - "codec", - &"Arc + Sync + Send + 'static>>", - ) + .field("codec", &self.codec) .field("ack_notify", &self.ack_notify) .finish() } } -impl Clone for MysqlStorage { +impl Clone for MysqlStorage +where + C: Debug + Codec, +{ fn clone(&self) -> Self { let pool = self.pool.clone(); MysqlStorage { @@ -64,13 +78,13 @@ impl Clone for MysqlStorage { job_type: PhantomData, controller: self.controller.clone(), config: self.config.clone(), - codec: self.codec.clone(), + codec: self.codec, ack_notify: self.ack_notify.clone(), } } } -impl MysqlStorage<()> { +impl MysqlStorage<(), JsonCodec> { /// Get mysql migrations without running them #[cfg(feature = "migrate")] pub fn migrations() -> sqlx::migrate::Migrator { @@ -85,10 +99,13 @@ impl MysqlStorage<()> { } } -impl MysqlStorage { +impl MysqlStorage +where + T: Serialize + DeserializeOwned, +{ /// Create a new instance from a pool pub fn new(pool: MySqlPool) -> Self { - Self::new_with_config(pool, Config::default()) + Self::new_with_config(pool, Config::new(type_name::())) } /// Create a new instance from a pool and custom config @@ -98,8 +115,8 @@ impl MysqlStorage { job_type: PhantomData, controller: Controller::new(), config, - codec: Arc::new(Box::new(JsonCodec)), ack_notify: Notify::new(), + codec: PhantomData, } } @@ -107,26 +124,33 @@ impl MysqlStorage { pub fn pool(&self) -> &Pool { &self.pool } + + /// Get the config used by the storage + pub fn get_config(&self) -> &Config { + &self.config + } } -impl MysqlStorage { +impl MysqlStorage +where + T: DeserializeOwned + Send + Unpin + Sync + 'static, + C: Codec + Send + 'static, +{ fn stream_jobs( self, worker_id: &WorkerId, interval: Duration, buffer_size: usize, - ) -> impl Stream>, sqlx::Error>> { + ) -> impl Stream>, sqlx::Error>> { let pool = self.pool.clone(); let worker_id = worker_id.to_string(); - try_stream! { - let pool = pool.clone(); let buffer_size = u32::try_from(buffer_size) .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?; loop { apalis_core::sleep(interval).await; let pool = pool.clone(); - let job_type = T::NAME; + let job_type = self.config.namespace.clone(); let mut tx = pool.begin().await?; let fetch_query = "SELECT id FROM jobs WHERE status = 'Pending' AND run_at <= NOW() AND job_type = ? ORDER BY run_at ASC LIMIT ? FOR UPDATE SKIP LOCKED"; @@ -156,10 +180,14 @@ impl MysqlStorage let jobs: Vec> = query.fetch_all(&pool).await?; for job in jobs { - yield Some(Into::into(SqlRequest { - context: job.context, - req: self.codec.decode(&job.req).map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))? - })) + yield { + let (req, ctx) = job.req.take_parts(); + let req = C::decode(req) + .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?; + let mut req: Request = Request::new_with_parts(req, ctx); + req.parts.namespace = Some(Namespace(self.config.namespace.clone())); + Some(req) + } } } } @@ -174,7 +202,7 @@ impl MysqlStorage let pool = self.pool.clone(); let mut tx = pool.acquire().await?; - let worker_type = T::NAME; + let worker_type = self.config.namespace.clone(); let storage_name = std::any::type_name::(); let query = "INSERT INTO workers (id, worker_type, storage_name, layers, last_seen) VALUES (?, ?, ?, ?, ?) ON DUPLICATE KEY UPDATE id = ?;"; @@ -191,62 +219,66 @@ impl MysqlStorage } } -impl Storage for MysqlStorage +impl Storage for MysqlStorage where - T: Job + Serialize + DeserializeOwned + Send + 'static + Unpin + Sync, + T: Serialize + DeserializeOwned + Send + 'static + Unpin + Sync, + C: Codec + Send, { type Job = T; type Error = sqlx::Error; - type Identifier = TaskId; + type Context = SqlContext; - async fn push(&mut self, job: Self::Job) -> Result { - let id = TaskId::new(); + async fn push_request( + &mut self, + job: Request, + ) -> Result, sqlx::Error> { + let (args, parts) = job.take_parts(); let query = - "INSERT INTO jobs VALUES (?, ?, ?, 'Pending', 0, 25, now(), NULL, NULL, NULL, NULL)"; + "INSERT INTO jobs VALUES (?, ?, ?, 'Pending', 0, ?, now(), NULL, NULL, NULL, NULL)"; let pool = self.pool.clone(); - let job = self - .codec - .encode(&job) + let job = C::encode(args) .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?; - let job_type = T::NAME; + let job_type = self.config.namespace.clone(); sqlx::query(query) .bind(job) - .bind(id.to_string()) + .bind(parts.task_id.to_string()) .bind(job_type.to_string()) + .bind(parts.context.max_attempts()) .execute(&pool) .await?; - Ok(id) + Ok(parts) } - async fn schedule(&mut self, job: Self::Job, on: i64) -> Result { - let query = - "INSERT INTO jobs VALUES (?, ?, ?, 'Pending', 0, 25, ?, NULL, NULL, NULL, NULL)"; + async fn schedule_request( + &mut self, + req: Request, + on: i64, + ) -> Result, sqlx::Error> { + let query = "INSERT INTO jobs VALUES (?, ?, ?, 'Pending', 0, ?, ?, NULL, NULL, NULL, NULL)"; let pool = self.pool.clone(); - let id = TaskId::new(); - let job = self - .codec - .encode(&job) + let args = C::encode(&req.args) .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?; - let job_type = T::NAME; + let job_type = self.config.namespace.clone(); sqlx::query(query) - .bind(job) - .bind(id.to_string()) + .bind(args) + .bind(req.parts.task_id.to_string()) .bind(job_type) + .bind(req.parts.context.max_attempts()) .bind(on) .execute(&pool) .await?; - Ok(id) + Ok(req.parts) } async fn fetch_by_id( - &self, + &mut self, job_id: &TaskId, - ) -> Result>, sqlx::Error> { + ) -> Result>, sqlx::Error> { let pool = self.pool.clone(); let fetch_query = "SELECT * FROM jobs WHERE id = ?"; @@ -256,19 +288,18 @@ where .await?; match res { None => Ok(None), - Some(c) => Ok(Some( - SqlRequest { - context: c.context, - req: self.codec.decode(&c.req).map_err(|e| { - sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)) - })?, - } - .into(), - )), + Some(job) => Ok(Some({ + let (req, parts) = job.req.take_parts(); + let req = C::decode(req) + .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?; + let mut req = Request::new_with_parts(req, parts); + req.parts.namespace = Some(Namespace(self.config.namespace.clone())); + req + })), } } - async fn len(&self) -> Result { + async fn len(&mut self) -> Result { let pool = self.pool.clone(); let query = "Select Count(*) as count from jobs where status='Pending'"; @@ -276,12 +307,13 @@ where record.try_get("count") } - async fn reschedule(&mut self, job: Request, wait: Duration) -> Result<(), sqlx::Error> { + async fn reschedule( + &mut self, + job: Request, + wait: Duration, + ) -> Result<(), sqlx::Error> { let pool = self.pool.clone(); - let job_id = job.get::().ok_or(sqlx::Error::Io(io::Error::new( - io::ErrorKind::InvalidData, - "Missing TaskId", - )))?; + let job_id = job.parts.task_id.clone(); let wait: i64 = wait .as_secs() @@ -299,21 +331,16 @@ where Ok(()) } - async fn update(&self, job: Request) -> Result<(), sqlx::Error> { + async fn update(&mut self, job: Request) -> Result<(), sqlx::Error> { let pool = self.pool.clone(); - let ctx = job - .get::() - .ok_or(sqlx::Error::Io(io::Error::new( - io::ErrorKind::InvalidData, - "Missing TaskId", - )))?; + let ctx = job.parts.context; let status = ctx.status().to_string(); - let attempts = ctx.attempts(); + let attempts = job.parts.attempt; let done_at = *ctx.done_at(); let lock_by = ctx.lock_by().clone(); let lock_at = *ctx.lock_at(); let last_error = ctx.last_error().clone(); - let job_id = ctx.id(); + let job_id = job.parts.task_id; let mut tx = pool.acquire().await?; let query = "UPDATE jobs SET status = ?, attempts = ?, done_at = ?, lock_by = ?, lock_at = ?, last_error = ? WHERE id = ?"; @@ -335,11 +362,11 @@ where Ok(()) } - async fn is_empty(&self) -> Result { + async fn is_empty(&mut self) -> Result { Ok(self.len().await? == 0) } - async fn vacuum(&self) -> Result { + async fn vacuum(&mut self) -> Result { let pool = self.pool.clone(); let query = "Delete from jobs where status='Done'"; let record = sqlx::query(query).execute(&pool).await?; @@ -347,27 +374,49 @@ where } } -impl Backend> - for MysqlStorage -{ - type Stream = BackendStream>>; +/// Errors that can occur while polling a MySQL database. +#[derive(thiserror::Error, Debug)] +pub enum MysqlPollError { + /// Error during task acknowledgment. + #[error("Encountered an error during ACK: `{0}`")] + AckError(sqlx::Error), - type Layer = AckLayer, T>; + /// Error during result encoding. + #[error("Encountered an error during encoding the result: {0}")] + CodecError(BoxDynError), - fn common_layer(&self, worker_id: WorkerId) -> Self::Layer { - AckLayer::new(self.clone(), worker_id) - } + /// Error during a keep-alive heartbeat. + #[error("Encountered an error during KeepAlive heartbeat: `{0}`")] + KeepAliveError(sqlx::Error), - fn poll(self, worker: WorkerId) -> Poller { + /// Error during re-enqueuing orphaned tasks. + #[error("Encountered an error during ReenqueueOrphaned heartbeat: `{0}`")] + ReenqueueOrphanedError(sqlx::Error), +} + +impl Backend, Res> for MysqlStorage +where + Req: Serialize + DeserializeOwned + Sync + Send + Unpin + 'static, + C: Debug + Codec + Clone + Send + 'static + Sync, + C::Error: std::error::Error + 'static + Send + Sync, +{ + type Stream = BackendStream>>; + + type Layer = AckLayer, Req, SqlContext, Res>; + + fn poll(self, worker: &Worker) -> Poller { + let layer = AckLayer::new(self.clone()); let config = self.config.clone(); let controller = self.controller.clone(); let pool = self.pool.clone(); let ack_notify = self.ack_notify.clone(); let mut hb_storage = self.clone(); + let requeue_storage = self.clone(); let stream = self - .stream_jobs(&worker, config.poll_interval, config.buffer_size) - .map_err(|e| Error::SourceError(Box::new(e))); + .stream_jobs(worker.id(), config.poll_interval, config.buffer_size) + .map_err(|e| Error::SourceError(Arc::new(Box::new(e)))); let stream = BackendStream::new(stream.boxed(), controller); + let w = worker.clone(); let ack_heartbeat = async move { while let Some(ids) = ack_notify @@ -376,58 +425,101 @@ impl Back .next() .await { - let worker_ids: Vec = ids.iter().map(|c| c.0.to_string()).collect(); - let task_ids: Vec = ids.iter().map(|c| c.1.to_string()).collect(); - let id_params = format!("?{}", ", ?".repeat(task_ids.len() - 1)); - let worker_params = format!("?{}", ", ?".repeat(worker_ids.len() - 1)); - let query = - format!("UPDATE jobs SET status = 'Done', done_at = now() WHERE id IN ( { } ) AND lock_by IN ( { } )", id_params, worker_params); - let mut query = sqlx::query(&query); - for i in task_ids { - query = query.bind(i); - } - for i in worker_ids { - query = query.bind(i); - } - if let Err(e) = query.execute(&pool).await { - error!("Ack failed: {e}"); + for (ctx, res) in ids { + let query = "UPDATE jobs SET status = ?, done_at = now(), last_error = ? WHERE id = ? AND lock_by = ?"; + let query = sqlx::query(query); + let last_result = + C::encode(res.inner.as_ref().map_err(|e| e.to_string())).map_err(Box::new); + match (last_result, ctx.lock_by()) { + (Ok(val), Some(worker_id)) => { + let query = query + .bind(calculate_status(&res.inner).to_string()) + .bind(val) + .bind(res.task_id.to_string()) + .bind(worker_id.to_string()); + if let Err(e) = query.execute(&pool).await { + w.emit(Event::Error(Box::new(MysqlPollError::AckError(e)))); + } + } + (Err(error), Some(_)) => { + w.emit(Event::Error(Box::new(MysqlPollError::CodecError(error)))); + } + _ => { + unreachable!( + "Attempted to ACK without a worker attached. This is a bug, File it on the repo" + ); + } + } } + apalis_core::sleep(config.poll_interval).await; } }; - + let w = worker.clone(); let heartbeat = async move { loop { let now = Utc::now(); - if let Err(e) = hb_storage.keep_alive_at::(&worker, now).await { - error!("Heartbeat failed: {e}"); + if let Err(e) = hb_storage.keep_alive_at::(w.id(), now).await { + w.emit(Event::Error(Box::new(MysqlPollError::KeepAliveError(e)))); } apalis_core::sleep(config.keep_alive).await; } }; - Poller::new(stream, async { - futures::join!(heartbeat, ack_heartbeat); - }) + let w = worker.clone(); + let reenqueue_beat = async move { + loop { + let dead_since = Utc::now() + - chrono::Duration::from_std(config.reenqueue_orphaned_after) + .expect("Could not calculate dead since"); + if let Err(e) = requeue_storage + .reenqueue_orphaned( + config + .buffer_size + .try_into() + .expect("Could not convert usize to i32"), + dead_since, + ) + .await + { + w.emit(Event::Error(Box::new( + MysqlPollError::ReenqueueOrphanedError(e), + ))); + } + apalis_core::sleep(config.poll_interval).await; + } + }; + Poller::new_with_layer( + stream, + async { + futures::join!(heartbeat, ack_heartbeat, reenqueue_beat); + }, + layer, + ) } } -impl Ack for MysqlStorage { - type Acknowledger = TaskId; - type Error = sqlx::Error; - async fn ack( - &self, - worker_id: &WorkerId, - task_id: &Self::Acknowledger, - ) -> Result<(), sqlx::Error> { +impl Ack for MysqlStorage +where + T: Sync + Send, + Res: Serialize + Send + 'static + Sync, + C: Codec + Send, + C::Error: Debug, +{ + type Context = SqlContext; + type AckError = sqlx::Error; + async fn ack(&mut self, ctx: &Self::Context, res: &Response) -> Result<(), sqlx::Error> { self.ack_notify - .notify((worker_id.clone(), task_id.clone())) + .notify(( + ctx.clone(), + res.map(|res| C::encode(res).expect("Could not encode result")), + )) .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::BrokenPipe, e)))?; Ok(()) } } -impl MysqlStorage { +impl MysqlStorage { /// Kill a job pub async fn kill(&mut self, worker_id: &WorkerId, job_id: &TaskId) -> Result<(), sqlx::Error> { let pool = self.pool.clone(); @@ -463,44 +555,116 @@ impl MysqlStorage { } /// Readd jobs that are abandoned to the queue - pub async fn reenqueue_orphaned(&self, timeout: i64) -> Result { - let job_type = T::NAME; + pub async fn reenqueue_orphaned( + &self, + count: i32, + dead_since: DateTime, + ) -> Result { + let job_type = self.config.namespace.clone(); let mut tx = self.pool.acquire().await?; let query = r#"Update jobs INNER JOIN ( SELECT workers.id as worker_id, jobs.id as job_id from workers INNER JOIN jobs ON jobs.lock_by = workers.id WHERE jobs.status = "Running" AND workers.last_seen < ? AND workers.worker_type = ? ORDER BY lock_at ASC LIMIT ?) as workers ON jobs.lock_by = workers.worker_id AND jobs.id = workers.job_id SET status = "Pending", done_at = NULL, lock_by = NULL, lock_at = NULL, last_error ="Job was abandoned";"#; - let now = Utc::now().timestamp(); - let seconds_ago = DateTime::from_timestamp(now - timeout, 0).ok_or(sqlx::Error::Io( - io::Error::new(io::ErrorKind::InvalidData, "Invalid timeout"), - ))?; sqlx::query(query) - .bind(seconds_ago) + .bind(dead_since) .bind(job_type) - .bind::( - self.config - .buffer_size - .try_into() - .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?, - ) + .bind(count) .execute(&mut *tx) .await?; Ok(true) } } +impl BackendExpose + for MysqlStorage +{ + type Request = Request>; + type Error = SqlError; + async fn stats(&self) -> Result { + let fetch_query = "SELECT + COUNT(CASE WHEN status = 'Pending' THEN 1 END) AS pending, + COUNT(CASE WHEN status = 'Running' THEN 1 END) AS running, + COUNT(CASE WHEN status = 'Done' THEN 1 END) AS done, + COUNT(CASE WHEN status = 'Retry' THEN 1 END) AS retry, + COUNT(CASE WHEN status = 'Failed' THEN 1 END) AS failed, + COUNT(CASE WHEN status = 'Killed' THEN 1 END) AS killed + FROM jobs WHERE job_type = ?"; + + let res: (i64, i64, i64, i64, i64, i64) = sqlx::query_as(fetch_query) + .bind(self.get_config().namespace()) + .fetch_one(self.pool()) + .await?; + + Ok(Stat { + pending: res.0.try_into()?, + running: res.1.try_into()?, + dead: res.4.try_into()?, + failed: res.3.try_into()?, + success: res.2.try_into()?, + }) + } + + async fn list_jobs( + &self, + status: &State, + page: i32, + ) -> Result, Self::Error> { + let status = status.to_string(); + let fetch_query = "SELECT * FROM jobs WHERE status = ? AND job_type = ? ORDER BY done_at DESC, run_at DESC LIMIT 10 OFFSET ?"; + let res: Vec> = sqlx::query_as(fetch_query) + .bind(status) + .bind(self.get_config().namespace()) + .bind(((page - 1) * 10).to_string()) + .fetch_all(self.pool()) + .await?; + Ok(res + .into_iter() + .map(|j| { + let (req, ctx) = j.req.take_parts(); + let req: J = MysqlCodec::decode(req).unwrap(); + Request::new_with_ctx(req, ctx) + }) + .collect()) + } + + async fn list_workers(&self) -> Result>, Self::Error> { + let fetch_query = + "SELECT id, layers, last_seen FROM workers WHERE worker_type = ? ORDER BY last_seen DESC LIMIT 20 OFFSET ?"; + let res: Vec<(String, String, i64)> = sqlx::query_as(fetch_query) + .bind(self.get_config().namespace()) + .bind(0) + .fetch_all(self.pool()) + .await?; + Ok(res + .into_iter() + .map(|w| Worker::new(WorkerId::new(w.0), WorkerState::new::(w.1))) + .collect()) + } +} + #[cfg(test)] mod tests { - use crate::context::State; + use crate::sql_storage_tests; use super::*; + + use apalis_core::test_utils::DummyService; use email_service::Email; use futures::StreamExt; + use apalis_core::generic_storage_test; + use apalis_core::test_utils::apalis_test_service_fn; + use apalis_core::test_utils::TestWrapper; + + generic_storage_test!(setup); + + sql_storage_tests!(setup::, MysqlStorage, Email); + /// migrate DB and return a storage instance. - async fn setup() -> MysqlStorage { + async fn setup() -> MysqlStorage { let db_url = &std::env::var("DATABASE_URL").expect("No DATABASE_URL is specified"); // Because connections cannot be shared across async runtime // (different runtimes are created for each test), @@ -509,8 +673,8 @@ mod tests { MysqlStorage::setup(&pool) .await .expect("failed to migrate DB"); - let storage = MysqlStorage::new(pool); - + let mut storage = MysqlStorage::new(pool); + cleanup(&mut storage, &WorkerId::new("test-worker")).await; storage } @@ -520,9 +684,9 @@ mod tests { /// - worker identified by `worker_id` /// /// You should execute this function in the end of a test - async fn cleanup(storage: MysqlStorage, worker_id: &WorkerId) { - sqlx::query("DELETE FROM jobs WHERE lock_by = ? OR status = 'Pending'") - .bind(worker_id.to_string()) + async fn cleanup(storage: &mut MysqlStorage, worker_id: &WorkerId) { + sqlx::query("DELETE FROM jobs WHERE job_type = ?") + .bind(storage.config.namespace()) .execute(&storage.pool) .await .expect("failed to delete jobs"); @@ -533,9 +697,14 @@ mod tests { .expect("failed to delete worker"); } - async fn consume_one(storage: &MysqlStorage, worker_id: &WorkerId) -> Request { - let storage = storage.clone(); - let mut stream = storage.stream_jobs(worker_id, std::time::Duration::from_secs(10), 1); + async fn consume_one( + storage: &mut MysqlStorage, + worker_id: &WorkerId, + ) -> Request { + let mut stream = + storage + .clone() + .stream_jobs(worker_id, std::time::Duration::from_secs(10), 1); stream .next() .await @@ -552,8 +721,6 @@ mod tests { } } - struct DummyService {} - async fn register_worker_at( storage: &mut MysqlStorage, last_seen: DateTime, @@ -573,17 +740,16 @@ mod tests { register_worker_at(storage, now).await } - async fn push_email(storage: &mut S, email: Email) - where - S: Storage, - { + async fn push_email(storage: &mut MysqlStorage, email: Email) { storage.push(email).await.expect("failed to push a job"); } - async fn get_job(storage: &mut S, job_id: &TaskId) -> Request - where - S: Storage, - { + async fn get_job( + storage: &mut MysqlStorage, + job_id: &TaskId, + ) -> Request { + // add a slight delay to allow background actions like ack to complete + apalis_core::sleep(Duration::from_secs(1)).await; storage .fetch_by_id(job_id) .await @@ -599,39 +765,11 @@ mod tests { let worker_id = register_worker(&mut storage).await; let job = consume_one(&mut storage, &worker_id).await; - let ctx = job.get::().unwrap(); - // TODO: Fix assertions - // assert_eq!(*ctx.status(), State::Running); - // assert_eq!(*ctx.lock_by(), Some(worker_id.clone())); - // assert!(ctx.lock_at().is_some()); - - cleanup(storage, &worker_id).await; - } - - #[tokio::test] - async fn test_acknowledge_job() { - let mut storage = setup().await; - push_email(&mut storage, example_email()).await; - - let worker_id = register_worker(&mut storage).await; - - let job = consume_one(&mut storage, &worker_id).await; - let ctx = job.get::().unwrap(); - let job_id = ctx.id(); - - storage - .ack(&worker_id, job_id) - .await - .expect("failed to acknowledge the job"); - - let job = get_job(&mut storage, job_id).await; - let ctx = job.get::().unwrap(); - + let ctx = job.parts.context; // TODO: Fix assertions - // assert_eq!(*ctx.status(), State::Done); - // assert!(ctx.done_at().is_some()); - - cleanup(storage, &worker_id).await; + assert_eq!(*ctx.status(), State::Running); + assert_eq!(*ctx.lock_by(), Some(worker_id.clone())); + assert!(ctx.lock_at().is_some()); } #[tokio::test] @@ -644,8 +782,7 @@ mod tests { let job = consume_one(&mut storage, &worker_id).await; - let ctx = job.get::().unwrap(); - let job_id = ctx.id(); + let job_id = &job.parts.task_id; storage .kill(&worker_id, job_id) @@ -653,12 +790,10 @@ mod tests { .expect("failed to kill job"); let job = get_job(&mut storage, job_id).await; - let ctx = job.get::().unwrap(); + let ctx = job.parts.context; // TODO: Fix assertions - // assert_eq!(*ctx.status(), State::Killed); - // assert!(ctx.done_at().is_some()); - - cleanup(storage, &worker_id).await; + assert_eq!(*ctx.status(), State::Killed); + assert!(ctx.done_at().is_some()); } #[tokio::test] @@ -674,6 +809,8 @@ mod tests { // register a worker not responding since 6 minutes ago let worker_id = WorkerId::new("test-worker"); + let five_minutes_ago = Utc::now() - Duration::from_secs(5 * 60); + let six_minutes_ago = Utc::now() - Duration::from_secs(60 * 6); storage @@ -683,23 +820,28 @@ mod tests { // fetch job let job = consume_one(&mut storage, &worker_id).await; - let ctx = job.get::().unwrap(); + let ctx = job.parts.context; assert_eq!(*ctx.status(), State::Running); - storage.reenqueue_orphaned(300).await.unwrap(); + storage + .reenqueue_orphaned(1, five_minutes_ago) + .await + .unwrap(); // then, the job status has changed to Pending - let job = storage.fetch_by_id(ctx.id()).await.unwrap().unwrap(); - let context = job.get::().unwrap(); - // TODO: Fix assertions - // assert_eq!(*context.status(), State::Pending); - // assert!(context.lock_by().is_none()); - // assert!(context.lock_at().is_none()); - // assert!(context.done_at().is_none()); - // assert_eq!(*context.last_error(), Some("Job was abandoned".to_string())); - - cleanup(storage, &worker_id).await; + let job = storage + .fetch_by_id(&job.parts.task_id) + .await + .unwrap() + .unwrap(); + let ctx = job.parts.context; + assert_eq!(*ctx.status(), State::Pending); + assert!(ctx.done_at().is_none()); + assert!(ctx.lock_by().is_none()); + assert!(ctx.lock_at().is_none()); + assert_eq!(*ctx.last_error(), Some("Job was abandoned".to_owned())); + assert_eq!(job.parts.attempt.current(), 1); } #[tokio::test] @@ -714,6 +856,7 @@ mod tests { // register a worker responding at 4 minutes ago let four_minutes_ago = Utc::now() - Duration::from_secs(4 * 60); + let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60); let worker_id = WorkerId::new("test-worker"); storage @@ -723,20 +866,27 @@ mod tests { // fetch job let job = consume_one(&mut storage, &worker_id).await; - let ctx = job.get::().unwrap(); + let ctx = &job.parts.context; assert_eq!(*ctx.status(), State::Running); // heartbeat with ReenqueueOrpharned pulse - storage.reenqueue_orphaned(300).await.unwrap(); + storage + .reenqueue_orphaned(1, six_minutes_ago) + .await + .unwrap(); // then, the job status is not changed - let job = storage.fetch_by_id(ctx.id()).await.unwrap().unwrap(); - let context = job.get::().unwrap(); - // TODO: Fix assertions - // assert_eq!(*context.status(), State::Running); - // assert_eq!(*context.lock_by(), Some(worker_id.clone())); - - cleanup(storage, &worker_id).await; + let job = storage + .fetch_by_id(&job.parts.task_id) + .await + .unwrap() + .unwrap(); + let ctx = job.parts.context; + assert_eq!(*ctx.status(), State::Running); + assert_eq!(*ctx.lock_by(), Some(worker_id)); + assert!(ctx.lock_at().is_some()); + assert_eq!(*ctx.last_error(), None); + assert_eq!(job.parts.attempt.current(), 1); } } diff --git a/packages/apalis-sql/src/postgres.rs b/packages/apalis-sql/src/postgres.rs index 8d68865..26664f2 100644 --- a/packages/apalis-sql/src/postgres.rs +++ b/packages/apalis-sql/src/postgres.rs @@ -27,11 +27,11 @@ //! // let query = "Select apalis.push_job('apalis::Email', json_build_object('subject', 'Test apalis', 'to', 'test1@example.com', 'text', 'Lorem Ipsum'));"; //! // pg.execute(query).await.unwrap(); //! -//! Monitor::::new() -//! .register_with_count(4, { +//! Monitor::new() +//! .register({ //! WorkerBuilder::new(&format!("tasty-avocado")) //! .data(0usize) -//! .source(pg) +//! .backend(pg) //! .build_fn(send_email) //! }) //! .run() @@ -39,28 +39,34 @@ //! } //! ``` use crate::context::SqlContext; -use crate::Config; +use crate::{calculate_status, Config, SqlError}; +use apalis_core::backend::{BackendExpose, Stat, WorkerState}; use apalis_core::codec::json::JsonCodec; -use apalis_core::error::Error; +use apalis_core::error::{BoxDynError, Error}; use apalis_core::layers::{Ack, AckLayer}; use apalis_core::notify::Notify; use apalis_core::poller::controller::Controller; use apalis_core::poller::stream::BackendStream; use apalis_core::poller::Poller; -use apalis_core::request::{Request, RequestStream}; -use apalis_core::storage::{Job, Storage}; +use apalis_core::request::{Parts, Request, RequestStream, State}; +use apalis_core::response::Response; +use apalis_core::storage::Storage; +use apalis_core::task::namespace::Namespace; use apalis_core::task::task_id::TaskId; -use apalis_core::worker::WorkerId; -use apalis_core::{Backend, Codec}; -use async_stream::try_stream; -use futures::{FutureExt, Stream}; -use futures::{StreamExt, TryStreamExt}; +use apalis_core::worker::{Context, Event, Worker, WorkerId}; +use apalis_core::{backend::Backend, codec::Codec}; +use chrono::{DateTime, Utc}; +use futures::channel::mpsc; +use futures::StreamExt; +use futures::{select, stream, SinkExt}; use log::error; use serde::{de::DeserializeOwned, Serialize}; +use serde_json::Value; use sqlx::postgres::PgListener; -use sqlx::types::chrono::{DateTime, Utc}; use sqlx::{Pool, Postgres, Row}; +use std::any::type_name; use std::convert::TryInto; +use std::fmt::Debug; use std::sync::Arc; use std::{fmt, io}; use std::{marker::PhantomData, time::Duration}; @@ -73,107 +79,204 @@ use crate::from_row::SqlRequest; /// Represents a [Storage] that persists to Postgres // #[derive(Debug)] -pub struct PostgresStorage { +pub struct PostgresStorage> +where + C: Codec, +{ pool: PgPool, job_type: PhantomData, - codec: Arc< - Box< - dyn Codec - + Sync - + Send - + 'static, - >, - >, + codec: PhantomData, config: Config, controller: Controller, - ack_notify: Notify<(WorkerId, TaskId)>, + ack_notify: Notify<(SqlContext, Response)>, + subscription: Option, } -impl Clone for PostgresStorage { +impl Clone for PostgresStorage { fn clone(&self) -> Self { PostgresStorage { pool: self.pool.clone(), job_type: PhantomData, - codec: self.codec.clone(), + codec: PhantomData, config: self.config.clone(), controller: self.controller.clone(), ack_notify: self.ack_notify.clone(), + subscription: self.subscription.clone(), } } } -impl fmt::Debug for PostgresStorage { +impl fmt::Debug for PostgresStorage { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("PostgresStorage") .field("pool", &self.pool) .field("job_type", &"PhantomData") .field("controller", &self.controller) .field("config", &self.config) - .field( - "codec", - &"Arc + Sync + Send + 'static>>", - ) - .field("ack_notify", &self.ack_notify) + .field("codec", &std::any::type_name::()) + // .field("ack_notify", &std::any::type_name_of_val(&self.ack_notify)) .finish() } } -impl Backend> - for PostgresStorage -{ - type Stream = BackendStream>>; +/// Errors that can occur while polling a PostgreSQL database. +#[derive(thiserror::Error, Debug)] +pub enum PgPollError { + /// Error during task acknowledgment. + #[error("Encountered an error during ACK: `{0}`")] + AckError(sqlx::Error), - type Layer = AckLayer, T>; + /// Error while fetching the next item. + #[error("Encountered an error during FetchNext: `{0}`")] + FetchNextError(apalis_core::error::Error), - fn common_layer(&self, worker_id: WorkerId) -> Self::Layer { - AckLayer::new(self.clone(), worker_id) - } + /// Error while listening to PostgreSQL notifications. + #[error("Encountered an error during listening to PgNotification: {0}")] + PgNotificationError(apalis_core::error::Error), + + /// Error during a keep-alive heartbeat. + #[error("Encountered an error during KeepAlive heartbeat: `{0}`")] + KeepAliveError(sqlx::Error), + + /// Error during re-enqueuing orphaned tasks. + #[error("Encountered an error during ReenqueueOrphaned heartbeat: `{0}`")] + ReenqueueOrphanedError(sqlx::Error), + + /// Error during result encoding. + #[error("Encountered an error during encoding the result: {0}")] + CodecError(BoxDynError), +} + +impl Backend, Res> for PostgresStorage +where + T: Serialize + DeserializeOwned + Sync + Send + Unpin + 'static, + C: Codec + Send + 'static, + C::Error: std::error::Error + 'static + Send + Sync, +{ + type Stream = BackendStream>>; + + type Layer = AckLayer, T, SqlContext, Res>; - fn poll(mut self, worker: WorkerId) -> Poller { + fn poll(mut self, worker: &Worker) -> Poller { + let layer = AckLayer::new(self.clone()); + let subscription = self.subscription.clone(); let config = self.config.clone(); let controller = self.controller.clone(); - let stream = self - .stream_jobs(&worker, config.poll_interval, config.buffer_size) - .map_err(|e| Error::SourceError(Box::new(e))); - let stream = BackendStream::new(stream.boxed(), controller); + let (mut tx, rx) = mpsc::channel(self.config.buffer_size); let ack_notify = self.ack_notify.clone(); let pool = self.pool.clone(); - let ack_heartbeat = async move { - while let Some(ids) = ack_notify - .clone() - .ready_chunks(config.buffer_size) - .next() - .await - { - let worker_ids: Vec = ids.iter().map(|c| c.0.to_string()).collect(); - let task_ids: Vec = ids.iter().map(|c| c.1.to_string()).collect(); - - let query = - "UPDATE apalis.jobs SET status = 'Done', done_at = now() WHERE id = ANY($1::text[]) AND lock_by = ANY($2::text[])"; - if let Err(e) = sqlx::query(query) - .bind(task_ids) - .bind(worker_ids) - .execute(&pool) + let worker = worker.clone(); + let heartbeat = async move { + let mut keep_alive_stm = apalis_core::interval::interval(config.keep_alive).fuse(); + let mut reenqueue_orphaned_stm = + apalis_core::interval::interval(config.poll_interval).fuse(); + let mut ack_stream = ack_notify.clone().ready_chunks(config.buffer_size).fuse(); + + let mut poll_next_stm = apalis_core::interval::interval(config.poll_interval).fuse(); + + let mut pg_notification = subscription + .map(|stm| stm.notify.boxed().fuse()) + .unwrap_or(stream::iter(vec![]).boxed().fuse()); + + async fn fetch_next_batch< + T: Unpin + DeserializeOwned + Send + 'static, + C: Codec, + >( + storage: &mut PostgresStorage, + worker: &WorkerId, + tx: &mut mpsc::Sender>, Error>>, + ) -> Result<(), Error> { + let res = storage + .fetch_next(worker) .await - { - error!("Ack failed: {e}"); + .map_err(|e| Error::SourceError(Arc::new(Box::new(e))))?; + for job in res { + tx.send(Ok(Some(job))) + .await + .map_err(|e| Error::SourceError(Arc::new(Box::new(e))))?; } - apalis_core::sleep(config.poll_interval).await; + Ok(()) } - }; - let heartbeat = async move { + + if let Err(e) = self + .keep_alive_at::(worker.id(), Utc::now().timestamp()) + .await + { + worker.emit(Event::Error(Box::new(PgPollError::KeepAliveError(e)))); + } + loop { - let now: i64 = Utc::now().timestamp(); - if let Err(e) = self.keep_alive_at::(&worker, now).await { - error!("Heartbeat failed: {e}") - } - apalis_core::sleep(config.keep_alive).await; + select! { + _ = keep_alive_stm.next() => { + if let Err(e) = self.keep_alive_at::(worker.id(), Utc::now().timestamp()).await { + worker.emit(Event::Error(Box::new(PgPollError::KeepAliveError(e)))); + } + } + ids = ack_stream.next() => { + if let Some(ids) = ids { + let ack_ids: Vec<(String, String, String, String, u64)> = ids.iter().map(|(_ctx, res)| { + (res.task_id.to_string(), worker.id().to_string(), serde_json::to_string(&res.inner.as_ref().map_err(|e| e.to_string())).expect("Could not convert response to json"), calculate_status(&res.inner).to_string(), (res.attempt.current() + 1) as u64 ) + }).collect(); + let query = + "UPDATE apalis.jobs + SET status = Q.status, + done_at = now(), + lock_by = Q.worker_id, + last_error = Q.result, + attempts = Q.attempts + FROM ( + SELECT (value->>0)::text as id, + (value->>1)::text as worker_id, + (value->>2)::text as result, + (value->>3)::text as status, + (value->>4)::int as attempts + FROM json_array_elements($1::json) + ) Q + WHERE apalis.jobs.id = Q.id; + "; + let codec_res = C::encode(&ack_ids); + match codec_res { + Ok(val) => { + if let Err(e) = sqlx::query(query) + .bind(val) + .execute(&pool) + .await + { + worker.emit(Event::Error(Box::new(PgPollError::AckError(e)))); + } + } + Err(e) => { + worker.emit(Event::Error(Box::new(PgPollError::CodecError(e.into())))); + } + } + + } + } + _ = poll_next_stm.next() => { + if let Err(e) = fetch_next_batch(&mut self, worker.id(), &mut tx).await { + worker.emit(Event::Error(Box::new(PgPollError::FetchNextError(e)))); + + } + } + _ = pg_notification.next() => { + if let Err(e) = fetch_next_batch(&mut self, worker.id(), &mut tx).await { + worker.emit(Event::Error(Box::new(PgPollError::PgNotificationError(e)))); + + } + } + _ = reenqueue_orphaned_stm.next() => { + let dead_since = Utc::now() + - chrono::Duration::from_std(config.reenqueue_orphaned_after).expect("could not build dead_since"); + if let Err(e) = self.reenqueue_orphaned((config.buffer_size * 10) as i32, dead_since).await { + worker.emit(Event::Error(Box::new(PgPollError::ReenqueueOrphanedError(e)))); + } + } + + + }; } - } - .boxed(); - Poller::new(stream, async { - futures::join!(heartbeat, ack_heartbeat); - }) + }; + Poller::new_with_layer(BackendStream::new(rx.boxed(), controller), heartbeat, layer) } } @@ -192,20 +295,21 @@ impl PostgresStorage<()> { } } -impl PostgresStorage { +impl PostgresStorage { /// New Storage from [PgPool] pub fn new(pool: PgPool) -> Self { - Self::new_with_config(pool, Config::default()) + Self::new_with_config(pool, Config::new(type_name::())) } /// New Storage from [PgPool] and custom config pub fn new_with_config(pool: PgPool, config: Config) -> Self { Self { pool, job_type: PhantomData, - codec: Arc::new(Box::new(JsonCodec)), + codec: PhantomData, config, controller: Controller::new(), ack_notify: Notify::new(), + subscription: None, } } @@ -213,6 +317,43 @@ impl PostgresStorage { pub fn pool(&self) -> &Pool { &self.pool } + + /// Expose the config + pub fn config(&self) -> &Config { + &self.config + } +} + +impl PostgresStorage { + /// Expose the codec + pub fn codec(&self) -> &PhantomData { + &self.codec + } + + async fn keep_alive_at( + &mut self, + worker_id: &WorkerId, + last_seen: Timestamp, + ) -> Result<(), sqlx::Error> { + let last_seen = DateTime::from_timestamp(last_seen, 0).ok_or(sqlx::Error::Io( + io::Error::new(io::ErrorKind::InvalidInput, "Invalid Timestamp"), + ))?; + let worker_type = self.config.namespace.clone(); + let storage_name = std::any::type_name::(); + let query = "INSERT INTO apalis.workers (id, worker_type, storage_name, layers, last_seen) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (id) DO + UPDATE SET last_seen = EXCLUDED.last_seen"; + sqlx::query(query) + .bind(worker_id.to_string()) + .bind(worker_type) + .bind(storage_name) + .bind(std::any::type_name::()) + .bind(last_seen) + .execute(&self.pool) + .await?; + Ok(()) + } } /// A listener that listens to Postgres notifications @@ -240,12 +381,22 @@ impl PgListen { }) } + /// Add a new subscription with a storage + pub fn subscribe_with(&mut self, storage: &mut PostgresStorage) { + let sub = PgSubscription { + notify: Notify::new(), + }; + self.subscriptions + .push((storage.config.namespace.to_owned(), sub.clone())); + storage.subscription = Some(sub) + } + /// Add a new subscription - pub fn subscribe(&mut self) -> PgSubscription { + pub fn subscribe(&mut self, namespace: &str) -> PgSubscription { let sub = PgSubscription { notify: Notify::new(), }; - self.subscriptions.push((T::NAME.to_owned(), sub.clone())); + self.subscriptions.push((namespace.to_owned(), sub.clone())); sub } /// Start listening to jobs @@ -264,79 +415,54 @@ impl PgListen { } } -impl PostgresStorage { - fn stream_jobs( - &self, - worker_id: &WorkerId, - interval: Duration, - buffer_size: usize, - ) -> impl Stream>, sqlx::Error>> { - let pool = self.pool.clone(); - let worker_id = worker_id.clone(); - let codec = self.codec.clone(); - try_stream! { - loop { - // Ideally wait for a job or a tick - apalis_core::sleep(interval).await; - let tx = pool.clone(); - let job_type = T::NAME; - let fetch_query = "Select * from apalis.get_jobs($1, $2, $3);"; - let jobs: Vec> = sqlx::query_as(fetch_query) - .bind(worker_id.to_string()) - .bind(job_type) - // https://docs.rs/sqlx/latest/sqlx/postgres/types/index.html - .bind(i32::try_from(buffer_size).map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?) - .fetch_all(&tx) - .await?; - for job in jobs { - - yield Some(Into::into(SqlRequest { - context: job.context, - req: codec.decode(&job.req).map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?, - })) - } - - } - } - .boxed() - } - - async fn keep_alive_at( +impl PostgresStorage +where + T: DeserializeOwned + Send + Unpin + 'static, + C: Codec, +{ + async fn fetch_next( &mut self, worker_id: &WorkerId, - last_seen: Timestamp, - ) -> Result<(), sqlx::Error> { - let pool = self.pool.clone(); - let last_seen = DateTime::from_timestamp(last_seen, 0).ok_or(sqlx::Error::Io( - io::Error::new(io::ErrorKind::InvalidInput, "Invalid Timestamp"), - ))?; - let worker_type = T::NAME; - let storage_name = std::any::type_name::(); - let query = "INSERT INTO apalis.workers (id, worker_type, storage_name, layers, last_seen) - VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (id) DO - UPDATE SET last_seen = EXCLUDED.last_seen"; - sqlx::query(query) + ) -> Result>, sqlx::Error> { + let config = &self.config; + let job_type = &config.namespace; + let fetch_query = "Select * from apalis.get_jobs($1, $2, $3);"; + let jobs: Vec> = sqlx::query_as(fetch_query) .bind(worker_id.to_string()) - .bind(worker_type) - .bind(storage_name) - .bind(std::any::type_name::()) - .bind(last_seen) - .execute(&pool) + .bind(job_type) + // https://docs.rs/sqlx/latest/sqlx/postgres/types/index.html + .bind( + i32::try_from(config.buffer_size) + .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?, + ) + .fetch_all(&self.pool) .await?; - Ok(()) + let jobs: Vec<_> = jobs + .into_iter() + .map(|job| { + let (req, parts) = job.req.take_parts(); + let req = C::decode(req) + .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e))) + .expect("Unable to decode"); + let mut req = Request::new_with_parts(req, parts); + req.parts.namespace = Some(Namespace(self.config.namespace.clone())); + req + }) + .collect(); + Ok(jobs) } } -impl Storage for PostgresStorage +impl Storage for PostgresStorage where - T: Job + Serialize + DeserializeOwned + Send + 'static + Unpin + Sync, + Req: Serialize + DeserializeOwned + Send + 'static + Unpin + Sync, + C: Codec + Send + 'static, { - type Job = T; + type Job = Req; type Error = sqlx::Error; - type Identifier = TaskId; + type Context = SqlContext; /// Push a job to Postgres [Storage] /// @@ -345,88 +471,87 @@ where /// ```sql /// Select apalis.push_job(job_type::text, job::json); /// ``` - async fn push(&mut self, job: Self::Job) -> Result { - let id = TaskId::new(); - let query = "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, 25, NOW() , NULL, NULL, NULL, NULL)"; - let pool = self.pool.clone(); - let job = self - .codec - .encode(&job) + async fn push_request( + &mut self, + req: Request, + ) -> Result, sqlx::Error> { + let query = "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, $4, NOW() , NULL, NULL, NULL, NULL)"; + + let args = C::encode(&req.args) .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?; - let job_type = T::NAME; + let job_type = self.config.namespace.clone(); sqlx::query(query) - .bind(job) - .bind(id.to_string()) - .bind(job_type.to_string()) - .execute(&pool) + .bind(args) + .bind(req.parts.task_id.to_string()) + .bind(&job_type) + .bind(&req.parts.context.max_attempts()) + .execute(&self.pool) .await?; - Ok(id) + Ok(req.parts) } - async fn schedule(&mut self, job: Self::Job, on: Timestamp) -> Result { + async fn schedule_request( + &mut self, + req: Request, + on: Timestamp, + ) -> Result, sqlx::Error> { let query = - "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, 25, $4, NULL, NULL, NULL, NULL)"; - let pool = self.pool.clone(); - let id = TaskId::new(); + "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, $4, $5, NULL, NULL, NULL, NULL)"; + let task_id = req.parts.task_id.to_string(); + let parts = req.parts; let on = DateTime::from_timestamp(on, 0); - let job = self - .codec - .encode(&job) + let job = C::encode(&req.args) .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?; - let job_type = T::NAME; + let job_type = self.config.namespace.clone(); sqlx::query(query) .bind(job) - .bind(id.to_string()) + .bind(task_id) .bind(job_type) + .bind(&parts.context.max_attempts()) .bind(on) - .execute(&pool) + .execute(&self.pool) .await?; - Ok(id) + Ok(parts) } async fn fetch_by_id( - &self, + &mut self, job_id: &TaskId, - ) -> Result>, sqlx::Error> { - let pool = self.pool.clone(); - - let fetch_query = "SELECT * FROM apalis.jobs WHERE id = $1"; + ) -> Result>, sqlx::Error> { + let fetch_query = "SELECT * FROM apalis.jobs WHERE id = $1 LIMIT 1"; let res: Option> = sqlx::query_as(fetch_query) .bind(job_id.to_string()) - .fetch_optional(&pool) + .fetch_optional(&self.pool) .await?; + match res { None => Ok(None), - Some(c) => Ok(Some( - SqlRequest { - context: c.context, - req: self.codec.decode(&c.req).map_err(|e| { - sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)) - })?, - } - .into(), - )), + Some(job) => Ok(Some({ + let (req, parts) = job.req.take_parts(); + let args = C::decode(req) + .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?; + + let mut req: Request = Request::new_with_parts(args, parts); + req.parts.namespace = Some(Namespace(self.config.namespace.clone())); + req + })), } } - async fn len(&self) -> Result { - let pool = self.pool.clone(); + async fn len(&mut self) -> Result { let query = "Select Count(*) as count from apalis.jobs where status='Pending'"; - let record = sqlx::query(query).fetch_one(&pool).await?; + let record = sqlx::query(query).fetch_one(&self.pool).await?; record.try_get("count") } - async fn reschedule(&mut self, job: Request, wait: Duration) -> Result<(), sqlx::Error> { - let pool = self.pool.clone(); - let ctx = job - .get::() - .ok_or(sqlx::Error::Io(io::Error::new( - io::ErrorKind::InvalidData, - "Missing SqlContext", - )))?; - let job_id = ctx.id(); + async fn reschedule( + &mut self, + job: Request, + wait: Duration, + ) -> Result<(), sqlx::Error> { + let job_id = job.parts.task_id; let on = Utc::now() + wait; - let mut tx = pool.acquire().await?; + let mut tx = self.pool.acquire().await?; let query = "UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = $2 WHERE id = $1"; @@ -438,18 +563,13 @@ where Ok(()) } - async fn update(&self, job: Request) -> Result<(), sqlx::Error> { - let pool = self.pool.clone(); - let ctx = job - .get::() - .ok_or(sqlx::Error::Io(io::Error::new( - io::ErrorKind::InvalidData, - "Missing SqlContext", - )))?; - let job_id = ctx.id(); + async fn update(&mut self, job: Request) -> Result<(), sqlx::Error> { + let ctx = job.parts.context; + let job_id = job.parts.task_id; let status = ctx.status().to_string(); - let attempts: i32 = ctx - .attempts() + let attempts: i32 = job + .parts + .attempt .current() .try_into() .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?; @@ -458,7 +578,7 @@ where let lock_at = *ctx.lock_at(); let last_error = ctx.last_error().clone(); - let mut tx = pool.acquire().await?; + let mut tx = self.pool.acquire().await?; let query = "UPDATE apalis.jobs SET status = $1, attempts = $2, done_at = $3, lock_by = $4, lock_at = $5, last_error = $6 WHERE id = $7"; sqlx::query(query) @@ -474,44 +594,48 @@ where Ok(()) } - async fn is_empty(&self) -> Result { + async fn is_empty(&mut self) -> Result { Ok(self.len().await? == 0) } - async fn vacuum(&self) -> Result { - let pool = self.pool.clone(); + async fn vacuum(&mut self) -> Result { let query = "Delete from apalis.jobs where status='Done'"; - let record = sqlx::query(query).execute(&pool).await?; + let record = sqlx::query(query).execute(&self.pool).await?; Ok(record.rows_affected().try_into().unwrap_or_default()) } } -impl Ack for PostgresStorage { - type Acknowledger = TaskId; - type Error = sqlx::Error; - async fn ack( - &self, - worker_id: &WorkerId, - task_id: &Self::Acknowledger, - ) -> Result<(), sqlx::Error> { +impl Ack for PostgresStorage +where + T: Sync + Send, + Res: Serialize + Sync + Clone, + C: Codec + Send, +{ + type Context = SqlContext; + type AckError = sqlx::Error; + async fn ack(&mut self, ctx: &Self::Context, res: &Response) -> Result<(), sqlx::Error> { + let res = res.clone().map(|r| { + C::encode(r) + .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::Interrupted, e))) + .expect("Could not encode result") + }); + self.ack_notify - .notify((worker_id.clone(), task_id.clone())) + .notify((ctx.clone(), res)) .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::Interrupted, e)))?; Ok(()) } } -impl PostgresStorage { +impl PostgresStorage { /// Kill a job pub async fn kill( &mut self, worker_id: &WorkerId, task_id: &TaskId, ) -> Result<(), sqlx::Error> { - let pool = self.pool.clone(); - - let mut tx = pool.acquire().await?; + let mut tx = self.pool.acquire().await?; let query = "UPDATE apalis.jobs SET status = 'Killed', done_at = now() WHERE id = $1 AND lock_by = $2"; sqlx::query(query) @@ -523,15 +647,13 @@ impl PostgresStorage { } /// Puts the job instantly back into the queue - /// Another [Worker] may consume + /// Another Worker may consume pub async fn retry( &mut self, worker_id: &WorkerId, task_id: &TaskId, ) -> Result<(), sqlx::Error> { - let pool = self.pool.clone(); - - let mut tx = pool.acquire().await?; + let mut tx = self.pool.acquire().await?; let query = "UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = $1 AND lock_by = $2"; sqlx::query(query) @@ -543,60 +665,146 @@ impl PostgresStorage { } /// Reenqueue jobs that have been abandoned by their workers - pub async fn reenqueue_orphaned(&self, count: i32) -> Result<(), sqlx::Error> - where - T: Job, - { - let job_type = T::NAME; + pub async fn reenqueue_orphaned( + &mut self, + count: i32, + dead_since: DateTime, + ) -> Result<(), sqlx::Error> { + let job_type = self.config.namespace.clone(); let mut tx = self.pool.acquire().await?; - let query = "Update apalis.jobs - SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, last_error ='Job was abandoned' - WHERE id in - (SELECT jobs.id from apalis.jobs INNER join apalis.workers ON lock_by = workers.id - WHERE status= 'Running' AND workers.last_seen < (NOW() - INTERVAL '300 seconds') - AND workers.worker_type = $1 ORDER BY lock_at ASC LIMIT $2);"; + let query = "UPDATE apalis.jobs + SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, last_error = 'Job was abandoned' + WHERE id IN + (SELECT jobs.id FROM apalis.jobs INNER JOIN apalis.workers ON lock_by = workers.id + WHERE status = 'Running' + AND workers.last_seen < ($3::timestamp) + AND workers.worker_type = $1 + ORDER BY lock_at ASC + LIMIT $2);"; + sqlx::query(query) .bind(job_type) .bind(count) + .bind(dead_since) .execute(&mut *tx) .await?; Ok(()) } } + +impl BackendExpose + for PostgresStorage +{ + type Request = Request>; + type Error = SqlError; + async fn stats(&self) -> Result { + let fetch_query = "SELECT + COUNT(1) FILTER (WHERE status = 'Pending') AS pending, + COUNT(1) FILTER (WHERE status = 'Running') AS running, + COUNT(1) FILTER (WHERE status = 'Done') AS done, + COUNT(1) FILTER (WHERE status = 'Retry') AS retry, + COUNT(1) FILTER (WHERE status = 'Failed') AS failed, + COUNT(1) FILTER (WHERE status = 'Killed') AS killed + FROM apalis.jobs WHERE job_type = $1"; + + let res: (i64, i64, i64, i64, i64, i64) = sqlx::query_as(fetch_query) + .bind(self.config().namespace()) + .fetch_one(self.pool()) + .await?; + + Ok(Stat { + pending: res.0.try_into()?, + running: res.1.try_into()?, + dead: res.4.try_into()?, + failed: res.3.try_into()?, + success: res.2.try_into()?, + }) + } + + async fn list_jobs( + &self, + status: &State, + page: i32, + ) -> Result, Self::Error> { + let status = status.to_string(); + let fetch_query = "SELECT * FROM apalis.jobs WHERE status = $1 AND job_type = $2 ORDER BY done_at DESC, run_at DESC LIMIT 10 OFFSET $3"; + let res: Vec> = sqlx::query_as(fetch_query) + .bind(status) + .bind(self.config().namespace()) + .bind(((page - 1) * 10).to_string()) + .fetch_all(self.pool()) + .await?; + Ok(res + .into_iter() + .map(|j| { + let (req, ctx) = j.req.take_parts(); + let req = JsonCodec::::decode(req).unwrap(); + Request::new_with_ctx(req, ctx) + }) + .collect()) + } + + async fn list_workers(&self) -> Result>, Self::Error> { + let fetch_query = + "SELECT id, layers, last_seen FROM apalis.workers WHERE worker_type = $1 ORDER BY last_seen DESC LIMIT 20 OFFSET $2"; + let res: Vec<(String, String, i64)> = sqlx::query_as(fetch_query) + .bind(self.config().namespace()) + .bind(0) + .fetch_all(self.pool()) + .await?; + Ok(res + .into_iter() + .map(|w| Worker::new(WorkerId::new(w.0), WorkerState::new::(w.1))) + .collect()) + } +} + #[cfg(test)] mod tests { - use crate::context::State; + + use crate::sql_storage_tests; use super::*; + use apalis_core::test_utils::DummyService; + use chrono::Utc; use email_service::Email; - use futures::StreamExt; - use sqlx::types::chrono::Utc; + + use apalis_core::generic_storage_test; + use apalis_core::test_utils::apalis_test_service_fn; + use apalis_core::test_utils::TestWrapper; + + generic_storage_test!(setup); + + sql_storage_tests!(setup::, PostgresStorage, Email); /// migrate DB and return a storage instance. - async fn setup() -> PostgresStorage { + async fn setup() -> PostgresStorage { let db_url = &std::env::var("DATABASE_URL").expect("No DATABASE_URL is specified"); let pool = PgPool::connect(&db_url).await.unwrap(); // Because connections cannot be shared across async runtime // (different runtimes are created for each test), // we don't share the storage and tests must be run sequentially. PostgresStorage::setup(&pool).await.unwrap(); - let storage = PostgresStorage::new(pool); + let config = Config::new("apalis-ci-tests").set_buffer_size(1); + let mut storage = PostgresStorage::new_with_config(pool, config); + cleanup(&mut storage, &WorkerId::new("test-worker")).await; storage } /// rollback DB changes made by tests. /// Delete the following rows: - /// - jobs whose state is `Pending` or locked by `worker_id` + /// - jobs of the current type /// - worker identified by `worker_id` /// /// You should execute this function in the end of a test - async fn cleanup(storage: PostgresStorage, worker_id: &WorkerId) { + async fn cleanup(storage: &mut PostgresStorage, worker_id: &WorkerId) { let mut tx = storage .pool .acquire() .await .expect("failed to get connection"); - sqlx::query("Delete from apalis.jobs where lock_by = $1 or status = 'Pending'") + sqlx::query("Delete from apalis.jobs where job_type = $1 OR lock_by = $2") + .bind(storage.config.namespace()) .bind(worker_id.to_string()) .execute(&mut *tx) .await @@ -608,8 +816,6 @@ mod tests { .expect("failed to delete worker"); } - struct DummyService {} - fn example_email() -> Email { Email { subject: "Test Subject".to_string(), @@ -621,14 +827,9 @@ mod tests { async fn consume_one( storage: &mut PostgresStorage, worker_id: &WorkerId, - ) -> Request { - let mut stream = storage.stream_jobs(worker_id, std::time::Duration::from_secs(10), 1); - stream - .next() - .await - .expect("stream is empty") - .expect("failed to poll job") - .expect("no job is pending") + ) -> Request { + let req = storage.fetch_next(worker_id).await; + req.unwrap()[0].clone() } async fn register_worker_at( @@ -652,7 +853,12 @@ mod tests { storage.push(email).await.expect("failed to push a job"); } - async fn get_job(storage: &mut PostgresStorage, job_id: &TaskId) -> Request { + async fn get_job( + storage: &mut PostgresStorage, + job_id: &TaskId, + ) -> Request { + // add a slight delay to allow background actions like ack to complete + apalis_core::sleep(Duration::from_secs(2)).await; storage .fetch_by_id(job_id) .await @@ -668,37 +874,14 @@ mod tests { let worker_id = register_worker(&mut storage).await; let job = consume_one(&mut storage, &worker_id).await; - let ctx = job.get::().unwrap(); - assert_eq!(*ctx.status(), State::Running); - assert_eq!(*ctx.lock_by(), Some(worker_id.clone())); - // TODO: assert!(ctx.lock_at().is_some()); - - cleanup(storage, &worker_id).await; - } - - #[tokio::test] - async fn test_acknowledge_job() { - let mut storage = setup().await; - push_email(&mut storage, example_email()).await; - - let worker_id = register_worker(&mut storage).await; - - let job = consume_one(&mut storage, &worker_id).await; - let ctx = job.get::().unwrap(); - let job_id = ctx.id(); - - storage - .ack(&worker_id, job_id) - .await - .expect("failed to acknowledge the job"); + let job_id = &job.parts.task_id; + // Refresh our job let job = get_job(&mut storage, job_id).await; - let ctx = job.get::().unwrap(); - // TODO: Currently ack is done in the background - // assert_eq!(*ctx.status(), State::Done); - // assert!(ctx.done_at().is_some()); - - cleanup(storage, &worker_id).await; + let ctx = job.parts.context; + assert_eq!(*ctx.status(), State::Running); + assert_eq!(*ctx.lock_by(), Some(worker_id.clone())); + assert!(ctx.lock_at().is_some()); } #[tokio::test] @@ -710,8 +893,7 @@ mod tests { let worker_id = register_worker(&mut storage).await; let job = consume_one(&mut storage, &worker_id).await; - let ctx = job.get::().unwrap(); - let job_id = ctx.id(); + let job_id = &job.parts.task_id; storage .kill(&worker_id, job_id) @@ -719,11 +901,9 @@ mod tests { .expect("failed to kill job"); let job = get_job(&mut storage, job_id).await; - let ctx = job.get::().unwrap(); + let ctx = job.parts.context; assert_eq!(*ctx.status(), State::Killed); - // TODO: assert!(ctx.done_at().is_some()); - - cleanup(storage, &worker_id).await; + assert!(ctx.done_at().is_some()); } #[tokio::test] @@ -732,25 +912,25 @@ mod tests { push_email(&mut storage, example_email()).await; let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60); + let five_minutes_ago = Utc::now() - Duration::from_secs(5 * 60); + let worker_id = register_worker_at(&mut storage, six_minutes_ago.timestamp()).await; let job = consume_one(&mut storage, &worker_id).await; storage - .reenqueue_orphaned(5) + .reenqueue_orphaned(1, five_minutes_ago) .await .expect("failed to heartbeat"); - let ctx = job.get::().unwrap(); - let job_id = ctx.id(); + let job_id = &job.parts.task_id; let job = get_job(&mut storage, job_id).await; - let ctx = job.get::().unwrap(); + let ctx = job.parts.context; assert_eq!(*ctx.status(), State::Pending); assert!(ctx.done_at().is_none()); assert!(ctx.lock_by().is_none()); assert!(ctx.lock_at().is_none()); - assert_eq!(*ctx.last_error(), Some("Job was abandoned".to_string())); - - cleanup(storage, &worker_id).await; + assert_eq!(*ctx.last_error(), Some("Job was abandoned".to_owned())); + assert_eq!(job.parts.attempt.current(), 0); // TODO: update get_jobs to increase attempts } #[tokio::test] @@ -760,25 +940,26 @@ mod tests { push_email(&mut storage, example_email()).await; let four_minutes_ago = Utc::now() - Duration::from_secs(4 * 60); + let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60); let worker_id = register_worker_at(&mut storage, four_minutes_ago.timestamp()).await; let job = consume_one(&mut storage, &worker_id).await; - let ctx = job.get::().unwrap(); + let ctx = &job.parts.context; assert_eq!(*ctx.status(), State::Running); storage - .reenqueue_orphaned(5) + .reenqueue_orphaned(1, six_minutes_ago) .await .expect("failed to heartbeat"); - let job_id = ctx.id(); + let job_id = &job.parts.task_id; let job = get_job(&mut storage, job_id).await; - let ctx = job.get::().unwrap(); - + let ctx = job.parts.context; assert_eq!(*ctx.status(), State::Running); - assert_eq!(*ctx.lock_by(), Some(worker_id.clone())); - - cleanup(storage, &worker_id).await; + assert_eq!(*ctx.lock_by(), Some(worker_id)); + assert!(ctx.lock_at().is_some()); + assert_eq!(*ctx.last_error(), None); + assert_eq!(job.parts.attempt.current(), 0); } } diff --git a/packages/apalis-sql/src/sqlite.rs b/packages/apalis-sql/src/sqlite.rs index dfdee04..f562cc4 100644 --- a/packages/apalis-sql/src/sqlite.rs +++ b/packages/apalis-sql/src/sqlite.rs @@ -1,22 +1,26 @@ use crate::context::SqlContext; -use crate::Config; - +use crate::{calculate_status, Config, SqlError}; +use apalis_core::backend::{BackendExpose, Stat, WorkerState}; use apalis_core::codec::json::JsonCodec; use apalis_core::error::Error; use apalis_core::layers::{Ack, AckLayer}; use apalis_core::poller::controller::Controller; use apalis_core::poller::stream::BackendStream; use apalis_core::poller::Poller; -use apalis_core::request::{Request, RequestStream}; -use apalis_core::storage::{Job, Storage}; +use apalis_core::request::{Parts, Request, RequestStream, State}; +use apalis_core::response::Response; +use apalis_core::storage::Storage; +use apalis_core::task::namespace::Namespace; use apalis_core::task::task_id::TaskId; -use apalis_core::worker::WorkerId; -use apalis_core::{Backend, Codec}; +use apalis_core::worker::{Context, Event, Worker, WorkerId}; +use apalis_core::{backend::Backend, codec::Codec}; use async_stream::try_stream; +use chrono::{DateTime, Utc}; use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt}; +use log::error; use serde::{de::DeserializeOwned, Serialize}; -use sqlx::types::chrono::Utc; use sqlx::{Pool, Row, Sqlite}; +use std::any::type_name; use std::convert::TryInto; use std::sync::Arc; use std::{fmt, io}; @@ -25,41 +29,37 @@ use std::{marker::PhantomData, time::Duration}; use crate::from_row::SqlRequest; pub use sqlx::sqlite::SqlitePool; + /// Represents a [Storage] that persists to Sqlite // #[derive(Debug)] -pub struct SqliteStorage { +pub struct SqliteStorage> { pool: Pool, job_type: PhantomData, controller: Controller, config: Config, - codec: Arc + Sync + Send + 'static>>, + codec: PhantomData, } -impl fmt::Debug for SqliteStorage { +impl fmt::Debug for SqliteStorage { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("MysqlStorage") .field("pool", &self.pool) .field("job_type", &"PhantomData") .field("controller", &self.controller) .field("config", &self.config) - .field( - "codec", - &"Arc + Sync + Send + 'static>>", - ) - // .field("ack_notify", &self.ack_notify) + .field("codec", &std::any::type_name::()) .finish() } } impl Clone for SqliteStorage { fn clone(&self) -> Self { - let pool = self.pool.clone(); SqliteStorage { - pool, + pool: self.pool.clone(), job_type: PhantomData, controller: self.controller.clone(), config: self.config.clone(), - codec: self.codec.clone(), + codec: self.codec, } } } @@ -89,10 +89,16 @@ impl SqliteStorage<()> { } } -impl SqliteStorage { - /// Construct a new Storage from a pool +impl SqliteStorage { + /// Create a new instance pub fn new(pool: SqlitePool) -> Self { - Self::new_with_config(pool, Config::default()) + Self { + pool, + job_type: PhantomData, + controller: Controller::new(), + config: Config::new(type_name::()), + codec: PhantomData, + } } /// Create a new instance with a custom config @@ -102,7 +108,7 @@ impl SqliteStorage { job_type: PhantomData, controller: Controller::new(), config, - codec: Arc::new(Box::new(JsonCodec)), + codec: PhantomData, } } /// Keeps a storage notified that the worker is still alive manually @@ -111,8 +117,7 @@ impl SqliteStorage { worker_id: &WorkerId, last_seen: i64, ) -> Result<(), sqlx::Error> { - let pool = self.pool.clone(); - let worker_type = T::NAME; + let worker_type = self.config.namespace.clone(); let storage_name = std::any::type_name::(); let query = "INSERT INTO Workers (id, worker_type, storage_name, layers, last_seen) VALUES ($1, $2, $3, $4, $5) @@ -124,7 +129,7 @@ impl SqliteStorage { .bind(storage_name) .bind(std::any::type_name::()) .bind(last_seen) - .execute(&pool) + .execute(&self.pool) .await?; Ok(()) } @@ -133,42 +138,59 @@ impl SqliteStorage { pub fn pool(&self) -> &Pool { &self.pool } + + /// Get the config used by the storage + pub fn get_config(&self) -> &Config { + &self.config + } } -async fn fetch_next( - pool: Pool, +impl SqliteStorage { + /// Expose the code used + pub fn codec(&self) -> &PhantomData { + &self.codec + } +} + +async fn fetch_next( + pool: &Pool, worker_id: &WorkerId, id: String, + config: &Config, ) -> Result>, sqlx::Error> { let now: i64 = Utc::now().timestamp(); - let update_query = "UPDATE Jobs SET status = 'Running', lock_by = ?2, lock_at = ?3 WHERE id = ?1 AND job_type = ?4 AND status = 'Pending' AND lock_by IS NULL; Select * from Jobs where id = ?1 AND lock_by = ?2 AND job_type = ?4"; + let update_query = "UPDATE Jobs SET status = 'Running', lock_by = ?2, lock_at = ?3, attempts = attempts + 1 WHERE id = ?1 AND job_type = ?4 AND status = 'Pending' AND lock_by IS NULL; Select * from Jobs where id = ?1 AND lock_by = ?2 AND job_type = ?4"; let job: Option> = sqlx::query_as(update_query) .bind(id.to_string()) .bind(worker_id.to_string()) .bind(now) - .bind(T::NAME) - .fetch_optional(&pool) + .bind(config.namespace.clone()) + .fetch_optional(pool) .await?; Ok(job) } -impl SqliteStorage { +impl SqliteStorage +where + T: DeserializeOwned + Send + Unpin, + C: Codec, +{ fn stream_jobs( &self, worker_id: &WorkerId, interval: Duration, buffer_size: usize, - ) -> impl Stream>, sqlx::Error>> { + ) -> impl Stream>, sqlx::Error>> { let pool = self.pool.clone(); let worker_id = worker_id.clone(); - let codec = self.codec.clone(); + let config = self.config.clone(); + let namespace = Namespace(self.config.namespace.clone()); try_stream! { loop { - apalis_core::sleep(interval).await; let tx = pool.clone(); let mut tx = tx.acquire().await?; - let job_type = T::NAME; + let job_type = &config.namespace; let fetch_query = "SELECT id FROM Jobs WHERE (status = 'Pending' OR (status = 'Failed' AND attempts < max_attempts)) AND run_at < ?1 AND job_type = ?2 LIMIT ?3"; let now: i64 = Utc::now().timestamp(); @@ -179,121 +201,119 @@ impl SqliteStorage { .fetch_all(&mut *tx) .await?; for id in ids { - let res = fetch_next::(pool.clone(), &worker_id, id.0).await?; + let res = fetch_next(&pool, &worker_id, id.0, &config).await?; yield match res { - None => None::>, - Some(c) => Some( - SqlRequest { - context: c.context, - req: codec.decode(&c.req).map_err(|e| { - sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)) - })?, - } - .into(), - ), + None => None::>, + Some(job) => { + let (req, parts) = job.req.take_parts(); + let args = C::decode(req) + .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?; + let mut req = Request::new_with_parts(args, parts); + req.parts.namespace = Some(namespace.clone()); + Some(req) + } } - - .map(Into::into); - } + }; + apalis_core::sleep(interval).await; } } } } -impl Storage for SqliteStorage +impl Storage for SqliteStorage where - T: Job + Serialize + DeserializeOwned + Send + 'static + Unpin + Sync, + T: Serialize + DeserializeOwned + Send + 'static + Unpin + Sync, + C: Codec + Send, { type Job = T; type Error = sqlx::Error; - type Identifier = TaskId; - - async fn push(&mut self, job: Self::Job) -> Result { - let id = TaskId::new(); - let query = "INSERT INTO Jobs VALUES (?1, ?2, ?3, 'Pending', 0, 25, strftime('%s','now'), NULL, NULL, NULL, NULL)"; - let pool = self.pool.clone(); + type Context = SqlContext; - let job = self - .codec - .encode(&job) + async fn push_request( + &mut self, + job: Request, + ) -> Result, Self::Error> { + let query = "INSERT INTO Jobs VALUES (?1, ?2, ?3, 'Pending', 0, ?4, strftime('%s','now'), NULL, NULL, NULL, NULL)"; + let (task, parts) = job.take_parts(); + let raw = C::encode(&task) .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?; - let job_type = T::NAME; + let job_type = self.config.namespace.clone(); sqlx::query(query) - .bind(job) - .bind(id.to_string()) + .bind(raw) + .bind(parts.task_id.to_string()) .bind(job_type.to_string()) - .execute(&pool) + .bind(&parts.context.max_attempts()) + .execute(&self.pool) .await?; - Ok(id) + Ok(parts) } - async fn schedule(&mut self, job: Self::Job, on: i64) -> Result { + async fn schedule_request( + &mut self, + req: Request, + on: i64, + ) -> Result, Self::Error> { let query = - "INSERT INTO Jobs VALUES (?1, ?2, ?3, 'Pending', 0, 25, ?4, NULL, NULL, NULL, NULL)"; - let pool = self.pool.clone(); - let id = TaskId::new(); - let job = self - .codec - .encode(&job) + "INSERT INTO Jobs VALUES (?1, ?2, ?3, 'Pending', 0, ?4, ?5, NULL, NULL, NULL, NULL)"; + let id = &req.parts.task_id; + let job = C::encode(&req.args) .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?; - let job_type = T::NAME; + let job_type = self.config.namespace.clone(); sqlx::query(query) .bind(job) .bind(id.to_string()) .bind(job_type) + .bind(&req.parts.context.max_attempts()) .bind(on) - .execute(&pool) + .execute(&self.pool) .await?; - Ok(id) + Ok(req.parts) } async fn fetch_by_id( - &self, + &mut self, job_id: &TaskId, - ) -> Result>, Self::Error> { - let pool = self.pool.clone(); + ) -> Result>, Self::Error> { let fetch_query = "SELECT * FROM Jobs WHERE id = ?1"; let res: Option> = sqlx::query_as(fetch_query) .bind(job_id.to_string()) - .fetch_optional(&pool) + .fetch_optional(&self.pool) .await?; match res { None => Ok(None), - Some(c) => Ok(Some( - SqlRequest { - context: c.context, - req: self.codec.decode(&c.req).map_err(|e| { - sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)) - })?, - } - .into(), - )), + Some(job) => Ok(Some({ + let (req, parts) = job.req.take_parts(); + let args = C::decode(req) + .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?; + + let mut req: Request = Request::new_with_parts(args, parts); + req.parts.namespace = Some(Namespace(self.config.namespace.clone())); + req + })), } } - async fn len(&self) -> Result { - let pool = self.pool.clone(); - + async fn len(&mut self) -> Result { let query = "Select Count(*) as count from Jobs where status='Pending'"; - let record = sqlx::query(query).fetch_one(&pool).await?; + let record = sqlx::query(query).fetch_one(&self.pool).await?; record.try_get("count") } - async fn reschedule(&mut self, job: Request, wait: Duration) -> Result<(), Self::Error> { - let pool = self.pool.clone(); - let task_id = job.get::().ok_or(sqlx::Error::Io(io::Error::new( - io::ErrorKind::InvalidData, - "Missing TaskId", - )))?; + async fn reschedule( + &mut self, + job: Request, + wait: Duration, + ) -> Result<(), Self::Error> { + let task_id = job.parts.task_id; let wait: i64 = wait .as_secs() .try_into() .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?; - let mut tx = pool.acquire().await?; + let mut tx = self.pool.acquire().await?; let query = "UPDATE Jobs SET status = 'Failed', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = ?2 WHERE id = ?1"; let now: i64 = Utc::now().timestamp(); @@ -307,22 +327,16 @@ where Ok(()) } - async fn update(&self, job: Request) -> Result<(), Self::Error> { - let pool = self.pool.clone(); - let ctx = job - .get::() - .ok_or(sqlx::Error::Io(io::Error::new( - io::ErrorKind::InvalidData, - "Missing SqlContext", - )))?; + async fn update(&mut self, job: Request) -> Result<(), Self::Error> { + let ctx = job.parts.context; let status = ctx.status().to_string(); - let attempts = ctx.attempts(); + let attempts = job.parts.attempt; let done_at = *ctx.done_at(); let lock_by = ctx.lock_by().clone(); let lock_at = *ctx.lock_at(); let last_error = ctx.last_error().clone(); - let job_id = ctx.id(); - let mut tx = pool.acquire().await?; + let job_id = job.parts.task_id; + let mut tx = self.pool.acquire().await?; let query = "UPDATE Jobs SET status = ?1, attempts = ?2, done_at = ?3, lock_by = ?4, lock_at = ?5, last_error = ?6 WHERE id = ?7"; sqlx::query(query) @@ -343,29 +357,26 @@ where Ok(()) } - async fn is_empty(&self) -> Result { + async fn is_empty(&mut self) -> Result { self.len().map_ok(|c| c == 0).await } - async fn vacuum(&self) -> Result { - let pool = self.pool.clone(); + async fn vacuum(&mut self) -> Result { let query = "Delete from Jobs where status='Done'"; - let record = sqlx::query(query).execute(&pool).await?; + let record = sqlx::query(query).execute(&self.pool).await?; Ok(record.rows_affected().try_into().unwrap_or_default()) } } impl SqliteStorage { /// Puts the job instantly back into the queue - /// Another [Worker] may consume + /// Another Worker may consume pub async fn retry( &mut self, worker_id: &WorkerId, job_id: &TaskId, ) -> Result<(), sqlx::Error> { - let pool = self.pool.clone(); - - let mut tx = pool.acquire().await?; + let mut tx = self.pool.acquire().await?; let query = "UPDATE Jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = ?1 AND lock_by = ?2"; sqlx::query(query) @@ -378,9 +389,7 @@ impl SqliteStorage { /// Kill a job pub async fn kill(&mut self, worker_id: &WorkerId, job_id: &TaskId) -> Result<(), sqlx::Error> { - let pool = self.pool.clone(); - - let mut tx = pool.begin().await?; + let mut tx = self.pool.begin().await?; let query = "UPDATE Jobs SET status = 'Killed', done_at = strftime('%s','now') WHERE id = ?1 AND lock_by = ?2"; sqlx::query(query) @@ -392,12 +401,9 @@ impl SqliteStorage { Ok(()) } - /// Add jobs that failed back to the queue if there are still remaining attempts - pub async fn reenqueue_failed(&self) -> Result<(), sqlx::Error> - where - T: Job, - { - let job_type = T::NAME; + /// Add jobs that failed back to the queue if there are still remaining attemps + pub async fn reenqueue_failed(&mut self) -> Result<(), sqlx::Error> { + let job_type = self.config.namespace.clone(); let mut tx = self.pool.acquire().await?; let query = r#"Update Jobs SET status = "Pending", done_at = NULL, lock_by = NULL, lock_at = NULL @@ -419,11 +425,12 @@ impl SqliteStorage { } /// Add jobs that workers have disappeared to the queue - pub async fn reenqueue_orphaned(&self, timeout: i64) -> Result<(), sqlx::Error> - where - T: Job, - { - let job_type = T::NAME; + pub async fn reenqueue_orphaned( + &self, + count: i32, + dead_since: DateTime, + ) -> Result<(), sqlx::Error> { + let job_type = self.config.namespace.clone(); let mut tx = self.pool.acquire().await?; let query = r#"Update Jobs SET status = "Pending", done_at = NULL, lock_by = NULL, lock_at = NULL, last_error ="Job was abandoned" @@ -433,78 +440,198 @@ impl SqliteStorage { AND Workers.worker_type = ?2 ORDER BY lock_at ASC LIMIT ?3);"#; sqlx::query(query) - .bind(timeout) + .bind(dead_since.timestamp()) .bind(job_type) - .bind::(self.config.buffer_size.try_into().unwrap()) + .bind(count) .execute(&mut *tx) .await?; Ok(()) } } -impl Backend> - for SqliteStorage -{ - type Stream = BackendStream>>; - type Layer = AckLayer, T>; +/// Errors that can occur while polling an SQLite database. +#[derive(thiserror::Error, Debug)] +pub enum SqlitePollError { + /// Error during a keep-alive heartbeat. + #[error("Encountered an error during KeepAlive heartbeat: `{0}`")] + KeepAliveError(sqlx::Error), - fn common_layer(&self, worker_id: WorkerId) -> Self::Layer { - AckLayer::new(self.clone(), worker_id) - } + /// Error during re-enqueuing orphaned tasks. + #[error("Encountered an error during ReenqueueOrphaned heartbeat: `{0}`")] + ReenqueueOrphanedError(sqlx::Error), +} - fn poll(mut self, worker: WorkerId) -> Poller { +impl + Backend, Res> for SqliteStorage +{ + type Stream = BackendStream>>; + type Layer = AckLayer, T, SqlContext, Res>; + + fn poll(mut self, worker: &Worker) -> Poller { + let layer = AckLayer::new(self.clone()); let config = self.config.clone(); let controller = self.controller.clone(); let stream = self - .stream_jobs(&worker, config.poll_interval, config.buffer_size) - .map_err(|e| Error::SourceError(Box::new(e))); + .stream_jobs(worker.id(), config.poll_interval, config.buffer_size) + .map_err(|e| Error::SourceError(Arc::new(Box::new(e)))); let stream = BackendStream::new(stream.boxed(), controller); + let requeue_storage = self.clone(); + let w = worker.clone(); let heartbeat = async move { loop { let now: i64 = Utc::now().timestamp(); - self.keep_alive_at::(&worker, now) - .await - .unwrap(); + if let Err(e) = self.keep_alive_at::(w.id(), now).await { + w.emit(Event::Error(Box::new(SqlitePollError::KeepAliveError(e)))); + } apalis_core::sleep(Duration::from_secs(30)).await; } } .boxed(); - Poller::new(stream, heartbeat) + let w = worker.clone(); + let reenqueue_beat = async move { + loop { + let dead_since = Utc::now() + - chrono::Duration::from_std(config.reenqueue_orphaned_after).unwrap(); + if let Err(e) = requeue_storage + .reenqueue_orphaned( + config + .buffer_size + .try_into() + .expect("could not convert usize to i32"), + dead_since, + ) + .await + { + w.emit(Event::Error(Box::new( + SqlitePollError::ReenqueueOrphanedError(e), + ))); + } + apalis_core::sleep(config.poll_interval).await; + } + }; + Poller::new_with_layer( + stream, + async { + futures::join!(heartbeat, reenqueue_beat); + }, + layer, + ) } } -impl Ack for SqliteStorage { - type Acknowledger = TaskId; - type Error = sqlx::Error; - async fn ack( - &self, - worker_id: &WorkerId, - task_id: &Self::Acknowledger, - ) -> Result<(), sqlx::Error> { +impl Ack for SqliteStorage { + type Context = SqlContext; + type AckError = sqlx::Error; + async fn ack(&mut self, ctx: &Self::Context, res: &Response) -> Result<(), sqlx::Error> { let pool = self.pool.clone(); let query = - "UPDATE Jobs SET status = 'Done', done_at = strftime('%s','now') WHERE id = ?1 AND lock_by = ?2"; + "UPDATE Jobs SET status = ?4, done_at = strftime('%s','now'), last_error = ?3 WHERE id = ?1 AND lock_by = ?2"; + let result = serde_json::to_string(&res.inner.as_ref().map_err(|r| r.to_string())) + .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?; sqlx::query(query) - .bind(task_id.to_string()) - .bind(worker_id.to_string()) + .bind(res.task_id.to_string()) + .bind( + ctx.lock_by() + .as_ref() + .expect("Task is not locked") + .to_string(), + ) + .bind(result) + .bind(calculate_status(&res.inner).to_string()) .execute(&pool) .await?; Ok(()) } } +impl BackendExpose + for SqliteStorage> +{ + type Request = Request>; + type Error = SqlError; + async fn stats(&self) -> Result { + let fetch_query = "SELECT + COUNT(1) FILTER (WHERE status = 'Pending') AS pending, + COUNT(1) FILTER (WHERE status = 'Running') AS running, + COUNT(1) FILTER (WHERE status = 'Done') AS done, + COUNT(1) FILTER (WHERE status = 'Failed') AS failed, + COUNT(1) FILTER (WHERE status = 'Killed') AS killed + FROM Jobs WHERE job_type = ?"; + + let res: (i64, i64, i64, i64, i64, i64) = sqlx::query_as(fetch_query) + .bind(self.get_config().namespace()) + .fetch_one(self.pool()) + .await?; + + Ok(Stat { + pending: res.0.try_into()?, + running: res.1.try_into()?, + dead: res.4.try_into()?, + failed: res.3.try_into()?, + success: res.2.try_into()?, + }) + } + + async fn list_jobs( + &self, + status: &State, + page: i32, + ) -> Result, Self::Error> { + let status = status.to_string(); + let fetch_query = "SELECT * FROM Jobs WHERE status = ? AND job_type = ? ORDER BY done_at DESC, run_at DESC LIMIT 10 OFFSET ?"; + let res: Vec> = sqlx::query_as(fetch_query) + .bind(status) + .bind(self.get_config().namespace()) + .bind(((page - 1) * 10).to_string()) + .fetch_all(self.pool()) + .await?; + Ok(res + .into_iter() + .map(|j| { + let (req, ctx) = j.req.take_parts(); + let req = JsonCodec::::decode(req).unwrap(); + Request::new_with_ctx(req, ctx) + }) + .collect()) + } + + async fn list_workers(&self) -> Result>, Self::Error> { + let fetch_query = + "SELECT id, layers, last_seen FROM Workers WHERE worker_type = ? ORDER BY last_seen DESC LIMIT 20 OFFSET ?"; + let res: Vec<(String, String, i64)> = sqlx::query_as(fetch_query) + .bind(self.get_config().namespace()) + .bind(0) + .fetch_all(self.pool()) + .await?; + Ok(res + .into_iter() + .map(|w| Worker::new(WorkerId::new(w.0), WorkerState::new::(w.1))) + .collect()) + } +} + #[cfg(test)] mod tests { - use crate::context::State; + use crate::sql_storage_tests; use super::*; + use apalis_core::request::State; + use apalis_core::test_utils::DummyService; + use chrono::Utc; + use email_service::example_good_email; use email_service::Email; use futures::StreamExt; - use sqlx::types::chrono::Utc; + + use apalis_core::generic_storage_test; + use apalis_core::test_utils::apalis_test_service_fn; + use apalis_core::test_utils::TestWrapper; + + generic_storage_test!(setup); + sql_storage_tests!(setup::, SqliteStorage, Email); /// migrate DB and return a storage instance. - async fn setup() -> SqliteStorage { + async fn setup() -> SqliteStorage { // Because connections cannot be shared across async runtime // (different runtimes are created for each test), // we don't share the storage and tests must be run sequentially. @@ -512,7 +639,7 @@ mod tests { SqliteStorage::setup(&pool) .await .expect("failed to migrate DB"); - let storage = SqliteStorage::::new(pool); + let storage = SqliteStorage::::new(pool); storage } @@ -532,20 +659,10 @@ mod tests { assert_eq!(len, 1); } - struct DummyService {} - - fn example_email() -> Email { - Email { - subject: "Test Subject".to_string(), - to: "example@postgres".to_string(), - text: "Some Text".to_string(), - } - } - async fn consume_one( storage: &mut SqliteStorage, worker_id: &WorkerId, - ) -> Request { + ) -> Request { let mut stream = storage .stream_jobs(worker_id, std::time::Duration::from_secs(10), 1) .boxed(); @@ -575,7 +692,10 @@ mod tests { storage.push(email).await.expect("failed to push a job"); } - async fn get_job(storage: &mut SqliteStorage, job_id: &TaskId) -> Request { + async fn get_job( + storage: &mut SqliteStorage, + job_id: &TaskId, + ) -> Request { storage .fetch_by_id(job_id) .await @@ -586,12 +706,14 @@ mod tests { #[tokio::test] async fn test_consume_last_pushed_job() { let mut storage = setup().await; - push_email(&mut storage, example_email()).await; - let worker_id = register_worker(&mut storage).await; + push_email(&mut storage, example_good_email()).await; + let len = storage.len().await.expect("Could not fetch the jobs count"); + assert_eq!(len, 1); + let job = consume_one(&mut storage, &worker_id).await; - let ctx = job.get::().unwrap(); + let ctx = job.parts.context; assert_eq!(*ctx.status(), State::Running); assert_eq!(*ctx.lock_by(), Some(worker_id.clone())); assert!(ctx.lock_at().is_some()); @@ -600,21 +722,23 @@ mod tests { #[tokio::test] async fn test_acknowledge_job() { let mut storage = setup().await; - push_email(&mut storage, example_email()).await; - let worker_id = register_worker(&mut storage).await; + push_email(&mut storage, example_good_email()).await; let job = consume_one(&mut storage, &worker_id).await; - let ctx = job.get::().unwrap(); - let job_id = ctx.id(); - + let job_id = &job.parts.task_id; + let ctx = &job.parts.context; + let res = 1usize; storage - .ack(&worker_id, job_id) + .ack( + ctx, + &Response::success(res, job_id.clone(), job.parts.attempt.clone()), + ) .await .expect("failed to acknowledge the job"); let job = get_job(&mut storage, job_id).await; - let ctx = job.get::().unwrap(); + let ctx = job.parts.context; assert_eq!(*ctx.status(), State::Done); assert!(ctx.done_at().is_some()); } @@ -623,13 +747,12 @@ mod tests { async fn test_kill_job() { let mut storage = setup().await; - push_email(&mut storage, example_email()).await; + push_email(&mut storage, example_good_email()).await; let worker_id = register_worker(&mut storage).await; let job = consume_one(&mut storage, &worker_id).await; - let ctx = job.get::().unwrap(); - let job_id = ctx.id(); + let job_id = &job.parts.task_id; storage .kill(&worker_id, job_id) @@ -637,7 +760,7 @@ mod tests { .expect("failed to kill job"); let job = get_job(&mut storage, job_id).await; - let ctx = job.get::().unwrap(); + let ctx = job.parts.context; assert_eq!(*ctx.status(), State::Killed); assert!(ctx.done_at().is_some()); } @@ -646,50 +769,52 @@ mod tests { async fn test_heartbeat_renqueueorphaned_pulse_last_seen_6min() { let mut storage = setup().await; - push_email(&mut storage, example_email()).await; + push_email(&mut storage, example_good_email()).await; let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60); + let five_minutes_ago = Utc::now() - Duration::from_secs(5 * 60); let worker_id = register_worker_at(&mut storage, six_minutes_ago.timestamp()).await; let job = consume_one(&mut storage, &worker_id).await; - let ctx = job.get::().unwrap(); + let job_id = &job.parts.task_id; storage - .reenqueue_orphaned(six_minutes_ago.timestamp()) + .reenqueue_orphaned(1, five_minutes_ago) .await .expect("failed to heartbeat"); - - let job_id = ctx.id(); let job = get_job(&mut storage, job_id).await; - let ctx = job.get::().unwrap(); - // TODO: rework these assertions - // assert_eq!(*ctx.status(), State::Pending); - // assert!(ctx.done_at().is_none()); - // assert!(ctx.lock_by().is_none()); - // assert!(ctx.lock_at().is_none()); - // assert_eq!(*ctx.last_error(), Some("Job was abandoned".to_string())); + let ctx = &job.parts.context; + assert_eq!(*ctx.status(), State::Pending); + assert!(ctx.done_at().is_none()); + assert!(ctx.lock_by().is_none()); + assert!(ctx.lock_at().is_none()); + assert_eq!(*ctx.last_error(), Some("Job was abandoned".to_owned())); + assert_eq!(job.parts.attempt.current(), 1); } #[tokio::test] async fn test_heartbeat_renqueueorphaned_pulse_last_seen_4min() { let mut storage = setup().await; - push_email(&mut storage, example_email()).await; + push_email(&mut storage, example_good_email()).await; + let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60); let four_minutes_ago = Utc::now() - Duration::from_secs(4 * 60); let worker_id = register_worker_at(&mut storage, four_minutes_ago.timestamp()).await; let job = consume_one(&mut storage, &worker_id).await; - let ctx = job.get::().unwrap(); + let job_id = &job.parts.task_id; storage - .reenqueue_orphaned(four_minutes_ago.timestamp()) + .reenqueue_orphaned(1, six_minutes_ago) .await .expect("failed to heartbeat"); - let job_id = ctx.id(); let job = get_job(&mut storage, job_id).await; - let ctx = job.get::().unwrap(); + let ctx = &job.parts.context; assert_eq!(*ctx.status(), State::Running); assert_eq!(*ctx.lock_by(), Some(worker_id)); + assert!(ctx.lock_at().is_some()); + assert_eq!(*ctx.last_error(), None); + assert_eq!(job.parts.attempt.current(), 1); } } diff --git a/src/layers/catch_panic/mod.rs b/src/layers/catch_panic/mod.rs new file mode 100644 index 0000000..ba869d1 --- /dev/null +++ b/src/layers/catch_panic/mod.rs @@ -0,0 +1,207 @@ +use std::any::Any; +use std::fmt; +use std::future::Future; +use std::panic::{catch_unwind, AssertUnwindSafe}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use apalis_core::error::Error; +use apalis_core::request::Request; +use tower::Layer; +use tower::Service; + +/// Apalis Layer that catches panics in the service. +#[derive(Clone, Debug)] +pub struct CatchPanicLayer { + on_panic: F, +} + +impl CatchPanicLayer) -> Error> { + /// Creates a new `CatchPanicLayer` with a default panic handler. + pub fn new() -> Self { + CatchPanicLayer { + on_panic: default_handler, + } + } +} + +impl Default for CatchPanicLayer) -> Error> { + fn default() -> Self { + Self::new() + } +} + +impl CatchPanicLayer +where + F: FnMut(Box) -> Error + Clone, +{ + /// Creates a new `CatchPanicLayer` with a custom panic handler. + pub fn with_panic_handler(on_panic: F) -> Self { + CatchPanicLayer { on_panic } + } +} + +impl Layer for CatchPanicLayer +where + F: FnMut(Box) -> Error + Clone, +{ + type Service = CatchPanicService; + + fn layer(&self, service: S) -> Self::Service { + CatchPanicService { + service, + on_panic: self.on_panic.clone(), + } + } +} + +/// Apalis Service that catches panics. +#[derive(Clone, Debug)] +pub struct CatchPanicService { + service: S, + on_panic: F, +} + +impl Service> for CatchPanicService +where + S: Service, Response = Res, Error = Error>, + F: FnMut(Box) -> Error + Clone, +{ + type Response = S::Response; + type Error = S::Error; + type Future = CatchPanicFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, request: Request) -> Self::Future { + CatchPanicFuture { + future: self.service.call(request), + on_panic: self.on_panic.clone(), + } + } +} + +pin_project_lite::pin_project! { + /// A wrapper that catches panics during execution + pub struct CatchPanicFuture { + #[pin] + future: Fut, + on_panic: F, + } +} + +/// An error generated from a panic +#[derive(Debug, Clone)] +pub struct PanicError(pub String); + +impl std::error::Error for PanicError {} + +impl fmt::Display for PanicError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "PanicError: {}", self.0) + } +} + +impl Future for CatchPanicFuture +where + Fut: Future>, + F: FnMut(Box) -> Error, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.as_mut().project(); + + match catch_unwind(AssertUnwindSafe(|| this.future.poll(cx))) { + Ok(res) => res, + Err(e) => Poll::Ready(Err((this.on_panic)(e))), + } + } +} + +fn default_handler(e: Box) -> Error { + let panic_info = if let Some(s) = e.downcast_ref::<&str>() { + s.to_string() + } else if let Some(s) = e.downcast_ref::() { + s.clone() + } else { + "Unknown panic".to_string() + }; + // apalis assumes service functions are pure + // therefore a panic should ideally abort + Error::Abort(Arc::new(Box::new(PanicError(panic_info)))) +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::task::{Context, Poll}; + use tower::Service; + + #[derive(Clone, Debug)] + struct TestJob; + + #[derive(Clone)] + struct TestService; + + impl Service> for TestService { + type Response = usize; + type Error = Error; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: Request) -> Self::Future { + Box::pin(async { Ok(42) }) + } + } + + #[tokio::test] + async fn test_catch_panic_layer() { + let layer = CatchPanicLayer::new(); + let mut service = layer.layer(TestService); + + let request = Request::new(TestJob); + let response = service.call(request).await; + + assert!(response.is_ok()); + } + + #[tokio::test] + async fn test_catch_panic_layer_panics() { + struct PanicService; + + impl Service> for PanicService { + type Response = usize; + type Error = Error; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: Request) -> Self::Future { + Box::pin(async { None.unwrap() }) + } + } + + let layer = CatchPanicLayer::new(); + let mut service = layer.layer(PanicService); + + let request = Request::new(TestJob); + let response = service.call(request).await; + + assert!(response.is_err()); + + assert_eq!( + response.unwrap_err().to_string(), + *"AbortError: PanicError: called `Option::unwrap()` on a `None` value" + ); + } +} diff --git a/src/layers/mod.rs b/src/layers/mod.rs index e7b5e99..47bfbeb 100644 --- a/src/layers/mod.rs +++ b/src/layers/mod.rs @@ -18,10 +18,301 @@ pub mod tracing; #[cfg(feature = "limit")] #[cfg_attr(docsrs, doc(cfg(feature = "limit")))] pub mod limit { + pub use tower::limit::ConcurrencyLimitLayer; + pub use tower::limit::GlobalConcurrencyLimitLayer; pub use tower::limit::RateLimitLayer; } +use apalis_core::{builder::WorkerBuilder, layers::Identity}; +#[cfg(feature = "catch-panic")] +use catch_panic::CatchPanicLayer; +use tower::layer::util::Stack; /// Timeout middleware for apalis #[cfg(feature = "timeout")] #[cfg_attr(docsrs, doc(cfg(feature = "timeout")))] pub use tower::timeout::TimeoutLayer; + +/// catch panic middleware for apalis +#[cfg(feature = "catch-panic")] +#[cfg_attr(docsrs, doc(cfg(feature = "catch-panic")))] +pub mod catch_panic; + +pub use apalis_core::error::ErrorHandlingLayer; + +/// A trait that extends `WorkerBuilder` with additional middleware methods +/// derived from `tower::ServiceBuilder`. +pub trait WorkerBuilderExt { + /// Optionally adds a new layer `T` into the [`WorkerBuilder`]. + fn option_layer( + self, + layer: Option, + ) -> WorkerBuilder, Middleware>, Serv>; + + /// Adds a [`Layer`] built from a function that accepts a service and returns another service. + fn layer_fn( + self, + f: F, + ) -> WorkerBuilder, Middleware>, Serv>; + + /// Limits the max number of in-flight requests. + #[cfg(feature = "limit")] + fn concurrency( + self, + max: usize, + ) -> WorkerBuilder, Serv>; + + /// Limits requests to at most `num` per the given duration. + #[cfg(feature = "limit")] + fn rate_limit( + self, + num: u64, + per: std::time::Duration, + ) -> WorkerBuilder, Serv>; + + /// Retries failed requests according to the given retry policy. + #[cfg(feature = "retry")] + fn retry

( + self, + policy: P, + ) -> WorkerBuilder, Middleware>, Serv>; + + /// Fails requests that take longer than `timeout`. + #[cfg(feature = "timeout")] + fn timeout( + self, + timeout: std::time::Duration, + ) -> WorkerBuilder, Serv>; + + /// Conditionally rejects requests based on `predicate`. + #[cfg(feature = "filter")] + fn filter

( + self, + predicate: P, + ) -> WorkerBuilder, Middleware>, Serv>; + + /// Conditionally rejects requests based on an asynchronous `predicate`. + #[cfg(feature = "filter")] + fn filter_async

( + self, + predicate: P, + ) -> WorkerBuilder, Middleware>, Serv>; + + /// Maps one request type to another. + fn map_request( + self, + f: F, + ) -> WorkerBuilder, Middleware>, Serv> + where + F: FnMut(R1) -> R2 + Clone; + + /// Maps one response type to another. + fn map_response( + self, + f: F, + ) -> WorkerBuilder, Middleware>, Serv>; + + /// Maps one error type to another. + fn map_err( + self, + f: F, + ) -> WorkerBuilder, Middleware>, Serv>; + + /// Composes a function that transforms futures produced by the service. + fn map_future( + self, + f: F, + ) -> WorkerBuilder, Middleware>, Serv>; + + /// Applies an asynchronous function after the service, regardless of whether the future succeeds or fails. + fn then( + self, + f: F, + ) -> WorkerBuilder, Middleware>, Serv>; + + /// Executes a new future after this service's future resolves. + fn and_then( + self, + f: F, + ) -> WorkerBuilder, Middleware>, Serv>; + + /// Maps the service's result type to a different value, regardless of success or failure. + fn map_result( + self, + f: F, + ) -> WorkerBuilder, Middleware>, Serv>; + + /// Catch panics in execution and pipe them as errors + #[cfg(feature = "catch-panic")] + #[cfg_attr(docsrs, doc(cfg(feature = "catch-panic")))] + #[allow(clippy::type_complexity)] + fn catch_panic( + self, + ) -> WorkerBuilder< + Req, + Ctx, + Source, + Stack< + CatchPanicLayer) -> apalis_core::error::Error>, + Middleware, + >, + Serv, + >; + /// Enable tracing via tracing crate + #[cfg(feature = "tracing")] + #[cfg_attr(docsrs, doc(cfg(feature = "tracing")))] + fn enable_tracing( + self, + ) -> WorkerBuilder, Serv>; +} + +impl WorkerBuilderExt + for WorkerBuilder +{ + fn option_layer( + self, + layer: Option, + ) -> WorkerBuilder, Middleware>, Serv> + { + self.chain(|sb| sb.option_layer(layer)) + } + + fn layer_fn( + self, + f: F, + ) -> WorkerBuilder, Middleware>, Serv> { + self.chain(|sb| sb.layer_fn(f)) + } + + #[cfg(feature = "limit")] + fn concurrency( + self, + max: usize, + ) -> WorkerBuilder, Serv> + { + self.chain(|sb| sb.concurrency_limit(max)) + } + + #[cfg(feature = "limit")] + fn rate_limit( + self, + num: u64, + per: std::time::Duration, + ) -> WorkerBuilder, Serv> { + self.chain(|sb| sb.rate_limit(num, per)) + } + + #[cfg(feature = "retry")] + fn retry

( + self, + policy: P, + ) -> WorkerBuilder, Middleware>, Serv> { + self.chain(|sb| sb.retry(policy)) + } + + #[cfg(feature = "timeout")] + fn timeout( + self, + timeout: std::time::Duration, + ) -> WorkerBuilder, Serv> { + self.chain(|sb| sb.timeout(timeout)) + } + + #[cfg(feature = "filter")] + fn filter

( + self, + predicate: P, + ) -> WorkerBuilder, Middleware>, Serv> { + self.chain(|sb| sb.filter(predicate)) + } + + #[cfg(feature = "filter")] + fn filter_async

( + self, + predicate: P, + ) -> WorkerBuilder, Middleware>, Serv> + { + self.chain(|sb| sb.filter_async(predicate)) + } + + fn map_request( + self, + f: F, + ) -> WorkerBuilder, Middleware>, Serv> + where + F: FnMut(R1) -> R2 + Clone, + { + self.chain(|sb| sb.map_request(f)) + } + + fn map_response( + self, + f: F, + ) -> WorkerBuilder, Middleware>, Serv> + { + self.chain(|sb| sb.map_response(f)) + } + + fn map_err( + self, + f: F, + ) -> WorkerBuilder, Middleware>, Serv> { + self.chain(|sb| sb.map_err(f)) + } + + fn map_future( + self, + f: F, + ) -> WorkerBuilder, Middleware>, Serv> { + self.chain(|sb| sb.map_future(f)) + } + + fn then( + self, + f: F, + ) -> WorkerBuilder, Middleware>, Serv> { + self.chain(|sb| sb.then(f)) + } + + fn and_then( + self, + f: F, + ) -> WorkerBuilder, Middleware>, Serv> { + self.chain(|sb| sb.and_then(f)) + } + + fn map_result( + self, + f: F, + ) -> WorkerBuilder, Middleware>, Serv> { + self.chain(|sb| sb.map_result(f)) + } + + /// Catch panics in execution and pipe them as errors + #[cfg(feature = "catch-panic")] + #[cfg_attr(docsrs, doc(cfg(feature = "catch-panic")))] + fn catch_panic( + self, + ) -> WorkerBuilder< + Req, + Ctx, + (), + Stack< + CatchPanicLayer) -> apalis_core::error::Error>, + Middleware, + >, + Serv, + > { + self.chain(|svc| svc.layer(CatchPanicLayer::new())) + } + + /// Enable tracing via tracing crate + #[cfg(feature = "tracing")] + #[cfg_attr(docsrs, doc(cfg(feature = "tracing")))] + fn enable_tracing( + self, + ) -> WorkerBuilder, Serv> { + use tracing::TraceLayer; + + self.chain(|svc| svc.layer(TraceLayer::new())) + } +} diff --git a/src/layers/prometheus/mod.rs b/src/layers/prometheus/mod.rs index bb39cd2..a0d5008 100644 --- a/src/layers/prometheus/mod.rs +++ b/src/layers/prometheus/mod.rs @@ -4,14 +4,15 @@ use std::{ time::Instant, }; -use apalis_core::{error::Error, request::Request, storage::Job}; +use apalis_core::{error::Error, request::Request}; use futures::Future; use pin_project_lite::pin_project; use tower::{Layer, Service}; /// A layer to support prometheus metrics #[derive(Debug, Default)] -pub struct PrometheusLayer; +#[non_exhaustive] +pub struct PrometheusLayer {} impl Layer for PrometheusLayer { type Service = PrometheusService; @@ -27,30 +28,36 @@ pub struct PrometheusService { service: S, } -impl Service> for PrometheusService +impl Service> for PrometheusService where - S: Service, Response = Res, Error = Error, Future = F>, - F: Future> + 'static, - J: Job, + Svc: Service, Response = Res, Error = Error, Future = Fut>, + Fut: Future> + 'static, { - type Response = S::Response; - type Error = S::Error; - type Future = ResponseFuture; + type Response = Svc::Response; + type Error = Svc::Error; + type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.service.poll_ready(cx) } - fn call(&mut self, request: Request) -> Self::Future { + fn call(&mut self, request: Request) -> Self::Future { let start = Instant::now(); + let namespace = request + .parts + .namespace + .as_ref() + .map(|ns| ns.0.to_string()) + .unwrap_or(std::any::type_name::().to_string()); + let req = self.service.call(request); - let job_type = std::any::type_name::().to_string(); - let op = J::NAME; + let job_type = std::any::type_name::().to_string(); + ResponseFuture { inner: req, start, job_type, - operation: op.to_string(), + operation: namespace, } } } @@ -89,8 +96,10 @@ where ("namespace", this.job_type.to_string()), ("status", status), ]; - metrics::counter!("requests_total", &labels).increment(1); - metrics::histogram!("request_duration_seconds", &labels).record(latency); + let counter = metrics::counter!("requests_total", &labels); + counter.increment(1); + let hist = metrics::histogram!("request_duration_seconds", &labels); + hist.record(latency); Poll::Ready(response) } } diff --git a/src/layers/retry/mod.rs b/src/layers/retry/mod.rs index 5b5ae46..e616434 100644 --- a/src/layers/retry/mod.rs +++ b/src/layers/retry/mod.rs @@ -1,15 +1,13 @@ use futures::future; use tower::retry::Policy; +use apalis_core::{error::Error, request::Request}; /// Re-export from [`RetryLayer`] /// /// [`RetryLayer`]: tower::retry::RetryLayer pub use tower::retry::RetryLayer; -use apalis_core::task::attempt::Attempt; -use apalis_core::{error::Error, request::Request}; - -type Req = Request; +type Req = Request; type Err = Error; /// Retries a task instantly for `retries` @@ -31,36 +29,34 @@ impl RetryPolicy { } } -impl Policy, Res, Err> for RetryPolicy +impl Policy, Res, Err> for RetryPolicy where T: Clone, + Ctx: Clone, { type Future = future::Ready<()>; - fn retry(&mut self, req: &mut Req, result: &mut Result) -> Option { - let ctx = req.get::().cloned().unwrap_or_default(); + fn retry( + &mut self, + req: &mut Req, + result: &mut Result, + ) -> Option { + let attempt = &req.parts.attempt; match result { Ok(_) => { // Treat all `Response`s as success, // so don't retry... None } - Err(_) if (self.retries - ctx.current() > 0) => Some(future::ready(())), + Err(_) if self.retries == 0 => None, + Err(_) if (self.retries - attempt.current() > 0) => Some(future::ready(())), Err(_) => None, } } - fn clone_request(&mut self, req: &Req) -> Option> { - let mut req = req.clone(); - let value = req - .get::() - .cloned() - .map(|attempt| { - attempt.increment(); - attempt - }) - .unwrap_or_default(); - req.insert(value); + fn clone_request(&mut self, req: &Req) -> Option> { + let req = req.clone(); + req.parts.attempt.increment(); Some(req) } } diff --git a/src/layers/sentry/mod.rs b/src/layers/sentry/mod.rs index 9bcfcd7..7e6d50c 100644 --- a/src/layers/sentry/mod.rs +++ b/src/layers/sentry/mod.rs @@ -1,16 +1,13 @@ +use sentry_core::protocol; use std::fmt::Debug; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; - -use sentry_core::protocol; use tower::Layer; use tower::Service; use apalis_core::error::Error; use apalis_core::request::Request; -use apalis_core::storage::Job; -use apalis_core::task::attempt::Attempt; use apalis_core::task::task_id::TaskId; /// Tower Layer that logs Job Details. @@ -126,34 +123,39 @@ where } } -impl Service> for SentryJobService +impl Service> for SentryJobService where - S: Service, Response = Res, Error = Error, Future = F>, - F: Future> + 'static, - J: Job, + Svc: Service, Response = Res, Error = Error, Future = Fut>, + Fut: Future> + 'static, { - type Response = S::Response; - type Error = S::Error; - type Future = SentryHttpFuture; + type Response = Svc::Response; + type Error = Svc::Error; + type Future = SentryHttpFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.service.poll_ready(cx) } - fn call(&mut self, request: Request) -> Self::Future { - let op = J::NAME; - let trx_ctx = sentry_core::TransactionContext::new(op, "apalis.job"); - let job_type = std::any::type_name::().to_string(); - let ctx = request.get::().cloned().unwrap_or_default(); - let task_id = request.get::().unwrap(); - let job_details = Task { + fn call(&mut self, request: Request) -> Self::Future { + let task_type = std::any::type_name::().to_string(); + let attempt = &request.parts.attempt; + let task_id = &request.parts.task_id; + let namespace = request + .parts + .namespace + .as_ref() + .map(|s| s.0.as_str()) + .unwrap_or(std::any::type_name::()); + let trx_ctx = sentry_core::TransactionContext::new(namespace, "apalis.task"); + + let task_details = Task { id: task_id.clone(), - current_attempt: ctx.current().try_into().unwrap(), - namespace: job_type, + current_attempt: attempt.current().try_into().unwrap(), + namespace: task_type, }; SentryHttpFuture { - on_first_poll: Some((job_details, trx_ctx)), + on_first_poll: Some((task_details, trx_ctx)), transaction: None, future: self.service.call(request), } diff --git a/src/layers/tracing/make_span.rs b/src/layers/tracing/make_span.rs index 58698ca..101d7ef 100644 --- a/src/layers/tracing/make_span.rs +++ b/src/layers/tracing/make_span.rs @@ -8,22 +8,22 @@ use super::DEFAULT_MESSAGE_LEVEL; /// /// [`Span`]: tracing::Span /// [`Trace`]: super::Trace -pub trait MakeSpan { +pub trait MakeSpan { /// Make a span from a request. - fn make_span(&mut self, request: &Request) -> Span; + fn make_span(&mut self, request: &Request) -> Span; } -impl MakeSpan for Span { - fn make_span(&mut self, _request: &Request) -> Span { +impl MakeSpan for Span { + fn make_span(&mut self, _request: &Request) -> Span { self.clone() } } -impl MakeSpan for F +impl MakeSpan for F where - F: FnMut(&Request) -> Span, + F: FnMut(&Request) -> Span, { - fn make_span(&mut self, request: &Request) -> Span { + fn make_span(&mut self, request: &Request) -> Span { self(request) } } @@ -62,18 +62,22 @@ impl Default for DefaultMakeSpan { } } -impl MakeSpan for DefaultMakeSpan { - fn make_span(&mut self, _req: &Request) -> Span { +impl MakeSpan for DefaultMakeSpan { + fn make_span(&mut self, req: &Request) -> Span { // This ugly macro is needed, unfortunately, because `tracing::span!` // required the level argument to be static. Meaning we can't just pass // `self.level`. + let task_id = req.parts.task_id.to_string(); + let attempt = req.parts.attempt.current(); let span = Span::current(); macro_rules! make_span { ($level:expr) => { tracing::span!( parent: span, $level, - "job", + "task", + task_id = task_id, + attempt = attempt ) }; } diff --git a/src/layers/tracing/mod.rs b/src/layers/tracing/mod.rs index 2ceb3e9..0f67509 100644 --- a/src/layers/tracing/mod.rs +++ b/src/layers/tracing/mod.rs @@ -3,7 +3,7 @@ mod on_failure; mod on_request; mod on_response; -use apalis_core::{error::Error, request::Request}; +use apalis_core::request::Request; use std::{ fmt::{self, Debug}, pin::Pin, @@ -289,26 +289,26 @@ impl } } -impl Service> +impl Service> for Trace where - S: Service, Response = Res, Error = Error, Future = F> + Unpin + Send + 'static, + S: Service, Response = Res, Future = F> + Unpin + Send + 'static, S::Error: fmt::Display + 'static, - MakeSpanT: MakeSpan, - OnRequestT: OnRequest, + MakeSpanT: MakeSpan, + OnRequestT: OnRequest, OnResponseT: OnResponse + Clone + 'static, - F: Future> + 'static, - OnFailureT: OnFailure + Clone + 'static, + F: Future> + 'static, + OnFailureT: OnFailure + Clone + 'static, { type Response = Res; - type Error = Error; + type Error = S::Error; type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { let span = self.make_span.make_span(&req); let start = Instant::now(); let job = { @@ -339,14 +339,14 @@ pin_project! { } } -impl Future for ResponseFuture +impl Future for ResponseFuture where - Fut: Future>, + Fut: Future>, OnResponseT: OnResponse, - OnFailureT: OnFailure, + OnFailureT: OnFailure, { - type Output = Result; + type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); diff --git a/src/layers/tracing/on_failure.rs b/src/layers/tracing/on_failure.rs index 20bc071..43a5392 100644 --- a/src/layers/tracing/on_failure.rs +++ b/src/layers/tracing/on_failure.rs @@ -1,8 +1,6 @@ -use apalis_core::error::Error; - use super::{LatencyUnit, DEFAULT_ERROR_LEVEL}; -use std::time::Duration; +use std::{fmt::Display, time::Duration}; use tracing::{Level, Span}; /// Trait used to tell [`Trace`] what to do when a request fails. @@ -11,7 +9,7 @@ use tracing::{Level, Span}; /// `on_failure` callback is called. /// /// [`Trace`]: super::Trace -pub trait OnFailure { +pub trait OnFailure { /// Do the thing. /// /// `latency` is the duration since the request was received. @@ -23,19 +21,19 @@ pub trait OnFailure { /// [`Span`]: https://docs.rs/tracing/latest/tracing/span/index.html /// [record]: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#method.record /// [`TraceLayer::make_span_with`]: crate::layers::tracing::TraceLayer::make_span_with - fn on_failure(&mut self, error: &Error, latency: Duration, span: &Span); + fn on_failure(&mut self, error: &E, latency: Duration, span: &Span); } -impl OnFailure for () { +impl OnFailure for () { #[inline] - fn on_failure(&mut self, _: &Error, _: Duration, _: &Span) {} + fn on_failure(&mut self, _: &E, _: Duration, _: &Span) {} } -impl OnFailure for F +impl OnFailure for F where - F: FnMut(&Error, Duration, &Span), + F: FnMut(&E, Duration, &Span), { - fn on_failure(&mut self, error: &Error, latency: Duration, span: &Span) { + fn on_failure(&mut self, error: &E, latency: Duration, span: &Span) { self(error, latency, span) } } @@ -135,8 +133,8 @@ macro_rules! log_pattern_match { }; } -impl OnFailure for DefaultOnFailure { - fn on_failure(&mut self, error: &Error, latency: Duration, span: &Span) { +impl OnFailure for DefaultOnFailure { + fn on_failure(&mut self, error: &E, latency: Duration, span: &Span) { log_pattern_match!( self, span, diff --git a/src/layers/tracing/on_request.rs b/src/layers/tracing/on_request.rs index f0be6b3..bc3ac71 100644 --- a/src/layers/tracing/on_request.rs +++ b/src/layers/tracing/on_request.rs @@ -10,7 +10,7 @@ use tracing::Span; /// `on_request` callback is called. /// /// [`Trace`]: super::Trace -pub trait OnRequest { +pub trait OnRequest { /// Do the thing. /// /// `span` is the `tracing` [`Span`], corresponding to this request, produced by the closure @@ -20,19 +20,19 @@ pub trait OnRequest { /// [`Span`]: https://docs.rs/tracing/latest/tracing/span/index.html /// [record]: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#method.record /// [`TraceLayer::make_span_with`]: crate::layers::tracing::TraceLayer::make_span_with - fn on_request(&mut self, request: &Request, span: &Span); + fn on_request(&mut self, request: &Request, span: &Span); } -impl OnRequest for () { +impl OnRequest for () { #[inline] - fn on_request(&mut self, _: &Request, _: &Span) {} + fn on_request(&mut self, _: &Request, _: &Span) {} } -impl OnRequest for F +impl OnRequest for F where - F: FnMut(&Request, &Span), + F: FnMut(&Request, &Span), { - fn on_request(&mut self, request: &Request, span: &Span) { + fn on_request(&mut self, request: &Request, span: &Span) { self(request, span) } } @@ -76,23 +76,23 @@ impl DefaultOnRequest { } } -impl OnRequest for DefaultOnRequest { - fn on_request(&mut self, _: &Request, _: &Span) { +impl OnRequest for DefaultOnRequest { + fn on_request(&mut self, _: &Request, _: &Span) { match self.level { Level::ERROR => { - tracing::event!(Level::ERROR, "job.start",); + tracing::event!(Level::ERROR, "task.start",); } Level::WARN => { - tracing::event!(Level::WARN, "job.start",); + tracing::event!(Level::WARN, "task.start",); } Level::INFO => { - tracing::event!(Level::INFO, "job.start",); + tracing::event!(Level::INFO, "task.start",); } Level::DEBUG => { - tracing::event!(Level::DEBUG, "job.start",); + tracing::event!(Level::DEBUG, "task.start",); } Level::TRACE => { - tracing::event!(Level::TRACE, "job.start",); + tracing::event!(Level::TRACE, "task.start",); } } } diff --git a/src/layers/tracing/on_response.rs b/src/layers/tracing/on_response.rs index fa7b67a..8889332 100644 --- a/src/layers/tracing/on_response.rs +++ b/src/layers/tracing/on_response.rs @@ -106,7 +106,7 @@ macro_rules! log_pattern_match { Level::$level, done_in = format_args!("{}s", $done_in.as_secs_f64()), result = format_args!("{:?}", $res), - "job.done" + "task.done" ); } (Level::$level, LatencyUnit::Millis) => { @@ -114,7 +114,7 @@ macro_rules! log_pattern_match { Level::$level, done_in = format_args!("{}ms", $done_in.as_millis()), result = format_args!("{:?}", $res), - "job.done" + "task.done" ); } (Level::$level, LatencyUnit::Micros) => { @@ -122,7 +122,7 @@ macro_rules! log_pattern_match { Level::$level, done_in = format_args!("{}Ξs", $done_in.as_micros()), result = format_args!("{:?}", $res), - "job.done" + "task.done" ); } (Level::$level, LatencyUnit::Nanos) => { @@ -130,7 +130,7 @@ macro_rules! log_pattern_match { Level::$level, done_in = format_args!("{}ns", $done_in.as_nanos()), result = format_args!("{:?}", $res), - "job.done" + "task.done" ); } diff --git a/src/lib.rs b/src/lib.rs index 268653d..385e3b4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,17 +19,13 @@ //! ```rust, no_run //! use apalis::prelude::*; //! use serde::{Deserialize, Serialize}; -//! use apalis::redis::RedisStorage; +//! use apalis_redis::{RedisStorage, Config}; //! //! #[derive(Debug, Deserialize, Serialize)] //! struct Email { //! to: String, //! } //! -//! impl Job for Email { -//! const NAME: &'static str = "apalis::Email"; -//! } -//! //! async fn send_email(job: Email, data: Data) -> Result<(), Error> { //! Ok(()) //! } @@ -37,13 +33,14 @@ //! #[tokio::main] //! async fn main() { //! let redis = std::env::var("REDIS_URL").expect("Missing REDIS_URL env variable"); -//! let conn = apalis::redis::connect(redis).await.unwrap(); +//! let conn = apalis_redis::connect(redis).await.unwrap(); //! let storage = RedisStorage::new(conn); -//! Monitor::::new() -//! .register_with_count(2, { +//! Monitor::new() +//! .register({ //! WorkerBuilder::new(&format!("quick-sand")) +//! .concurrency(2) //! .data(0usize) -//! .source(storage.clone()) +//! .backend(storage.clone()) //! .build_fn(send_email) //! }) //! .run() @@ -66,95 +63,35 @@ //! [`tower-http`]: https://crates.io/crates/tower-http //! [`Layer`]: https://docs.rs/tower/latest/tower/trait.Layer.html //! [`Stream`]: https://docs.rs/futures/latest/futures/stream/trait.Stream.html - -/// Include the default Redis storage -#[cfg(feature = "redis")] -#[cfg_attr(docsrs, doc(cfg(feature = "redis")))] -pub mod redis { - pub use apalis_redis::*; -} - -/// Include the default Sqlite storage -#[cfg(feature = "sqlite")] -#[cfg_attr(docsrs, doc(cfg(feature = "sqlite")))] -pub mod sqlite { - pub use apalis_sql::sqlite::*; -} - -/// Include the default Postgres storage -#[cfg(feature = "postgres")] -#[cfg_attr(docsrs, doc(cfg(feature = "postgres")))] -pub mod postgres { - pub use apalis_sql::postgres::*; -} - -/// Include the default MySQL storage -#[cfg(feature = "mysql")] -#[cfg_attr(docsrs, doc(cfg(feature = "mysql")))] -pub mod mysql { - pub use apalis_sql::mysql::*; -} - -/// Include Cron utilities -#[cfg(feature = "cron")] -#[cfg_attr(docsrs, doc(cfg(feature = "cron")))] -pub mod cron { - pub use apalis_cron::*; -} - /// apalis fully supports middleware via [`Layer`](https://docs.rs/tower/latest/tower/trait.Layer.html) pub mod layers; -/// Utilities for working with apalis -pub mod utils { - /// Executor for [`tokio`] - #[cfg(feature = "tokio-comp")] - #[derive(Clone, Debug, Default)] - pub struct TokioExecutor; - - #[cfg(feature = "tokio-comp")] - impl apalis_core::executor::Executor for TokioExecutor { - fn spawn(&self, future: impl std::future::Future + Send + 'static) { - tokio::spawn(future); - } - } - - /// Executor for [`async_std`] - #[cfg(feature = "async-std-comp")] - #[derive(Clone, Debug, Default)] - pub struct AsyncStdExecutor; - - #[cfg(feature = "async-std-comp")] - impl apalis_core::executor::Executor for AsyncStdExecutor { - fn spawn(&self, future: impl std::future::Future + Send + 'static) { - async_std::task::spawn(future); - } - } -} - /// Common imports pub mod prelude { - #[cfg(feature = "tokio-comp")] - pub use crate::utils::TokioExecutor; + pub use crate::layers::WorkerBuilderExt; pub use apalis_core::{ + backend::Backend, + backend::BackendExpose, + backend::Stat, + backend::WorkerState, builder::{WorkerBuilder, WorkerFactory, WorkerFactoryFn}, + codec::Codec, data::Extensions, error::{BoxDynError, Error}, - executor::Executor, layers::extensions::{AddExtension, Data}, memory::{MemoryStorage, MemoryWrapper}, - monitor::{Monitor, MonitorContext}, - mq::{Message, MessageQueue}, + monitor::Monitor, + mq::MessageQueue, notify::Notify, poller::stream::BackendStream, - poller::{controller::Controller, FetchNext, Poller}, + poller::{controller::Controller, Poller}, + request::State, request::{Request, RequestStream}, response::IntoResponse, - service_fn::{service_fn, FromData, ServiceFn}, - storage::{Job, Storage, StorageStream}, + service_fn::{service_fn, FromRequest, ServiceFn}, + storage::Storage, task::attempt::Attempt, task::task_id::TaskId, worker::{Context, Event, Ready, Worker, WorkerError, WorkerId}, - Backend, Codec, }; }