diff --git a/Cargo.lock b/Cargo.lock index 13d1484..3b59a98 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -387,6 +387,7 @@ dependencies = [ "chrono", "clap", "eyre", + "hyper-util", "proto", "tokio", "tracing", diff --git a/ctl/Cargo.toml b/ctl/Cargo.toml index b969738..e336a9b 100644 --- a/ctl/Cargo.toml +++ b/ctl/Cargo.toml @@ -14,6 +14,7 @@ axum.workspace = true chrono.workspace = true clap.workspace = true eyre.workspace = true +hyper-util.workspace = true tokio.workspace = true tracing.workspace = true uuid.workspace = true diff --git a/ctl/src/balancer/mod.rs b/ctl/src/balancer/mod.rs new file mode 100644 index 0000000..8d9d68c --- /dev/null +++ b/ctl/src/balancer/mod.rs @@ -0,0 +1,135 @@ +use std::{ + collections::HashMap, + net::{IpAddr, SocketAddr}, + str::FromStr as _, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Mutex, + }, + time::Duration, +}; + +use axum::{ + body::Body, + extract::{ConnectInfo, Request, State}, + http::{ + uri::{Authority, Scheme}, + HeaderValue, StatusCode, Uri, + }, + response::IntoResponse, +}; +use hyper_util::{ + client::legacy::{connect::HttpConnector, Client}, + rt::TokioExecutor, +}; +use proto::{ + common::{instance::InstanceId, service::ServiceId}, + well_known::{PROXY_FORWARDED_HEADER_NAME, PROXY_INSTANCE_HEADER_NAME}, +}; +use utils::http::{self, OptionExt as _, ResultExt as _}; + +#[derive(Default)] +pub struct InstanceBag { + pub instances: Vec<(InstanceId, IpAddr)>, + pub count: AtomicUsize, +} + +#[derive(Clone)] +pub struct BalancerState { + pub addrs: Arc>>, + pub client: Client, +} + +impl BalancerState { + #[must_use] + pub fn new() -> (Self, BalancerHandle) { + let map = Arc::new(Mutex::new(HashMap::default())); + ( + BalancerState { + addrs: map.clone(), + client: { + let mut connector = HttpConnector::new(); + connector.set_keepalive(Some(Duration::from_secs(60))); + connector.set_nodelay(true); + Client::builder(TokioExecutor::new()).build::<_, Body>(connector) + }, + }, + BalancerHandle { addrs: map }, + ) + } + + pub fn next(&self, service: &ServiceId) -> (InstanceId, IpAddr) { + let map = self.addrs.lock().unwrap(); + let bag = map.get(service).unwrap(); + let count = bag.count.fetch_add(1, Ordering::Relaxed); + bag.instances[count % bag.instances.len()] + } +} + +pub struct BalancerHandle { + pub addrs: Arc>>, +} + +impl BalancerHandle { + #[allow(dead_code)] + pub fn add_instance(&mut self, id: ServiceId, at: (InstanceId, IpAddr)) { + let mut map = self.addrs.lock().unwrap(); + let bag = map.entry(id).or_default(); + bag.instances.push(at); + } + + #[allow(dead_code)] + pub fn drop_instance(&mut self, id: &ServiceId, at: (InstanceId, IpAddr)) { + let mut map = self.addrs.lock().unwrap(); + let Some(bag) = map.get_mut(id) else { + return; + }; + bag.instances + .retain(|(inst, addr)| inst == &at.0 && addr == &at.1); + } +} + +pub async fn proxy( + ConnectInfo(addr): ConnectInfo, + State(balancer): State, + mut req: Request, +) -> http::Result { + let service = extract_service_id(&mut req)?; + + let (instance, server_addr) = balancer.next(&service); + + *req.uri_mut() = { + let uri = req.uri(); + let mut parts = uri.clone().into_parts(); + parts.authority = Authority::from_str(&server_addr.to_string()).ok(); + parts.scheme = Some(Scheme::HTTP); + Uri::from_parts(parts).unwrap() + }; + + req.headers_mut().insert( + PROXY_INSTANCE_HEADER_NAME, + HeaderValue::from_str(&instance.to_string()).unwrap(), + ); + req.headers_mut().insert( + PROXY_FORWARDED_HEADER_NAME, + HeaderValue::from_str(&addr.ip().to_string()).unwrap(), + ); + + balancer + .client + .request(req) + .await + .http_error(StatusCode::BAD_GATEWAY, "bad gateway") +} + +fn extract_service_id(req: &mut Request) -> http::Result { + let inner = req + .headers() + .get("Host") + .unwrap() + .to_str() + .ok() + .and_then(|s| s.parse().ok()) + .or_http_error(StatusCode::BAD_REQUEST, "invalid service name")?; + Ok(ServiceId(inner)) +} diff --git a/ctl/src/main.rs b/ctl/src/main.rs index 1d1e003..ed7b0fb 100644 --- a/ctl/src/main.rs +++ b/ctl/src/main.rs @@ -3,15 +3,17 @@ use std::{ sync::Arc, }; +use axum::handler::Handler; use clap::Parser; use proto::well_known::{CTL_BALANCER_PORT, CTL_HTTP_PORT}; use tokio::task::JoinSet; use tracing::info; use utils::server::mk_listener; -use crate::{args::CtlArgs, http::HttpState, worker_mgr::WorkerMgr}; +use crate::{args::CtlArgs, balancer::BalancerState, http::HttpState, worker_mgr::WorkerMgr}; mod args; +mod balancer; mod http; mod worker_mgr; @@ -24,7 +26,7 @@ async fn main() -> eyre::Result<()> { let args = Arc::new(CtlArgs::parse()); info!(?args, "started ctl"); - let _balancer_listener = mk_listener(ANY_IP, CTL_BALANCER_PORT).await?; + let balancer_listener = mk_listener(ANY_IP, CTL_BALANCER_PORT).await?; let http_listener = mk_listener(ANY_IP, CTL_HTTP_PORT).await?; let mut bag = JoinSet::new(); @@ -34,6 +36,15 @@ async fn main() -> eyre::Result<()> { worker_mgr.run().await; }); + let (balancer, _balancer_handle) = BalancerState::new(); + bag.spawn(async move { + let app = balancer::proxy + .with_state(balancer) + .into_make_service_with_connect_info::(); + info!("balancer http listening at {ANY_IP}:{CTL_BALANCER_PORT}"); + axum::serve(balancer_listener, app).await.unwrap(); + }); + bag.spawn(async move { let state = HttpState { worker_mgr: worker_mgr_handle, diff --git a/proto/src/well_known.rs b/proto/src/well_known.rs index fca78c3..dffc89b 100644 --- a/proto/src/well_known.rs +++ b/proto/src/well_known.rs @@ -2,6 +2,7 @@ use std::time::Duration; pub const GRACEFUL_SHUTDOWN_DEADLINE: Duration = Duration::from_secs(20); +pub const PROXY_FORWARDED_HEADER_NAME: &str = "X-Tuc-Fwd-For"; pub const PROXY_INSTANCE_HEADER_NAME: &str = "X-Tuc-Inst"; pub const CTL_HTTP_PORT: u16 = 7070;