diff --git a/examples/tokio-redis/Cargo.toml b/examples/tokio-redis/Cargo.toml new file mode 100644 index 0000000000..fb276849e8 --- /dev/null +++ b/examples/tokio-redis/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "example-tokio-redis" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +axum = { path = "../../axum" } +bb8 = "0.7.1" +bb8-redis = "0.14.0" +redis = "0.24.0" +tokio = { version = "1.0", features = ["full"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/tokio-redis/src/main.rs b/examples/tokio-redis/src/main.rs new file mode 100644 index 0000000000..f0109f2127 --- /dev/null +++ b/examples/tokio-redis/src/main.rs @@ -0,0 +1,106 @@ +//! Run with +//! +//! ```not_rust +//! cargo run -p example-tokio-redis +//! ``` + +use axum::{ + async_trait, + extract::{FromRef, FromRequestParts, State}, + http::{request::Parts, StatusCode}, + routing::get, + Router, +}; +use bb8::{Pool, PooledConnection}; +use bb8_redis::RedisConnectionManager; +use redis::AsyncCommands; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +use bb8_redis::bb8; + +#[tokio::main] +async fn main() { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "example_tokio_redis=debug".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + tracing::debug!("connecting to redis"); + let manager = RedisConnectionManager::new("redis://localhost").unwrap(); + let pool = bb8::Pool::builder().build(manager).await.unwrap(); + + { + // ping the database before starting + let mut conn = pool.get().await.unwrap(); + conn.set::<&str, &str, ()>("foo", "bar").await.unwrap(); + let result: String = conn.get("foo").await.unwrap(); + assert_eq!(result, "bar"); + } + tracing::debug!("successfully connected to redis and pinged it"); + + // build our application with some routes + let app = Router::new() + .route( + "/", + get(using_connection_pool_extractor).post(using_connection_extractor), + ) + .with_state(pool); + + // run it + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .unwrap(); + tracing::debug!("listening on {}", listener.local_addr().unwrap()); + axum::serve(listener, app).await.unwrap(); +} + +type ConnectionPool = Pool; + +async fn using_connection_pool_extractor( + State(pool): State, +) -> Result { + let mut conn = pool.get().await.map_err(internal_error)?; + let result: String = conn.get("foo").await.map_err(internal_error)?; + Ok(result) +} + +// we can also write a custom extractor that grabs a connection from the pool +// which setup is appropriate depends on your application +struct DatabaseConnection(PooledConnection<'static, RedisConnectionManager>); + +#[async_trait] +impl FromRequestParts for DatabaseConnection +where + ConnectionPool: FromRef, + S: Send + Sync, +{ + type Rejection = (StatusCode, String); + + async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result { + let pool = ConnectionPool::from_ref(state); + + let conn = pool.get_owned().await.map_err(internal_error)?; + + Ok(Self(conn)) + } +} + +async fn using_connection_extractor( + DatabaseConnection(mut conn): DatabaseConnection, +) -> Result { + let result: String = conn.get("foo").await.map_err(internal_error)?; + + Ok(result) +} + +/// Utility function for mapping any error into a `500 Internal Server Error` +/// response. +fn internal_error(err: E) -> (StatusCode, String) +where + E: std::error::Error, +{ + (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) +}