diff --git a/Cargo.lock b/Cargo.lock index e9a2a95c43..1cad16b7f0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2601,9 +2601,9 @@ dependencies = [ [[package]] name = "rustls-acme" -version = "0.4.0-beta4" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ca5c1e3f26e9be750856a2d986219fac80b28ecf65f69b3d396c50b46b7c9b2" +checksum = "b5b29a57dc718b12d5eba48a08548c615854a6ce661d9463254779f358922950" dependencies = [ "async-h1", "async-io", diff --git a/Cargo.toml b/Cargo.toml index a6294c1f64..5994986fda 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,7 +28,7 @@ rayon = "1.5.1" redb = { version = "0.5.0", git = "https://github.com/cberner/redb", branch = "master" } rust-embed = "6.4.0" rustls = "0.20.6" -rustls-acme = { version = "0.4.0-beta2", features = ["axum"] } +rustls-acme = { version = "0.5.0", features = ["axum"] } serde = { version = "1.0.137", features = ["derive"] } serde_cbor = "0.11.2" serde_json = "1.0.81" diff --git a/deploy/ord.service b/deploy/ord.service index a5888cd9f1..7b90ccab78 100644 --- a/deploy/ord.service +++ b/deploy/ord.service @@ -15,7 +15,8 @@ ExecStart=/usr/local/bin/ord \ --chain ${CHAIN} \ server \ --acme-contact mailto:casey@rodarmor.com \ - --https-port 443 + --http \ + --https Group=ord MemoryDenyWriteExecute=true NoNewPrivileges=true diff --git a/src/main.rs b/src/main.rs index ec93bfe1bc..3c049b8ad1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -68,7 +68,7 @@ use { thread, time::{Duration, Instant}, }, - tokio::runtime::Runtime, + tokio::{runtime::Runtime, task}, tower_http::cors::{Any, CorsLayer}, }; diff --git a/src/subcommand/server.rs b/src/subcommand/server.rs index 0f41cf6df3..2121ebfff5 100644 --- a/src/subcommand/server.rs +++ b/src/subcommand/server.rs @@ -9,7 +9,6 @@ use { }, }, axum::{body, http::header, response::Response}, - clap::ArgGroup, rust_embed::RustEmbed, rustls_acme::{ acme::{LETS_ENCRYPT_PRODUCTION_DIRECTORY, LETS_ENCRYPT_STAGING_DIRECTORY}, @@ -47,7 +46,6 @@ impl Display for StaticHtml { } #[derive(Debug, Parser)] -#[clap(group = ArgGroup::new("port").multiple(false))] pub(crate) struct Server { #[clap( long, @@ -62,20 +60,23 @@ pub(crate) struct Server { acme_domain: Vec<String>, #[clap( long, - group = "port", - help = "Listen on <HTTP_PORT> for incoming HTTP requests. Defaults to 80." + help = "Listen on <HTTP_PORT> for incoming HTTP requests. [default: 80]." )] http_port: Option<u16>, #[clap( long, group = "port", - help = "Listen on <HTTPS_PORT> for incoming HTTPS requests." + help = "Listen on <HTTPS_PORT> for incoming HTTPS requests. [default: 443]." )] https_port: Option<u16>, - #[structopt(long, help = "Store ACME TLS certificates in <ACME_CACHE>.")] + #[clap(long, help = "Store ACME TLS certificates in <ACME_CACHE>.")] acme_cache: Option<PathBuf>, - #[structopt(long, help = "Provide ACME contact <ACME_CONTACT>.")] + #[clap(long, help = "Provide ACME contact <ACME_CONTACT>.")] acme_contact: Vec<String>, + #[clap(long, help = "Serve HTTP traffic on <HTTP_PORT>.")] + http: bool, + #[clap(long, help = "Serve HTTPS traffic on <HTTPS_PORT>.")] + https: bool, } impl Server { @@ -91,7 +92,7 @@ impl Server { thread::sleep(Duration::from_millis(100)); }); - let app = Router::new() + let router = Router::new() .route("/", get(Self::home)) .route("/block/:hash", get(Self::block)) .route("/bounties", get(Self::bounties)) @@ -113,35 +114,58 @@ impl Server { .allow_origin(Any), ); - let port = self.port(); - - let addr = (self.address.as_str(), port) - .to_socket_addrs()? - .next() - .ok_or_else(|| anyhow!("Failed to get socket addrs"))?; - let handle = Handle::new(); LISTENERS.lock().unwrap().push(handle.clone()); - let server = axum_server::Server::bind(addr).handle(handle); - - match self.acceptor(&options)? { - Some(acceptor) => { - server - .acceptor(acceptor) - .serve(app.into_make_service()) - .await? - } - None => server.serve(app.into_make_service()).await?, - } + let (http_result, https_result) = tokio::join!( + self.spawn(&router, &handle, None)?, + self.spawn(&router, &handle, self.acceptor(&options)?)? + ); + http_result.and(https_result)?.transpose()?; Ok(()) }) } - fn port(&self) -> u16 { - self.http_port.or(self.https_port).unwrap_or(80) + fn spawn( + &self, + router: &Router, + handle: &Handle, + https_acceptor: Option<AxumAcceptor>, + ) -> Result<task::JoinHandle<Option<io::Result<()>>>> { + let addr = if https_acceptor.is_some() { + self.https_port() + } else { + self.http_port() + } + .map(|port| { + (self.address.as_str(), port) + .to_socket_addrs()? + .next() + .ok_or_else(|| anyhow!("Failed to get socket addrs")) + .map(|addr| (addr, router.clone(), handle.clone())) + }) + .transpose()?; + + Ok(tokio::spawn(async move { + if let Some((addr, router, handle)) = addr { + Some(if let Some(acceptor) = https_acceptor { + axum_server::Server::bind(addr) + .handle(handle) + .acceptor(acceptor) + .serve(router.into_make_service()) + .await + } else { + axum_server::Server::bind(addr) + .handle(handle) + .serve(router.into_make_service()) + .await + }) + } else { + None + } + })) } fn acme_cache(acme_cache: Option<&PathBuf>, options: &Options) -> Result<PathBuf> { @@ -160,8 +184,24 @@ impl Server { } } + fn http_port(&self) -> Option<u16> { + if self.http || self.http_port.is_some() || (self.https_port.is_none() && !self.https) { + Some(self.http_port.unwrap_or(80)) + } else { + None + } + } + + fn https_port(&self) -> Option<u16> { + if self.https || self.https_port.is_some() { + Some(self.https_port.unwrap_or(443)) + } else { + None + } + } + fn acceptor(&self, options: &Options) -> Result<Option<AxumAcceptor>> { - if self.https_port.is_some() { + if self.https_port().is_some() { let config = AcmeConfig::new(Self::acme_domains(&self.acme_domain)?) .contact(&self.acme_contact) .cache_option(Some(DirCache::new(Self::acme_cache( @@ -443,27 +483,89 @@ impl Server { mod tests { use super::*; - #[test] - fn port_defaults_to_80() { - match Arguments::try_parse_from(&["ord", "server"]) - .unwrap() - .subcommand - { - Subcommand::Server(server) => assert_eq!(server.port(), 80), - subcommand => panic!("Unexpected subcommand: {subcommand:?}"), + fn parse_server_args(args: &str) -> Server { + match Arguments::try_parse_from( + ["ord", "server"] + .iter() + .cloned() + .chain(args.split_whitespace()), + ) { + Ok(arguments) => match arguments.subcommand { + Subcommand::Server(server) => server, + subcommand => panic!("Unexpected subcommand: {subcommand:?}"), + }, + Err(err) => panic!("Error parsing arguments: {err}"), } } #[test] - fn http_and_https_port_conflict() { - let err = Arguments::try_parse_from(&["ord", "server", "--http-port=0", "--https-port=0"]) - .unwrap_err() - .to_string(); + fn http_and_https_port_dont_conflict() { + parse_server_args( + "--http-port 0 --https-port 0 --acme-cache foo --acme-contact bar --acme-domain baz", + ); + } + + #[test] + fn http_port_defaults_to_80() { + assert_eq!(parse_server_args("").http_port(), Some(80)); + } + + #[test] + fn https_port_defaults_to_none() { + assert_eq!(parse_server_args("").https_port(), None); + } + + #[test] + fn https_sets_https_port_to_443() { + assert_eq!( + parse_server_args("--https --acme-cache foo --acme-contact bar --acme-domain baz") + .https_port(), + Some(443) + ); + } + + #[test] + fn https_disables_http() { + assert_eq!( + parse_server_args("--https --acme-cache foo --acme-contact bar --acme-domain baz") + .http_port(), + None + ); + } + + #[test] + fn https_port_disables_http() { + assert_eq!( + parse_server_args("--https-port 433 --acme-cache foo --acme-contact bar --acme-domain baz") + .http_port(), + None + ); + } + + #[test] + fn https_port_sets_https_port() { + assert_eq!( + parse_server_args("--https-port 1000 --acme-cache foo --acme-contact bar --acme-domain baz") + .https_port(), + Some(1000) + ); + } + + #[test] + fn http_with_https_leaves_http_enabled() { + assert_eq!( + parse_server_args("--https --http --acme-cache foo --acme-contact bar --acme-domain baz") + .http_port(), + Some(80) + ); + } - assert!( - err.starts_with("error: The argument '--http-port <HTTP_PORT>' cannot be used with '--https-port <HTTPS_PORT>'\n"), - "{}", - err + #[test] + fn http_with_https_leaves_https_enabled() { + assert_eq!( + parse_server_args("--https --http --acme-cache foo --acme-contact bar --acme-domain baz") + .https_port(), + Some(443) ); }