diff --git a/Cargo.lock b/Cargo.lock index c4f0e80d5d..9102ef1bdb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -689,6 +689,7 @@ dependencies = [ "linkerd-stack-tracing", "linkerd-system", "linkerd-tls", + "linkerd-tls-rustls", "linkerd-trace-context", "linkerd-tracing", "linkerd-transport-header", @@ -991,12 +992,6 @@ name = "linkerd-identity" version = "0.1.0" dependencies = [ "linkerd-dns-name", - "ring", - "thiserror", - "tokio-rustls", - "tracing", - "untrusted", - "webpki", ] [[package]] @@ -1009,7 +1004,6 @@ dependencies = [ "linkerd-errno", "pin-project", "tokio", - "tokio-rustls", "tokio-test", "tokio-util", ] @@ -1152,6 +1146,7 @@ dependencies = [ "linkerd-metrics", "linkerd-stack", "linkerd-tls", + "linkerd-tls-rustls", "linkerd2-proxy-api", "pin-project", "thiserror", @@ -1189,6 +1184,7 @@ dependencies = [ "linkerd-proxy-transport", "linkerd-stack", "linkerd-tls", + "linkerd-tls-rustls", "linkerd2-proxy-api", "parking_lot", "pin-project", @@ -1364,13 +1360,29 @@ dependencies = [ "linkerd-io", "linkerd-proxy-transport", "linkerd-stack", + "linkerd-tls-rustls", "linkerd-tracing", + "pin-project", "thiserror", "tokio", - "tokio-rustls", "tower", "tracing", "untrusted", +] + +[[package]] +name = "linkerd-tls-rustls" +version = "0.1.0" +dependencies = [ + "futures", + "linkerd-identity", + "linkerd-io", + "linkerd-stack", + "linkerd-tls", + "ring", + "thiserror", + "tokio-rustls", + "tracing", "webpki", ] diff --git a/Cargo.toml b/Cargo.toml index 1ce77e888b..3599140009 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,6 +53,7 @@ members = [ "linkerd/system", "linkerd/tonic-watch", "linkerd/tls", + "linkerd/tls/rustls", "linkerd/tracing", "linkerd/transport-header", "linkerd/transport-metrics", diff --git a/linkerd/app/admin/src/stack.rs b/linkerd/app/admin/src/stack.rs index d669ad38de..d97dc93cdb 100644 --- a/linkerd/app/admin/src/stack.rs +++ b/linkerd/app/admin/src/stack.rs @@ -2,8 +2,9 @@ use linkerd_app_core::{ classify, config::ServerConfig, detect, drain, errors, + identity::LocalCrtKey, metrics::{self, FmtMetrics}, - proxy::{http, identity::LocalCrtKey}, + proxy::http, serve, svc::{self, ExtractParam, InsertParam, Param}, tls, trace, diff --git a/linkerd/app/core/Cargo.toml b/linkerd/app/core/Cargo.toml index 5049399b43..e4f7c3735e 100644 --- a/linkerd/app/core/Cargo.toml +++ b/linkerd/app/core/Cargo.toml @@ -58,6 +58,7 @@ linkerd-tracing = { path = "../../tracing" } linkerd-transport-header = { path = "../../transport-header" } linkerd-transport-metrics = { path = "../../transport-metrics" } linkerd-tls = { path = "../../tls" } +linkerd-tls-rustls = { path = "../../tls/rustls" } linkerd-trace-context = { path = "../../trace-context" } regex = "1.5.4" serde_json = "1" diff --git a/linkerd/app/core/src/control.rs b/linkerd/app/core/src/control.rs index 4dbf6dd22b..7f6f31a3bf 100644 --- a/linkerd/app/core/src/control.rs +++ b/linkerd/app/core/src/control.rs @@ -54,7 +54,8 @@ impl Config { > + Clone, > where - L: Clone + svc::Param + Send + Sync + 'static, + L: svc::NewService, + L: Clone + Send + Sync + 'static, { let addr = self.addr; diff --git a/linkerd/app/core/src/lib.rs b/linkerd/app/core/src/lib.rs index 866e0cf255..0f8bdff984 100644 --- a/linkerd/app/core/src/lib.rs +++ b/linkerd/app/core/src/lib.rs @@ -20,13 +20,14 @@ pub use linkerd_dns; pub use linkerd_error::{is_error, Error, Infallible, Recover, Result}; pub use linkerd_exp_backoff as exp_backoff; pub use linkerd_http_metrics as http_metrics; -pub use linkerd_identity as identity; pub use linkerd_io as io; pub use linkerd_opencensus as opencensus; +pub use linkerd_proxy_identity as identity; pub use linkerd_service_profiles as profiles; pub use linkerd_stack_metrics as stack_metrics; pub use linkerd_stack_tracing as stack_tracing; pub use linkerd_tls as tls; +pub use linkerd_tls_rustls as rustls; pub use linkerd_tracing as trace; pub use linkerd_transport_header as transport_header; @@ -56,7 +57,7 @@ const DEFAULT_PORT: u16 = 80; #[derive(Clone, Debug)] pub struct ProxyRuntime { - pub identity: proxy::identity::LocalCrtKey, + pub identity: identity::LocalCrtKey, pub metrics: metrics::Proxy, pub tap: proxy::tap::Registry, pub span_sink: http_tracing::OpenCensusSink, diff --git a/linkerd/app/core/src/proxy/mod.rs b/linkerd/app/core/src/proxy/mod.rs index ec37b9965d..2fda2f6f4b 100644 --- a/linkerd/app/core/src/proxy/mod.rs +++ b/linkerd/app/core/src/proxy/mod.rs @@ -5,7 +5,6 @@ pub use linkerd_proxy_core as core; pub use linkerd_proxy_discover as discover; pub use linkerd_proxy_dns_resolve as dns_resolve; pub use linkerd_proxy_http as http; -pub use linkerd_proxy_identity as identity; pub use linkerd_proxy_resolve as resolve; pub use linkerd_proxy_tap as tap; pub use linkerd_proxy_tcp as tcp; diff --git a/linkerd/app/inbound/fuzz/Cargo.toml b/linkerd/app/inbound/fuzz/Cargo.toml index 3f7136287b..170afb67f0 100644 --- a/linkerd/app/inbound/fuzz/Cargo.toml +++ b/linkerd/app/inbound/fuzz/Cargo.toml @@ -11,15 +11,15 @@ cargo-fuzz = true [target.'cfg(fuzzing)'.dependencies] arbitrary = { version = "1", features = ["derive"] } -libfuzzer-sys = { version = "0.4.2", features = ["arbitrary-derive"] } -tokio = { version = "1", features = ["full"] } hyper = { version = "0.14.9", features = ["http1", "http2"] } http = "0.2" +libfuzzer-sys = { version = "0.4.2", features = ["arbitrary-derive"] } linkerd-app-core = { path = "../../core" } linkerd-app-inbound = { path = ".." } linkerd-app-test = { path = "../../test" } linkerd-proxy-identity = { path = "../../../proxy/identity", features = ["test-util"] } linkerd-tracing = { path = "../../../tracing", features = ["ansi"] } +tokio = { version = "1", features = ["full"] } tracing = "0.1" # Prevent this from interfering with workspaces diff --git a/linkerd/app/inbound/src/detect.rs b/linkerd/app/inbound/src/detect.rs index 43c551af7f..a6863ccbc6 100644 --- a/linkerd/app/inbound/src/detect.rs +++ b/linkerd/app/inbound/src/detect.rs @@ -4,8 +4,8 @@ use crate::{ }; use linkerd_app_core::{ detect, identity, io, - proxy::{http, identity::LocalCrtKey}, - svc, tls, + proxy::http, + rustls, svc, tls, transport::{ self, addrs::{ClientAddr, OrigDstAddr, Remote}, @@ -50,9 +50,11 @@ struct ConfigureHttpDetect; #[derive(Clone)] struct TlsParams { timeout: tls::server::Timeout, - identity: LocalCrtKey, + identity: identity::LocalCrtKey, } +type TlsIo = tls::server::Io>, I>; + // === impl Inbound === impl Inbound { @@ -92,7 +94,7 @@ impl Inbound { I: Debug + Send + Sync + Unpin + 'static, N: svc::NewService, N: Clone + Send + Sync + Unpin + 'static, - NSvc: svc::Service, Response = ()>, + NSvc: svc::Service, Response = ()>, NSvc: Send + Unpin + 'static, NSvc::Error: Into, NSvc::Future: Send, @@ -135,10 +137,12 @@ impl Inbound { .push_on_service(svc::MapTargetLayer::new(io::BoxedIo::new)) .into_inner(), ) - .push(tls::NewDetectTls::::layer(TlsParams { - timeout: tls::server::Timeout(detect_timeout), - identity: rt.identity.clone(), - })) + .push(tls::NewDetectTls::::layer( + TlsParams { + timeout: tls::server::Timeout(detect_timeout), + identity: rt.identity.clone(), + }, + )) .push_switch( // If this port's policy indicates that authentication is not required and // detection should be skipped, use the TCP stack directly. @@ -425,9 +429,9 @@ impl svc::ExtractParam for TlsParams { } } -impl svc::ExtractParam for TlsParams { +impl svc::ExtractParam for TlsParams { #[inline] - fn extract_param(&self, _: &T) -> LocalCrtKey { + fn extract_param(&self, _: &T) -> identity::LocalCrtKey { self.identity.clone() } } diff --git a/linkerd/app/inbound/src/direct.rs b/linkerd/app/inbound/src/direct.rs index 0a88b31449..0bacf47db5 100644 --- a/linkerd/app/inbound/src/direct.rs +++ b/linkerd/app/inbound/src/direct.rs @@ -1,14 +1,14 @@ use crate::{policy, Inbound}; use linkerd_app_core::{ - io, - proxy::identity::LocalCrtKey, + identity::LocalCrtKey, + io, rustls, svc::{self, ExtractParam, InsertParam, Param}, tls, transport::{self, metrics::SensorIo, ClientAddr, OrigDstAddr, Remote, ServerAddr}, transport_header::{self, NewTransportHeaderServer, SessionProtocol, TransportHeader}, Conditional, Error, NameAddr, Result, }; -use std::{convert::TryFrom, fmt::Debug}; +use std::{convert::TryFrom, fmt::Debug, task}; use thiserror::Error; use tracing::{debug_span, info_span}; @@ -52,8 +52,9 @@ pub struct ClientInfo { pub local_addr: OrigDstAddr, } -type FwdIo = SensorIo>>; -pub type GatewayIo = io::EitherIo, SensorIo>>; +type TlsIo = tls::server::Io>, I>; +type FwdIo = SensorIo>>; +pub type GatewayIo = io::EitherIo, SensorIo>>; #[derive(Clone)] struct TlsParams { @@ -102,7 +103,6 @@ impl Inbound { rt.metrics.proxy.transport.clone(), )) .instrument(|_: &_| debug_span!("opaque")) - .check_new_service::() // When the transport header is present, it may be used for either local TCP // forwarding, or we may be processing an HTTP gateway connection. HTTP gateway // connections that have a transport header must provide a target name as a part of @@ -129,8 +129,13 @@ impl Inbound { negotiated_protocol: client.alpn, }, ); - let permit = allow.check_authorized(client.client_addr, &tls)?; - Ok(svc::Either::A(Local { addr: Remote(ServerAddr(addr)), permit, client_id: client.client_id, })) + let permit = + allow.check_authorized(client.client_addr, &tls)?; + Ok(svc::Either::A(Local { + addr: Remote(ServerAddr(addr)), + permit, + client_id: client.client_id, + })) } TransportHeader { port, @@ -167,30 +172,27 @@ impl Inbound { .instrument( |g: &GatewayTransportHeader| info_span!("gateway", dst = %g.target), ) - .check_new_service::>>() .into_inner(), ) // Use ALPN to determine whether a transport header should be read. .push(NewTransportHeaderServer::layer(detect_timeout)) - .push_request_filter( - |client: ClientInfo| -> Result<_> { - if client.header_negotiated() { - Ok(client) - } else { - Err(RefusedNoTarget.into()) - } - }, - ) - .check_new_service::>() + .push_request_filter(|client: ClientInfo| -> Result<_> { + if client.header_negotiated() { + Ok(client) + } else { + Err(RefusedNoTarget.into()) + } + }) // Build a ClientInfo target for each accepted connection. Refuse the // connection if it doesn't include an mTLS identity. .push_request_filter(ClientInfo::try_from) .push(svc::ArcNewService::layer()) - .push(tls::NewDetectTls::::layer(TlsParams { - timeout: tls::server::Timeout(detect_timeout), - identity: WithTransportHeaderAlpn(rt.identity.clone()), - })) - .check_new_service::() + .push(tls::NewDetectTls::::layer( + TlsParams { + timeout: tls::server::Timeout(detect_timeout), + identity: WithTransportHeaderAlpn(rt.identity.clone()), + }, + )) .push_on_service(svc::BoxService::layer()) .push(svc::ArcNewService::layer()) }) @@ -293,8 +295,20 @@ impl Param for GatewayTransportHeader { // === impl WithTransportHeaderAlpn === -impl svc::Param for WithTransportHeaderAlpn { - fn param(&self) -> tls::server::Config { +impl svc::Service for WithTransportHeaderAlpn +where + I: io::AsyncRead + io::AsyncWrite + Send + Unpin, +{ + type Response = (tls::ServerTls, rustls::ServerIo); + type Error = io::Error; + type Future = rustls::TerminateFuture; + + #[inline] + fn poll_ready(&mut self, _: &mut task::Context<'_>) -> task::Poll> { + task::Poll::Ready(Ok(())) + } + + fn call(&mut self, io: I) -> Self::Future { // Copy the underlying TLS config and set an ALPN value. // // TODO: Avoid cloning the server config for every connection. It would @@ -304,7 +318,7 @@ impl svc::Param for WithTransportHeaderAlpn { config .alpn_protocols .push(transport_header::PROTOCOL.into()); - config.into() + rustls::terminate(config.into(), io) } } diff --git a/linkerd/app/inbound/src/lib.rs b/linkerd/app/inbound/src/lib.rs index b23ea3bb4b..0318430aef 100644 --- a/linkerd/app/inbound/src/lib.rs +++ b/linkerd/app/inbound/src/lib.rs @@ -21,9 +21,9 @@ use linkerd_app_core::{ config::{ConnectConfig, ProxyConfig}, drain, http_tracing::OpenCensusSink, + identity::LocalCrtKey, io, - proxy::tcp, - proxy::{identity::LocalCrtKey, tap}, + proxy::{tap, tcp}, svc, transport::{self, Remote, ServerAddr}, Error, NameMatch, ProxyRuntime, diff --git a/linkerd/app/inbound/src/policy/config.rs b/linkerd/app/inbound/src/policy/config.rs index a9168ac913..621b14aec1 100644 --- a/linkerd/app/inbound/src/policy/config.rs +++ b/linkerd/app/inbound/src/policy/config.rs @@ -1,5 +1,5 @@ use super::{discover::Discover, DefaultPolicy, ServerPolicy, Store}; -use linkerd_app_core::{control, dns, metrics, proxy::identity::LocalCrtKey, svc::NewService}; +use linkerd_app_core::{control, dns, identity::LocalCrtKey, metrics, svc::NewService}; use std::collections::{HashMap, HashSet}; /// Configures inbound policies. diff --git a/linkerd/app/inbound/src/test_util.rs b/linkerd/app/inbound/src/test_util.rs index 54b9603974..7935983c53 100644 --- a/linkerd/app/inbound/src/test_util.rs +++ b/linkerd/app/inbound/src/test_util.rs @@ -3,10 +3,11 @@ pub use futures::prelude::*; use linkerd_app_core::{ config, dns::Suffix, - drain, exp_backoff, metrics, + drain, exp_backoff, + identity::LocalCrtKey, + metrics, proxy::{ http::{h1, h2}, - identity::LocalCrtKey, tap, }, transport::{Keepalive, ListenAddr}, diff --git a/linkerd/app/outbound/src/lib.rs b/linkerd/app/outbound/src/lib.rs index efd053b9e9..6dc1a26534 100644 --- a/linkerd/app/outbound/src/lib.rs +++ b/linkerd/app/outbound/src/lib.rs @@ -23,11 +23,11 @@ use linkerd_app_core::{ config::ProxyConfig, drain, http_tracing::OpenCensusSink, + identity::LocalCrtKey, io, profiles, proxy::{ api_resolve::{ConcreteAddr, Metadata}, core::Resolve, - identity::LocalCrtKey, tap, }, serve, diff --git a/linkerd/app/src/dst.rs b/linkerd/app/src/dst.rs index 17aad2aec5..2f0b3693f6 100644 --- a/linkerd/app/src/dst.rs +++ b/linkerd/app/src/dst.rs @@ -1,9 +1,10 @@ use linkerd_app_core::{ control, dns, exp_backoff::{ExponentialBackoff, ExponentialBackoffStream}, + identity::LocalCrtKey, metrics, profiles::{self, DiscoveryRejected}, - proxy::{api_resolve as api, http, identity::LocalCrtKey, resolve::recover}, + proxy::{api_resolve as api, http, resolve::recover}, svc::{self, NewService}, Error, Recover, }; diff --git a/linkerd/app/src/identity.rs b/linkerd/app/src/identity.rs index 49c89fc211..83e397d601 100644 --- a/linkerd/app/src/identity.rs +++ b/linkerd/app/src/identity.rs @@ -1,7 +1,4 @@ -pub use linkerd_app_core::identity::{ - Crt, CrtKey, Csr, InvalidName, Key, Name, TokenSource, TrustAnchors, -}; -pub use linkerd_app_core::proxy::identity::{certify, metrics, LocalCrtKey}; +pub use linkerd_app_core::identity::*; use linkerd_app_core::{ control, dns, exp_backoff::{ExponentialBackoff, ExponentialBackoffStream}, diff --git a/linkerd/app/src/lib.rs b/linkerd/app/src/lib.rs index 2cb84ebd11..db1d190d1a 100644 --- a/linkerd/app/src/lib.rs +++ b/linkerd/app/src/lib.rs @@ -352,7 +352,7 @@ impl App { .await_crt() .map_ok(move |id| { latch.release(); - info!("Certified identity: {}", id.name().as_ref()); + info!("Certified identity: {}", id.name()); }) .map_err(|_| { // The daemon task was lost?! diff --git a/linkerd/app/src/tap.rs b/linkerd/app/src/tap.rs index a173b83df7..fce6ff24cb 100644 --- a/linkerd/app/src/tap.rs +++ b/linkerd/app/src/tap.rs @@ -2,7 +2,7 @@ use futures::prelude::*; use linkerd_app_core::{ config::ServerConfig, drain, - proxy::identity::LocalCrtKey, + identity::LocalCrtKey, proxy::tap, serve, svc::{self, ExtractParam, InsertParam, Param}, diff --git a/linkerd/identity/Cargo.toml b/linkerd/identity/Cargo.toml index b3e92e5bf3..5ae5915822 100644 --- a/linkerd/identity/Cargo.toml +++ b/linkerd/identity/Cargo.toml @@ -4,17 +4,7 @@ version = "0.1.0" authors = ["Linkerd Developers "] license = "Apache-2.0" edition = "2018" - -[features] -default = [] -test-util = [] +publish = false [dependencies] linkerd-dns-name = { path = "../dns/name" } -ring = "0.16.19" -thiserror = "1.0" -tokio-rustls = "0.22" -tracing = "0.1.29" -untrusted = "0.7" -webpki = "=0.21.4" - diff --git a/linkerd/identity/src/lib.rs b/linkerd/identity/src/lib.rs index 615d129b67..2272e91237 100644 --- a/linkerd/identity/src/lib.rs +++ b/linkerd/identity/src/lib.rs @@ -1,131 +1,17 @@ #![deny(warnings, rust_2018_idioms)] #![forbid(unsafe_code)] -pub use ring::error::KeyRejected; -use ring::rand; -use ring::signature::EcdsaKeyPair; -use std::{fmt, fs, io, ops::Deref, str::FromStr, sync::Arc, time::SystemTime}; -use thiserror::Error; -use tokio_rustls::rustls; -use tracing::{debug, warn}; - -#[cfg(any(test, feature = "test-util"))] -pub mod test_util; - pub use linkerd_dns_name::InvalidName; - -/// A DER-encoded X.509 certificate signing request. -#[derive(Clone, Debug)] -pub struct Csr(Arc>); +use std::{fmt, ops::Deref, str::FromStr, sync::Arc}; /// An endpoint's identity. #[derive(Clone, Eq, PartialEq, Hash)] pub struct Name(Arc); -#[derive(Clone, Debug)] -pub struct Key(Arc); - -struct SigningKey(Arc); -struct Signer(Arc); - -#[derive(Clone)] -pub struct TrustAnchors(Arc); - -#[derive(Clone, Debug)] -pub struct TokenSource(Arc); - -#[derive(Clone, Debug)] -pub struct Crt { - id: LocalId, - expiry: SystemTime, - chain: Vec, -} - -#[derive(Clone)] -pub struct CrtKey { - id: LocalId, - expiry: SystemTime, - client_config: Arc, - server_config: Arc, -} - -struct CertResolver(rustls::sign::CertifiedKey); - -#[derive(Clone, Debug, Error)] -#[error(transparent)] -pub struct InvalidCrt(rustls::TLSError); - /// A newtype for local server identities. #[derive(Clone, Debug, Eq, PartialEq, Hash)] pub struct LocalId(pub Name); -// These must be kept in sync: -static SIGNATURE_ALG_RING_SIGNING: &ring::signature::EcdsaSigningAlgorithm = - &ring::signature::ECDSA_P256_SHA256_ASN1_SIGNING; -const SIGNATURE_ALG_RUSTLS_SCHEME: rustls::SignatureScheme = - rustls::SignatureScheme::ECDSA_NISTP256_SHA256; -const SIGNATURE_ALG_RUSTLS_ALGORITHM: rustls::internal::msgs::enums::SignatureAlgorithm = - rustls::internal::msgs::enums::SignatureAlgorithm::ECDSA; -const TLS_VERSIONS: &[rustls::ProtocolVersion] = &[rustls::ProtocolVersion::TLSv1_3]; - -// === impl Csr === - -impl Csr { - pub fn from_der(der: Vec) -> Option { - if der.is_empty() { - return None; - } - - Some(Csr(Arc::new(der))) - } - - pub fn to_vec(&self) -> Vec { - self.0.to_vec() - } -} - -// === impl Key === - -impl Key { - pub fn from_pkcs8(b: &[u8]) -> Result { - let k = EcdsaKeyPair::from_pkcs8(SIGNATURE_ALG_RING_SIGNING, b)?; - Ok(Key(Arc::new(k))) - } -} - -impl rustls::sign::SigningKey for SigningKey { - fn choose_scheme( - &self, - offered: &[rustls::SignatureScheme], - ) -> Option> { - if offered.contains(&SIGNATURE_ALG_RUSTLS_SCHEME) { - Some(Box::new(Signer(self.0.clone()))) - } else { - None - } - } - - fn algorithm(&self) -> rustls::internal::msgs::enums::SignatureAlgorithm { - SIGNATURE_ALG_RUSTLS_ALGORITHM - } -} - -impl rustls::sign::Signer for Signer { - fn sign(&self, message: &[u8]) -> Result, rustls::TLSError> { - let rng = rand::SystemRandom::new(); - self.0 - .sign(&rng, message) - .map(|signature| signature.as_ref().to_owned()) - .map_err(|ring::error::Unspecified| { - rustls::TLSError::General("Signing Failed".to_owned()) - }) - } - - fn get_scheme(&self) -> rustls::SignatureScheme { - SIGNATURE_ALG_RUSTLS_SCHEME - } -} - // === impl Name === impl From for Name { @@ -166,243 +52,9 @@ impl fmt::Display for Name { } } -// === impl TokenSource === - -impl TokenSource { - pub fn if_nonempty_file(p: String) -> io::Result { - let ts = TokenSource(Arc::new(p)); - ts.load().map(|_| ts) - } - - pub fn load(&self) -> io::Result> { - let t = fs::read(self.0.as_str())?; - - if t.is_empty() { - return Err(io::Error::new(io::ErrorKind::Other, "token is empty")); - } - - Ok(t) - } -} - -// === impl TrustAnchors === - -impl TrustAnchors { - #[cfg(any(test, feature = "test-util"))] - fn empty() -> Self { - TrustAnchors(Arc::new(rustls::ClientConfig::new())) - } - - pub fn from_pem(s: &str) -> Option { - use std::io::Cursor; - - let mut roots = rustls::RootCertStore::empty(); - let (added, skipped) = roots.add_pem_file(&mut Cursor::new(s)).ok()?; - if skipped != 0 { - warn!("skipped {} trust anchors in trust anchors file", skipped); - } - if added == 0 { - return None; - } - - let mut c = rustls::ClientConfig::new(); - - // XXX: Rustls's built-in verifiers don't let us tweak things as fully - // as we'd like (e.g. controlling the set of trusted signature - // algorithms), but they provide good enough defaults for now. - // TODO: lock down the verification further. - // TODO: Change Rustls's API to Avoid needing to clone `root_cert_store`. - c.root_store = roots; - - // Disable session resumption for the time-being until resumption is - // more tested. - c.enable_tickets = false; - - Some(TrustAnchors(Arc::new(c))) - } - - pub fn certify(&self, key: Key, crt: Crt) -> Result { - let mut client = self.0.as_ref().clone(); - - let crt_id = webpki::DNSNameRef::try_from_ascii(crt.id.as_bytes()) - .expect("certificate ID must be a valid DNS name"); - - // Ensure the certificate is valid for the services we terminate for - // TLS. This assumes that server cert validation does the same or - // more validation than client cert validation. - // - // XXX: Rustls currently only provides access to a - // `ServerCertVerifier` through - // `rustls::ClientConfig::get_verifier()`. - // - // XXX: Once `rustls::ServerCertVerified` is exposed in Rustls's - // safe API, use it to pass proof to CertCertResolver::new.... - // - // TODO: Restrict accepted signature algorithms. - static NO_OCSP: &[u8] = &[]; - client - .get_verifier() - .verify_server_cert(&client.root_store, &crt.chain, crt_id, NO_OCSP) - .map_err(InvalidCrt)?; - debug!("certified {}", crt.id); - - let k = SigningKey(key.0); - let key = rustls::sign::CertifiedKey::new(crt.chain, Arc::new(Box::new(k))); - let resolver = Arc::new(CertResolver(key)); - - // Enable client authentication. - client.client_auth_cert_resolver = resolver.clone(); - - // Ask TLS clients for a certificate and accept any certificate issued - // by our trusted CA(s). - // - // XXX: Rustls's built-in verifiers don't let us tweak things as fully - // as we'd like (e.g. controlling the set of trusted signature - // algorithms), but they provide good enough defaults for now. - // TODO: lock down the verification further. - // - // TODO: Change Rustls's API to Avoid needing to clone `root_cert_store`. - let mut server = rustls::ServerConfig::new( - rustls::AllowAnyAnonymousOrAuthenticatedClient::new(self.0.root_store.clone()), - ); - server.versions = TLS_VERSIONS.to_vec(); - server.cert_resolver = resolver; - - Ok(CrtKey { - id: crt.id, - expiry: crt.expiry, - client_config: Arc::new(client), - server_config: Arc::new(server), - }) - } - - pub fn client_config(&self) -> Arc { - self.0.clone() - } -} - -impl fmt::Debug for TrustAnchors { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("TrustAnchors").finish() - } -} - -// === Crt === - -impl Crt { - pub fn new( - id: LocalId, - leaf: Vec, - intermediates: Vec>, - expiry: SystemTime, - ) -> Self { - let mut chain = Vec::with_capacity(intermediates.len() + 1); - chain.push(rustls::Certificate(leaf)); - chain.extend(intermediates.into_iter().map(rustls::Certificate)); - - Self { id, chain, expiry } - } - - pub fn name(&self) -> &Name { - &*self.id - } -} - -impl From<&'_ Crt> for LocalId { - fn from(crt: &Crt) -> LocalId { - crt.id.clone() - } -} - -// === CrtKey === - -impl CrtKey { - pub fn name(&self) -> &Name { - &*self.id - } - - pub fn expiry(&self) -> SystemTime { - self.expiry - } - - pub fn id(&self) -> &LocalId { - &self.id - } - - pub fn client_config(&self) -> Arc { - self.client_config.clone() - } - - pub fn server_config(&self) -> Arc { - self.server_config.clone() - } -} - -impl fmt::Debug for CrtKey { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - f.debug_struct("CrtKey") - .field("id", &self.id) - .field("expiry", &self.expiry) - .finish() - } -} - -// === impl CertResolver === - -impl rustls::ResolvesClientCert for CertResolver { - fn resolve( - &self, - _acceptable_issuers: &[&[u8]], - sigschemes: &[rustls::SignatureScheme], - ) -> Option { - // The proxy's server-side doesn't send the list of acceptable issuers so - // don't bother looking at `_acceptable_issuers`. - self.resolve_(sigschemes) - } - - fn has_certs(&self) -> bool { - true - } -} - -impl CertResolver { - fn resolve_( - &self, - sigschemes: &[rustls::SignatureScheme], - ) -> Option { - if !sigschemes.contains(&SIGNATURE_ALG_RUSTLS_SCHEME) { - debug!("signature scheme not supported -> no certificate"); - return None; - } - Some(self.0.clone()) - } -} - -impl rustls::ResolvesServerCert for CertResolver { - fn resolve(&self, hello: rustls::ClientHello<'_>) -> Option { - let server_name = if let Some(server_name) = hello.server_name() { - server_name - } else { - debug!("no SNI -> no certificate"); - return None; - }; - - // Verify that our certificate is valid for the given SNI name. - let c = (&self.0.cert) - .first() - .map(rustls::Certificate::as_ref) - .unwrap_or(&[]); // An empty input will fail to parse. - if let Err(err) = - webpki::EndEntityCert::from(c).and_then(|c| c.verify_is_valid_for_dns_name(server_name)) - { - debug!( - "our certificate is not valid for the SNI name -> no certificate: {:?}", - err - ); - return None; - } - - self.resolve_(hello.sigschemes()) +impl From for Name { + fn from(LocalId(name): LocalId) -> Name { + name } } @@ -414,12 +66,6 @@ impl From for LocalId { } } -impl From for Name { - fn from(LocalId(name): LocalId) -> Name { - name - } -} - impl Deref for LocalId { type Target = Name; @@ -433,42 +79,3 @@ impl fmt::Display for LocalId { self.0.fmt(f) } } - -#[cfg(test)] -mod tests { - use super::test_util::*; - - #[test] - fn can_construct_client_and_server_config_from_valid_settings() { - FOO_NS1.validate().expect("foo.ns1 must be valid"); - } - - #[test] - fn recognize_ca_did_not_issue_cert() { - let s = Identity { - trust_anchors: include_bytes!("testdata/ca2.pem"), - ..FOO_NS1 - }; - assert!(s.validate().is_err(), "ca2 should not validate foo.ns1"); - } - - #[test] - fn recognize_cert_is_not_valid_for_identity() { - let s = Identity { - crt: BAR_NS1.crt, - key: BAR_NS1.key, - ..FOO_NS1 - }; - assert!(s.validate().is_err(), "identity should not be valid"); - } - - #[test] - #[ignore] // XXX this doesn't fail because we don't actually check the key against the cert... - fn recognize_private_key_is_not_valid_for_cert() { - let s = Identity { - key: BAR_NS1.key, - ..FOO_NS1 - }; - assert!(s.validate().is_err(), "identity should not be valid"); - } -} diff --git a/linkerd/io/Cargo.toml b/linkerd/io/Cargo.toml index 3005a8e8a4..4e9ab36fa1 100644 --- a/linkerd/io/Cargo.toml +++ b/linkerd/io/Cargo.toml @@ -18,7 +18,6 @@ futures = { version = "0.3", default-features = false } bytes = "1" linkerd-errno = { path = "../errno" } tokio = { version = "1", features = ["io-util", "net"] } -tokio-rustls = "0.22" tokio-test = { version = "0.4", optional = true } tokio-util = { version = "0.6", features = ["io"] } pin-project = "1" diff --git a/linkerd/io/src/lib.rs b/linkerd/io/src/lib.rs index 8993c81217..6653e63c3a 100644 --- a/linkerd/io/src/lib.rs +++ b/linkerd/io/src/lib.rs @@ -64,18 +64,6 @@ impl PeerAddr for tokio::net::TcpStream { } } -impl PeerAddr for tokio_rustls::client::TlsStream { - fn peer_addr(&self) -> Result { - self.get_ref().0.peer_addr() - } -} - -impl PeerAddr for tokio_rustls::server::TlsStream { - fn peer_addr(&self) -> Result { - self.get_ref().0.peer_addr() - } -} - #[cfg(feature = "tokio-test")] impl PeerAddr for tokio_test::io::Mock { fn peer_addr(&self) -> Result { diff --git a/linkerd/proxy/identity/Cargo.toml b/linkerd/proxy/identity/Cargo.toml index 373d7d3f64..92770b88bc 100644 --- a/linkerd/proxy/identity/Cargo.toml +++ b/linkerd/proxy/identity/Cargo.toml @@ -8,7 +8,7 @@ publish = false [features] rustfmt = ["linkerd2-proxy-api/rustfmt"] -test-util = ["linkerd-identity/test-util"] +test-util = ["linkerd-tls-rustls/test-util"] [dependencies] futures = { version = "0.3", default-features = false } @@ -18,6 +18,7 @@ linkerd-identity = { path = "../../identity" } linkerd-metrics = { path = "../../metrics" } linkerd-stack = { path = "../../stack" } linkerd-tls = { path = "../../tls" } +linkerd-tls-rustls = { path = "../../tls/rustls" } thiserror = "1" tokio = { version = "1", features = ["time", "sync"] } tonic = { version = "0.5", default-features = false } diff --git a/linkerd/proxy/identity/src/certify.rs b/linkerd/proxy/identity/src/certify.rs index 1925686b37..fea85bd0bc 100644 --- a/linkerd/proxy/identity/src/certify.rs +++ b/linkerd/proxy/identity/src/certify.rs @@ -1,41 +1,51 @@ +use crate::TokenSource; use http_body::Body; use linkerd2_proxy_api::identity::{self as api, identity_client::IdentityClient}; use linkerd_error::Error; use linkerd_identity as id; use linkerd_metrics::Counter; -use linkerd_stack::{NewService, Param}; +use linkerd_stack::{NewService, Param, Service}; use linkerd_tls as tls; -use pin_project::pin_project; -use std::convert::TryFrom; -use std::sync::Arc; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use linkerd_tls_rustls::{self as rustls, Crt, CrtKey, Key, TrustAnchors}; +use std::{ + convert::TryFrom, + sync::Arc, + task, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; use thiserror::Error; -use tokio::sync::watch; -use tokio::time::{self, Sleep}; +use tokio::{ + io, + sync::watch, + time::{self, Sleep}, +}; use tonic::{self as grpc, body::BoxBody, client::GrpcService}; use tracing::{debug, error, trace}; /// Configures the Identity service and local identity. #[derive(Clone, Debug)] pub struct Config { - pub trust_anchors: id::TrustAnchors, - pub key: id::Key, - pub csr: id::Csr, - pub token: id::TokenSource, + pub trust_anchors: TrustAnchors, + pub key: Key, + pub csr: Csr, + pub token: TokenSource, pub local_id: id::LocalId, pub min_refresh: Duration, pub max_refresh: Duration, } +/// A DER-encoded X.509 certificate signing request. +#[derive(Clone, Debug)] +pub struct Csr(Arc>); + /// Holds the process's local TLS identity state. /// /// Updates dynamically as certificates are provisioned from the Identity service. -#[pin_project] #[derive(Clone, Debug)] pub struct LocalCrtKey { - trust_anchors: id::TrustAnchors, + trust_anchors: TrustAnchors, id: id::LocalId, - crt_key: watch::Receiver>, + crt_key: watch::Receiver>, refreshes: Arc, } @@ -47,7 +57,7 @@ pub struct AwaitCrt(Option); #[error("identity initialization failed")] pub struct LostDaemon(()); -pub type CrtKeySender = watch::Sender>; +pub type CrtKeySender = watch::Sender>; #[derive(Debug)] pub struct Daemon { @@ -130,7 +140,7 @@ impl Daemon { ), Some(expiry) => { let key = config.key.clone(); - let crt = id::Crt::new( + let crt = Crt::new( config.local_id.clone(), leaf_certificate, intermediate_certificates, @@ -186,7 +196,7 @@ impl LocalCrtKey { } #[cfg(feature = "test-util")] - pub fn for_test(id: &id::test_util::Identity) -> Self { + pub fn for_test(id: &rustls::test_util::Identity) -> Self { let crt_key = id.validate().expect("Identity must be valid"); let (tx, rx) = watch::channel(Some(crt_key)); // Prevent the receiver stream from ending. @@ -203,7 +213,7 @@ impl LocalCrtKey { #[cfg(feature = "test-util")] pub fn default_for_test() -> Self { - Self::for_test(&id::test_util::DEFAULT_DEFAULT) + Self::for_test(&rustls::test_util::DEFAULT_DEFAULT) } pub async fn await_crt(mut self) -> Result { @@ -228,7 +238,7 @@ impl LocalCrtKey { &*self.id } - pub fn client_config(&self) -> tls::client::Config { + fn client_config(&self) -> Arc { if let Some(ref c) = *self.crt_key.borrow() { return c.client_config(); } @@ -236,24 +246,43 @@ impl LocalCrtKey { self.trust_anchors.client_config() } - pub fn server_config(&self) -> tls::server::Config { + pub fn server_config(&self) -> Arc { if let Some(ref c) = *self.crt_key.borrow() { return c.server_config(); } - tls::server::empty_config() + let verifier = rustls::NoClientAuth::new(); + Arc::new(rustls::ServerConfig::new(verifier)) } } -impl Param for LocalCrtKey { - fn param(&self) -> tls::client::Config { - self.client_config() +impl NewService for LocalCrtKey { + type Service = rustls::Connect; + + /// Creates a new TLS client service. + #[inline] + fn new_service(&self, target: tls::ClientTls) -> Self::Service { + rustls::Connect::new(target, self.client_config()) } } -impl Param for LocalCrtKey { - fn param(&self) -> tls::server::Config { - self.server_config() +impl Service for LocalCrtKey +where + I: io::AsyncRead + io::AsyncWrite + Send + Unpin, +{ + type Response = (tls::ServerTls, rustls::ServerIo); + type Error = io::Error; + type Future = rustls::TerminateFuture; + + #[inline] + fn poll_ready(&mut self, _: &mut task::Context<'_>) -> task::Poll> { + task::Poll::Ready(Ok(())) + } + + /// Terminates a server-side TLS connection. + #[inline] + fn call(&mut self, io: I) -> Self::Future { + rustls::terminate(self.server_config(), io) } } @@ -262,3 +291,19 @@ impl Param for LocalCrtKey { self.id().clone() } } + +// === impl Csr === + +impl Csr { + pub fn from_der(der: Vec) -> Option { + if der.is_empty() { + return None; + } + + Some(Csr(Arc::new(der))) + } + + pub fn to_vec(&self) -> Vec { + self.0.to_vec() + } +} diff --git a/linkerd/proxy/identity/src/lib.rs b/linkerd/proxy/identity/src/lib.rs index 4cab0f2b13..8494c3b0d8 100644 --- a/linkerd/proxy/identity/src/lib.rs +++ b/linkerd/proxy/identity/src/lib.rs @@ -3,5 +3,11 @@ pub mod certify; pub mod metrics; +mod token; -pub use self::certify::{AwaitCrt, CrtKeySender, LocalCrtKey}; +pub use self::{ + certify::{AwaitCrt, CrtKeySender, Csr, LocalCrtKey}, + token::TokenSource, +}; +pub use linkerd_identity::*; +pub use linkerd_tls_rustls::*; diff --git a/linkerd/proxy/identity/src/metrics.rs b/linkerd/proxy/identity/src/metrics.rs index 9e4f04aeca..d158da365c 100644 --- a/linkerd/proxy/identity/src/metrics.rs +++ b/linkerd/proxy/identity/src/metrics.rs @@ -1,5 +1,5 @@ -use linkerd_identity::CrtKey; use linkerd_metrics::{metrics, Counter, FmtMetrics, Gauge}; +use linkerd_tls_rustls::CrtKey; use std::{fmt, sync::Arc, time::UNIX_EPOCH}; use tokio::sync::watch; diff --git a/linkerd/proxy/identity/src/token.rs b/linkerd/proxy/identity/src/token.rs new file mode 100644 index 0000000000..41842f902b --- /dev/null +++ b/linkerd/proxy/identity/src/token.rs @@ -0,0 +1,24 @@ +use std::{io, path::PathBuf}; + +#[derive(Clone, Debug)] +pub struct TokenSource(PathBuf); + +// === impl TokenSource === + +impl TokenSource { + pub fn if_nonempty_file(p: impl Into) -> io::Result { + let ts = TokenSource(p.into()); + ts.load()?; + Ok(ts) + } + + pub fn load(&self) -> io::Result> { + let t = std::fs::read(&self.0)?; + + if t.is_empty() { + return Err(io::Error::new(io::ErrorKind::Other, "token is empty")); + } + + Ok(t) + } +} diff --git a/linkerd/proxy/tap/Cargo.toml b/linkerd/proxy/tap/Cargo.toml index 859e436df2..af05936f4e 100644 --- a/linkerd/proxy/tap/Cargo.toml +++ b/linkerd/proxy/tap/Cargo.toml @@ -23,6 +23,7 @@ linkerd-proxy-http = { path = "../http" } linkerd-proxy-transport = { path = "../transport" } linkerd-stack = { path = "../../stack" } linkerd-tls = { path = "../../tls" } +linkerd-tls-rustls = { path = "../../tls/rustls" } parking_lot = "0.11" rand = { version = "0.8" } thiserror = "1.0" diff --git a/linkerd/proxy/tap/src/accept.rs b/linkerd/proxy/tap/src/accept.rs index aa1514ef5d..d8fcb64374 100644 --- a/linkerd/proxy/tap/src/accept.rs +++ b/linkerd/proxy/tap/src/accept.rs @@ -5,7 +5,7 @@ use linkerd_conditional::Conditional; use linkerd_error::Error; use linkerd_io as io; use linkerd_proxy_http::{trace, HyperServerSvc}; -use linkerd_tls::{self as tls}; +use linkerd_tls as tls; use std::{ collections::HashSet, future::Future, @@ -21,7 +21,10 @@ pub struct AcceptPermittedClients { server: Server, } -type Connection = ((tls::ConditionalServerTls, T), tls::server::Io); +type Connection = ( + (tls::ConditionalServerTls, T), + io::EitherIo>, tls::server::DetectIo>, +); pub type ServeFuture = Pin> + Send + 'static>>; diff --git a/linkerd/tls/Cargo.toml b/linkerd/tls/Cargo.toml index 2ae9e5097f..43df6dc2ef 100644 --- a/linkerd/tls/Cargo.toml +++ b/linkerd/tls/Cargo.toml @@ -16,16 +16,15 @@ linkerd-error = { path = "../error" } linkerd-identity = { path = "../identity" } linkerd-io = { path = "../io" } linkerd-stack = { path = "../stack" } +pin-project = "1" thiserror = "1.0" tokio = { version = "1", features = ["macros", "time"] } -tokio-rustls = "0.22" -tower = "0.4.9" +tower = "0.4" tracing = "0.1.29" -webpki = "0.21" untrusted = "0.7" [dev-dependencies] -linkerd-identity = { path = "../identity", features = ["test-util"] } +linkerd-tls-rustls = { path = "rustls", features = ["test-util"] } linkerd-proxy-transport = { path = "../proxy/transport" } linkerd-tracing = { path = "../tracing", features = ["ansi"] } tokio = { version = "1", features = ["rt-multi-thread"] } diff --git a/linkerd/tls/fuzz/Cargo.toml b/linkerd/tls/fuzz/Cargo.toml index 5185c5e198..23ff35bba4 100644 --- a/linkerd/tls/fuzz/Cargo.toml +++ b/linkerd/tls/fuzz/Cargo.toml @@ -24,6 +24,3 @@ name = "fuzz_target_1" path = "fuzz_targets/fuzz_target_1.rs" test = false doc = false - -[patch.crates-io] -webpki = { git = "https://github.com/linkerd/webpki", branch = "cert-dns-names-0.21"} diff --git a/linkerd/tls/rustls/Cargo.toml b/linkerd/tls/rustls/Cargo.toml new file mode 100644 index 0000000000..5655aa7540 --- /dev/null +++ b/linkerd/tls/rustls/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "linkerd-tls-rustls" +version = "0.1.0" +authors = ["Linkerd Developers "] +license = "Apache-2.0" +edition = "2018" +publish = false + +[features] +default = [] +test-util = [] + +[dependencies] +futures = { version = "0.3", default-features = false } +linkerd-identity = { path = "../../identity" } +linkerd-io = { path = "../../io" } +linkerd-stack = { path = "../../stack" } +linkerd-tls = { path = ".." } +ring = "0.16.19" +thiserror = "1" +tokio-rustls = "0.22" +tracing = "0.1" +webpki = "0.21" diff --git a/linkerd/tls/rustls/src/client.rs b/linkerd/tls/rustls/src/client.rs new file mode 100644 index 0000000000..638643cb8e --- /dev/null +++ b/linkerd/tls/rustls/src/client.rs @@ -0,0 +1,136 @@ +use futures::prelude::*; +use linkerd_io as io; +use linkerd_stack::Service; +use linkerd_tls::{client::AlpnProtocols, ClientTls, HasNegotiatedProtocol, NegotiatedProtocolRef}; +use std::{pin::Pin, sync::Arc}; +use tokio_rustls::rustls::{ClientConfig, Session}; + +#[derive(Clone)] +pub struct Connect { + server_id: webpki::DNSName, + config: Arc, +} + +pub type ConnectFuture = futures::future::MapOk< + tokio_rustls::Connect, + fn(tokio_rustls::client::TlsStream) -> ClientIo, +>; + +#[derive(Debug)] +pub struct ClientIo(tokio_rustls::client::TlsStream); + +// === impl Connect === + +impl Connect { + pub fn new(client_tls: ClientTls, config: Arc) -> Self { + // If ALPN protocols are configured by the endpoint, we have to clone the + // entire configuration and set the protocols. If there are no + // ALPN options, clone the Arc'd base configuration without + // extra allocation. + // + // TODO it would be better to avoid cloning the whole TLS config + // per-connection. + let config = match client_tls.alpn { + None => config, + Some(AlpnProtocols(protocols)) => { + let mut c: ClientConfig = config.as_ref().clone(); + c.alpn_protocols = protocols; + Arc::new(c) + } + }; + + let server_id = webpki::DNSNameRef::try_from_ascii(client_tls.server_id.as_bytes()) + .expect("identity must be a valid DNS name") + .to_owned(); + + Self { server_id, config } + } +} + +impl Service for Connect +where + I: io::AsyncRead + io::AsyncWrite + Send + Unpin, +{ + type Response = ClientIo; + type Error = io::Error; + type Future = ConnectFuture; + + fn poll_ready( + &mut self, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn call(&mut self, io: I) -> Self::Future { + tokio_rustls::TlsConnector::from(self.config.clone()) + .connect(self.server_id.as_ref(), io) + .map_ok(ClientIo) + } +} + +// === impl ClientIo === + +impl io::AsyncRead for ClientIo { + #[inline] + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut io::ReadBuf<'_>, + ) -> io::Poll<()> { + Pin::new(&mut self.0).poll_read(cx, buf) + } +} + +impl io::AsyncWrite for ClientIo { + #[inline] + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> io::Poll<()> { + Pin::new(&mut self.0).poll_flush(cx) + } + + #[inline] + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> io::Poll<()> { + Pin::new(&mut self.0).poll_shutdown(cx) + } + + #[inline] + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> io::Poll { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + #[inline] + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> std::task::Poll> { + Pin::new(&mut self.0).poll_write_vectored(cx, bufs) + } + + #[inline] + fn is_write_vectored(&self) -> bool { + self.0.is_write_vectored() + } +} + +impl HasNegotiatedProtocol for ClientIo { + #[inline] + fn negotiated_protocol(&self) -> Option> { + self.0 + .get_ref() + .1 + .get_alpn_protocol() + .map(NegotiatedProtocolRef) + } +} + +impl io::PeerAddr for ClientIo { + #[inline] + fn peer_addr(&self) -> io::Result { + self.0.get_ref().0.peer_addr() + } +} diff --git a/linkerd/tls/rustls/src/lib.rs b/linkerd/tls/rustls/src/lib.rs new file mode 100644 index 0000000000..93e2eac166 --- /dev/null +++ b/linkerd/tls/rustls/src/lib.rs @@ -0,0 +1,311 @@ +#![deny(warnings, rust_2018_idioms)] +#![forbid(unsafe_code)] + +mod client; +mod server; +#[cfg(feature = "test-util")] +pub mod test_util; + +pub use self::{ + client::{ClientIo, Connect, ConnectFuture}, + server::{terminate, ServerIo, TerminateFuture}, +}; +use linkerd_identity as id; +pub use ring::error::KeyRejected; +use ring::{rand, signature::EcdsaKeyPair}; +use std::{sync::Arc, time::SystemTime}; +use thiserror::Error; +pub use tokio_rustls::rustls::*; +use tracing::{debug, warn}; + +#[derive(Clone, Debug)] +pub struct Key(Arc); + +struct SigningKey(Arc); +struct Signer(Arc); + +#[derive(Clone)] +pub struct TrustAnchors(Arc); + +#[derive(Clone, Debug)] +pub struct Crt { + id: id::LocalId, + expiry: SystemTime, + chain: Vec, +} + +#[derive(Clone)] +pub struct CrtKey { + id: id::LocalId, + expiry: SystemTime, + client_config: Arc, + server_config: Arc, +} + +struct CertResolver(sign::CertifiedKey); + +#[derive(Clone, Debug, Error)] +#[error(transparent)] +pub struct InvalidCrt(TLSError); + +// These must be kept in sync: +static SIGNATURE_ALG_RING_SIGNING: &ring::signature::EcdsaSigningAlgorithm = + &ring::signature::ECDSA_P256_SHA256_ASN1_SIGNING; +const SIGNATURE_ALG_RUSTLS_SCHEME: SignatureScheme = SignatureScheme::ECDSA_NISTP256_SHA256; +const SIGNATURE_ALG_RUSTLS_ALGORITHM: internal::msgs::enums::SignatureAlgorithm = + internal::msgs::enums::SignatureAlgorithm::ECDSA; +const TLS_VERSIONS: &[ProtocolVersion] = &[ProtocolVersion::TLSv1_3]; + +// === impl Key === + +impl Key { + pub fn from_pkcs8(b: &[u8]) -> Result { + let k = EcdsaKeyPair::from_pkcs8(SIGNATURE_ALG_RING_SIGNING, b)?; + Ok(Key(Arc::new(k))) + } +} + +impl sign::SigningKey for SigningKey { + fn choose_scheme(&self, offered: &[SignatureScheme]) -> Option> { + if offered.contains(&SIGNATURE_ALG_RUSTLS_SCHEME) { + Some(Box::new(Signer(self.0.clone()))) + } else { + None + } + } + + fn algorithm(&self) -> internal::msgs::enums::SignatureAlgorithm { + SIGNATURE_ALG_RUSTLS_ALGORITHM + } +} + +impl sign::Signer for Signer { + fn sign(&self, message: &[u8]) -> Result, TLSError> { + let rng = rand::SystemRandom::new(); + self.0 + .sign(&rng, message) + .map(|signature| signature.as_ref().to_owned()) + .map_err(|ring::error::Unspecified| TLSError::General("Signing Failed".to_owned())) + } + + fn get_scheme(&self) -> SignatureScheme { + SIGNATURE_ALG_RUSTLS_SCHEME + } +} + +// === impl TrustAnchors === + +impl TrustAnchors { + #[cfg(feature = "test-util")] + fn empty() -> Self { + TrustAnchors(Arc::new(ClientConfig::new())) + } + + pub fn from_pem(s: &str) -> Option { + use std::io::Cursor; + + let mut roots = RootCertStore::empty(); + let (added, skipped) = roots.add_pem_file(&mut Cursor::new(s)).ok()?; + if skipped != 0 { + warn!("skipped {} trust anchors in trust anchors file", skipped); + } + if added == 0 { + return None; + } + + let mut c = ClientConfig::new(); + + // XXX: Rustls's built-in verifiers don't let us tweak things as fully + // as we'd like (e.g. controlling the set of trusted signature + // algorithms), but they provide good enough defaults for now. + // TODO: lock down the verification further. + // TODO: Change Rustls's API to Avoid needing to clone `root_cert_store`. + c.root_store = roots; + + // Disable session resumption for the time-being until resumption is + // more tested. + c.enable_tickets = false; + + Some(TrustAnchors(Arc::new(c))) + } + + pub fn certify(&self, key: Key, crt: Crt) -> Result { + let mut client = self.0.as_ref().clone(); + + // Ensure the certificate is valid for the services we terminate for + // TLS. This assumes that server cert validation does the same or + // more validation than client cert validation. + // + // XXX: Rustls currently only provides access to a + // `ServerCertVerifier` through + // `ClientConfig::get_verifier()`. + // + // XXX: Once `ServerCertVerified` is exposed in Rustls's + // safe API, use it to pass proof to CertCertResolver::new.... + // + // TODO: Restrict accepted signature algorithms. + static NO_OCSP: &[u8] = &[]; + let crt_id = webpki::DNSNameRef::try_from_ascii((***crt.id).as_bytes()) + .map_err(|e| InvalidCrt(TLSError::General(e.to_string())))?; + client + .get_verifier() + .verify_server_cert(&client.root_store, &crt.chain, crt_id, NO_OCSP) + .map_err(InvalidCrt)?; + debug!("certified {}", crt.id); + + let k = SigningKey(key.0); + let key = sign::CertifiedKey::new(crt.chain, Arc::new(Box::new(k))); + let resolver = Arc::new(CertResolver(key)); + + // Enable client authentication. + client.client_auth_cert_resolver = resolver.clone(); + + // Ask TLS clients for a certificate and accept any certificate issued + // by our trusted CA(s). + // + // XXX: Rustls's built-in verifiers don't let us tweak things as fully + // as we'd like (e.g. controlling the set of trusted signature + // algorithms), but they provide good enough defaults for now. + // TODO: lock down the verification further. + // + // TODO: Change Rustls's API to Avoid needing to clone `root_cert_store`. + let mut server = ServerConfig::new(AllowAnyAnonymousOrAuthenticatedClient::new( + self.0.root_store.clone(), + )); + server.versions = TLS_VERSIONS.to_vec(); + server.cert_resolver = resolver; + + Ok(CrtKey { + id: crt.id, + expiry: crt.expiry, + client_config: Arc::new(client), + server_config: Arc::new(server), + }) + } + + pub fn client_config(&self) -> Arc { + self.0.clone() + } +} + +impl std::fmt::Debug for TrustAnchors { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TrustAnchors").finish() + } +} + +// === Crt === + +impl Crt { + pub fn new( + id: id::LocalId, + leaf: Vec, + intermediates: Vec>, + expiry: SystemTime, + ) -> Self { + let mut chain = Vec::with_capacity(intermediates.len() + 1); + chain.push(Certificate(leaf)); + chain.extend(intermediates.into_iter().map(Certificate)); + + Self { id, chain, expiry } + } + + pub fn name(&self) -> &id::Name { + &self.id.0 + } +} + +impl From<&'_ Crt> for id::LocalId { + fn from(crt: &Crt) -> id::LocalId { + crt.id.clone() + } +} + +// === CrtKey === + +impl CrtKey { + pub fn name(&self) -> &id::Name { + &self.id.0 + } + + pub fn expiry(&self) -> SystemTime { + self.expiry + } + + pub fn id(&self) -> &id::LocalId { + &self.id + } + + pub fn client_config(&self) -> Arc { + self.client_config.clone() + } + + pub fn server_config(&self) -> Arc { + self.server_config.clone() + } +} + +impl std::fmt::Debug for CrtKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + f.debug_struct("CrtKey") + .field("id", &self.id) + .field("expiry", &self.expiry) + .finish() + } +} + +// === impl CertResolver === + +impl ResolvesClientCert for CertResolver { + fn resolve( + &self, + _acceptable_issuers: &[&[u8]], + sigschemes: &[SignatureScheme], + ) -> Option { + // The proxy's server-side doesn't send the list of acceptable issuers so + // don't bother looking at `_acceptable_issuers`. + self.resolve_(sigschemes) + } + + fn has_certs(&self) -> bool { + true + } +} + +impl CertResolver { + fn resolve_(&self, sigschemes: &[SignatureScheme]) -> Option { + if !sigschemes.contains(&SIGNATURE_ALG_RUSTLS_SCHEME) { + debug!("signature scheme not supported -> no certificate"); + return None; + } + Some(self.0.clone()) + } +} + +impl ResolvesServerCert for CertResolver { + fn resolve(&self, hello: ClientHello<'_>) -> Option { + let server_name = if let Some(server_name) = hello.server_name() { + server_name + } else { + debug!("no SNI -> no certificate"); + return None; + }; + + // Verify that our certificate is valid for the given SNI name. + let c = (&self.0.cert) + .first() + .map(Certificate::as_ref) + .unwrap_or(&[]); // An empty input will fail to parse. + if let Err(err) = + webpki::EndEntityCert::from(c).and_then(|c| c.verify_is_valid_for_dns_name(server_name)) + { + debug!( + "our certificate is not valid for the SNI name -> no certificate: {:?}", + err + ); + return None; + } + + self.resolve_(hello.sigschemes()) + } +} diff --git a/linkerd/tls/rustls/src/server.rs b/linkerd/tls/rustls/src/server.rs new file mode 100644 index 0000000000..e0acfde629 --- /dev/null +++ b/linkerd/tls/rustls/src/server.rs @@ -0,0 +1,127 @@ +use futures::prelude::*; +use linkerd_io as io; +use linkerd_tls::{ + ClientId, HasNegotiatedProtocol, NegotiatedProtocol, NegotiatedProtocolRef, ServerTls, +}; +use std::{pin::Pin, sync::Arc}; +use tokio_rustls::rustls::{Certificate, ServerConfig, Session}; +use tracing::debug; + +pub type TerminateFuture = futures::future::MapOk< + tokio_rustls::Accept, + fn(tokio_rustls::server::TlsStream) -> (ServerTls, ServerIo), +>; + +#[derive(Debug)] +pub struct ServerIo(tokio_rustls::server::TlsStream); + +/// Terminates a TLS connection. +pub fn terminate(config: Arc, io: I) -> TerminateFuture +where + I: io::AsyncRead + io::AsyncWrite + Send + Unpin, +{ + tokio_rustls::TlsAcceptor::from(config) + .accept(io) + .map_ok(|io| { + // Determine the peer's identity, if it exist. + let client_id = client_identity(&io); + + let negotiated_protocol = io + .get_ref() + .1 + .get_alpn_protocol() + .map(|b| NegotiatedProtocol(b.into())); + + debug!(client.id = ?client_id, alpn = ?negotiated_protocol, "Accepted TLS connection"); + let tls = ServerTls::Established { + client_id, + negotiated_protocol, + }; + (tls, ServerIo(io)) + }) +} + +fn client_identity(tls: &tokio_rustls::server::TlsStream) -> Option { + let (_io, session) = tls.get_ref(); + let certs = session.get_peer_certificates()?; + let c = certs.first().map(Certificate::as_ref)?; + let end_cert = webpki::EndEntityCert::from(c).ok()?; + let dns_names = end_cert.dns_names().ok()?; + + match dns_names.first()? { + webpki::GeneralDNSNameRef::DNSName(n) => { + let s: &str = (*n).into(); + s.parse().ok().map(ClientId) + } + webpki::GeneralDNSNameRef::Wildcard(_) => { + // Wildcards can perhaps be handled in a future path... + None + } + } +} + +// === impl ServerIo === + +impl io::AsyncRead for ServerIo { + #[inline] + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut io::ReadBuf<'_>, + ) -> io::Poll<()> { + Pin::new(&mut self.0).poll_read(cx, buf) + } +} + +impl io::AsyncWrite for ServerIo { + #[inline] + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> io::Poll<()> { + Pin::new(&mut self.0).poll_flush(cx) + } + + #[inline] + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> io::Poll<()> { + Pin::new(&mut self.0).poll_shutdown(cx) + } + + #[inline] + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> io::Poll { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + #[inline] + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> std::task::Poll> { + Pin::new(&mut self.0).poll_write_vectored(cx, bufs) + } + + #[inline] + fn is_write_vectored(&self) -> bool { + self.0.is_write_vectored() + } +} + +impl HasNegotiatedProtocol for ServerIo { + #[inline] + fn negotiated_protocol(&self) -> Option> { + self.0 + .get_ref() + .1 + .get_alpn_protocol() + .map(NegotiatedProtocolRef) + } +} + +impl io::PeerAddr for ServerIo { + #[inline] + fn peer_addr(&self) -> io::Result { + self.0.get_ref().0.peer_addr() + } +} diff --git a/linkerd/identity/src/test_util.rs b/linkerd/tls/rustls/src/test_util.rs similarity index 94% rename from linkerd/identity/src/test_util.rs rename to linkerd/tls/rustls/src/test_util.rs index bd24f4f315..eb8e23c4fc 100644 --- a/linkerd/identity/src/test_util.rs +++ b/linkerd/tls/rustls/src/test_util.rs @@ -1,4 +1,5 @@ use super::*; +use linkerd_identity::{LocalId, Name}; use std::time::{Duration, SystemTime}; pub struct Identity { @@ -46,7 +47,7 @@ impl Identity { pub fn crt(&self) -> Crt { const HOUR: Duration = Duration::from_secs(60 * 60); - let n = Name::from_str(self.name).expect("name must be valid"); + let n = self.name.parse::().expect("name must be valid"); let der = self.crt.iter().copied().collect(); Crt::new(LocalId(n), der, vec![], SystemTime::now() + HOUR) } diff --git a/linkerd/identity/src/testdata/bar-ns1-ca1/crt.der b/linkerd/tls/rustls/src/testdata/bar-ns1-ca1/crt.der similarity index 100% rename from linkerd/identity/src/testdata/bar-ns1-ca1/crt.der rename to linkerd/tls/rustls/src/testdata/bar-ns1-ca1/crt.der diff --git a/linkerd/identity/src/testdata/bar-ns1-ca1/csr.pem b/linkerd/tls/rustls/src/testdata/bar-ns1-ca1/csr.pem similarity index 100% rename from linkerd/identity/src/testdata/bar-ns1-ca1/csr.pem rename to linkerd/tls/rustls/src/testdata/bar-ns1-ca1/csr.pem diff --git a/linkerd/identity/src/testdata/bar-ns1-ca1/key.p8 b/linkerd/tls/rustls/src/testdata/bar-ns1-ca1/key.p8 similarity index 100% rename from linkerd/identity/src/testdata/bar-ns1-ca1/key.p8 rename to linkerd/tls/rustls/src/testdata/bar-ns1-ca1/key.p8 diff --git a/linkerd/identity/src/testdata/ca-config.json b/linkerd/tls/rustls/src/testdata/ca-config.json similarity index 100% rename from linkerd/identity/src/testdata/ca-config.json rename to linkerd/tls/rustls/src/testdata/ca-config.json diff --git a/linkerd/identity/src/testdata/ca1-key.pem b/linkerd/tls/rustls/src/testdata/ca1-key.pem similarity index 100% rename from linkerd/identity/src/testdata/ca1-key.pem rename to linkerd/tls/rustls/src/testdata/ca1-key.pem diff --git a/linkerd/identity/src/testdata/ca1.pem b/linkerd/tls/rustls/src/testdata/ca1.pem similarity index 100% rename from linkerd/identity/src/testdata/ca1.pem rename to linkerd/tls/rustls/src/testdata/ca1.pem diff --git a/linkerd/identity/src/testdata/ca2-key.pem b/linkerd/tls/rustls/src/testdata/ca2-key.pem similarity index 100% rename from linkerd/identity/src/testdata/ca2-key.pem rename to linkerd/tls/rustls/src/testdata/ca2-key.pem diff --git a/linkerd/identity/src/testdata/ca2.pem b/linkerd/tls/rustls/src/testdata/ca2.pem similarity index 100% rename from linkerd/identity/src/testdata/ca2.pem rename to linkerd/tls/rustls/src/testdata/ca2.pem diff --git a/linkerd/identity/src/testdata/controller-linkerd-ca1/crt.der b/linkerd/tls/rustls/src/testdata/controller-linkerd-ca1/crt.der similarity index 100% rename from linkerd/identity/src/testdata/controller-linkerd-ca1/crt.der rename to linkerd/tls/rustls/src/testdata/controller-linkerd-ca1/crt.der diff --git a/linkerd/identity/src/testdata/controller-linkerd-ca1/csr.pem b/linkerd/tls/rustls/src/testdata/controller-linkerd-ca1/csr.pem similarity index 100% rename from linkerd/identity/src/testdata/controller-linkerd-ca1/csr.pem rename to linkerd/tls/rustls/src/testdata/controller-linkerd-ca1/csr.pem diff --git a/linkerd/identity/src/testdata/controller-linkerd-ca1/key.p8 b/linkerd/tls/rustls/src/testdata/controller-linkerd-ca1/key.p8 similarity index 100% rename from linkerd/identity/src/testdata/controller-linkerd-ca1/key.p8 rename to linkerd/tls/rustls/src/testdata/controller-linkerd-ca1/key.p8 diff --git a/linkerd/identity/src/testdata/default-default-ca1/crt.der b/linkerd/tls/rustls/src/testdata/default-default-ca1/crt.der similarity index 100% rename from linkerd/identity/src/testdata/default-default-ca1/crt.der rename to linkerd/tls/rustls/src/testdata/default-default-ca1/crt.der diff --git a/linkerd/identity/src/testdata/default-default-ca1/csr.pem b/linkerd/tls/rustls/src/testdata/default-default-ca1/csr.pem similarity index 100% rename from linkerd/identity/src/testdata/default-default-ca1/csr.pem rename to linkerd/tls/rustls/src/testdata/default-default-ca1/csr.pem diff --git a/linkerd/identity/src/testdata/default-default-ca1/key.p8 b/linkerd/tls/rustls/src/testdata/default-default-ca1/key.p8 similarity index 100% rename from linkerd/identity/src/testdata/default-default-ca1/key.p8 rename to linkerd/tls/rustls/src/testdata/default-default-ca1/key.p8 diff --git a/linkerd/identity/src/testdata/foo-ns1-ca1/crt.der b/linkerd/tls/rustls/src/testdata/foo-ns1-ca1/crt.der similarity index 100% rename from linkerd/identity/src/testdata/foo-ns1-ca1/crt.der rename to linkerd/tls/rustls/src/testdata/foo-ns1-ca1/crt.der diff --git a/linkerd/identity/src/testdata/foo-ns1-ca1/csr.pem b/linkerd/tls/rustls/src/testdata/foo-ns1-ca1/csr.pem similarity index 100% rename from linkerd/identity/src/testdata/foo-ns1-ca1/csr.pem rename to linkerd/tls/rustls/src/testdata/foo-ns1-ca1/csr.pem diff --git a/linkerd/identity/src/testdata/foo-ns1-ca1/key.p8 b/linkerd/tls/rustls/src/testdata/foo-ns1-ca1/key.p8 similarity index 100% rename from linkerd/identity/src/testdata/foo-ns1-ca1/key.p8 rename to linkerd/tls/rustls/src/testdata/foo-ns1-ca1/key.p8 diff --git a/linkerd/identity/src/testdata/foo-ns1-ca2/crt.der b/linkerd/tls/rustls/src/testdata/foo-ns1-ca2/crt.der similarity index 100% rename from linkerd/identity/src/testdata/foo-ns1-ca2/crt.der rename to linkerd/tls/rustls/src/testdata/foo-ns1-ca2/crt.der diff --git a/linkerd/identity/src/testdata/foo-ns1-ca2/csr.pem b/linkerd/tls/rustls/src/testdata/foo-ns1-ca2/csr.pem similarity index 100% rename from linkerd/identity/src/testdata/foo-ns1-ca2/csr.pem rename to linkerd/tls/rustls/src/testdata/foo-ns1-ca2/csr.pem diff --git a/linkerd/identity/src/testdata/foo-ns1-ca2/key.p8 b/linkerd/tls/rustls/src/testdata/foo-ns1-ca2/key.p8 similarity index 100% rename from linkerd/identity/src/testdata/foo-ns1-ca2/key.p8 rename to linkerd/tls/rustls/src/testdata/foo-ns1-ca2/key.p8 diff --git a/linkerd/identity/src/testdata/gen-certs.sh b/linkerd/tls/rustls/src/testdata/gen-certs.sh similarity index 100% rename from linkerd/identity/src/testdata/gen-certs.sh rename to linkerd/tls/rustls/src/testdata/gen-certs.sh diff --git a/linkerd/tls/rustls/src/tests.rs b/linkerd/tls/rustls/src/tests.rs new file mode 100644 index 0000000000..abf27117d0 --- /dev/null +++ b/linkerd/tls/rustls/src/tests.rs @@ -0,0 +1,35 @@ +use super::test_util::*; + +#[test] +fn can_construct_client_and_server_config_from_valid_settings() { + FOO_NS1.validate().expect("foo.ns1 must be valid"); +} + +#[test] +fn recognize_ca_did_not_issue_cert() { + let s = Identity { + trust_anchors: include_bytes!("testdata/ca2.pem"), + ..FOO_NS1 + }; + assert!(s.validate().is_err(), "ca2 should not validate foo.ns1"); +} + +#[test] +fn recognize_cert_is_not_valid_for_identity() { + let s = Identity { + crt: BAR_NS1.crt, + key: BAR_NS1.key, + ..FOO_NS1 + }; + assert!(s.validate().is_err(), "identity should not be valid"); +} + +#[test] +#[ignore] // XXX this doesn't fail because we don't actually check the key against the cert... +fn recognize_private_key_is_not_valid_for_cert() { + let s = Identity { + key: BAR_NS1.key, + ..FOO_NS1 + }; + assert!(s.validate().is_err(), "identity should not be valid"); +} diff --git a/linkerd/tls/src/client.rs b/linkerd/tls/src/client.rs index 0acdb7d3a7..340534f8a0 100644 --- a/linkerd/tls/src/client.rs +++ b/linkerd/tls/src/client.rs @@ -1,29 +1,24 @@ -use futures::{ - future::{Either, MapOk}, - prelude::*, -}; +use crate::{HasNegotiatedProtocol, NegotiatedProtocolRef}; +use futures::prelude::*; use linkerd_conditional::Conditional; use linkerd_identity as id; use linkerd_io as io; -use linkerd_stack::{layer, Param}; +use linkerd_stack::{layer, NewService, Oneshot, Param, Service, ServiceExt}; use std::{ fmt, future::Future, ops::Deref, pin::Pin, str::FromStr, - sync::Arc, task::{Context, Poll}, }; -pub use tokio_rustls::client::TlsStream; -use tokio_rustls::rustls::{self, Session}; use tracing::debug; /// A newtype for target server identities. #[derive(Clone, Debug, Eq, PartialEq, Hash)] pub struct ServerId(pub id::Name); -/// A stack paramter that configures a `Client` to establish a TLS connection. +/// A stack parameter that configures a `Client` to establish a TLS connection. #[derive(Clone, Debug, Eq, PartialEq, Hash)] pub struct ClientTls { pub server_id: ServerId, @@ -55,19 +50,18 @@ pub enum NoClientTls { /// known TLS identity. pub type ConditionalClientTls = Conditional; -pub type Config = Arc; - #[derive(Clone, Debug)] pub struct Client { - local: L, + identity: L, inner: C, } -type Connect = MapOk io::EitherIo>>; -type Handshake = - Pin>>> + Send + 'static>>; - -pub type Io = io::EitherIo>; +#[pin_project::pin_project(project = ConnectProj)] +#[derive(Debug)] +pub enum Connect> { + Connect(#[pin] F, Option), + Handshake(#[pin] Oneshot), +} // === impl ClientTls === @@ -83,25 +77,28 @@ impl From for ClientTls { // === impl Client === impl Client { - pub fn layer(local: L) -> impl layer::Layer + Clone { + pub fn layer(identity: L) -> impl layer::Layer + Clone { layer::mk(move |inner| Self { inner, - local: local.clone(), + identity: identity.clone(), }) } } -impl tower::Service for Client +impl Service for Client where - L: Clone + Param, T: Param, - C: tower::Service, + L: NewService, + C: Service, C::Response: io::AsyncRead + io::AsyncWrite + Send + Unpin, C::Future: Send + 'static, + H: Service + Send + 'static, + H::Response: io::AsyncRead + io::AsyncWrite + Send + Unpin + HasNegotiatedProtocol, + H::Future: Send + 'static, { - type Response = Io; + type Response = io::EitherIo; type Error = io::Error; - type Future = Either, Handshake>; + type Future = Connect; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { @@ -109,44 +106,49 @@ where } fn call(&mut self, target: T) -> Self::Future { - let ClientTls { server_id, alpn } = match target.param() { - Conditional::Some(tls) => tls, + let handshake = match target.param() { + Conditional::Some(tls) => Some(self.identity.new_service(tls)), Conditional::None(reason) => { debug!(%reason, "Peer does not support TLS"); - return Either::Left(self.inner.call(target).map_ok(io::EitherIo::Left)); - } - }; - - // Build a rustls ClientConfig for this connection. - // - // If ALPN protocols are configured by the endpoint, we have to clone the - // entire configuration and set the protocols. If there are no - // ALPN options, clone the Arc'd base configuration without - // extra allocation. - // - // TODO it would be better to avoid cloning the whole TLS config - // per-connection. - let handshake = match alpn { - None => tokio_rustls::TlsConnector::from(self.local.param()), - Some(AlpnProtocols(protocols)) => { - let mut config: rustls::ClientConfig = self.local.param().as_ref().clone(); - config.alpn_protocols = protocols; - tokio_rustls::TlsConnector::from(Arc::new(config)) + None } }; - debug!(server.id = %server_id, "Initiating TLS connection"); let connect = self.inner.call(target); - Either::Right(Box::pin(async move { - let io = connect.await?; - let sni = webpki::DNSNameRef::try_from_ascii(server_id.as_bytes()) - .expect("identity must be a valid DNS-like name"); - let io = handshake.connect(sni, io).await?; - if let Some(alpn) = io.get_ref().1.get_alpn_protocol() { - debug!(alpn = ?std::str::from_utf8(alpn)); + Connect::Connect(connect, handshake) + } +} + +impl Future for Connect +where + F: TryFuture, + H: Service, + H::Response: HasNegotiatedProtocol, +{ + type Output = io::Result>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + loop { + match self.as_mut().project() { + ConnectProj::Connect(fut, tls) => { + let io = futures::ready!(fut.try_poll(cx))?; + match tls.take() { + None => return Poll::Ready(Ok(io::EitherIo::Left(io))), + Some(tls) => self.set(Connect::Handshake(tls.oneshot(io))), + } + } + ConnectProj::Handshake(fut) => { + let io = futures::ready!(fut.try_poll(cx))?; + debug!( + alpn = io + .negotiated_protocol() + .and_then(|NegotiatedProtocolRef(p)| std::str::from_utf8(p).ok()) + .map(tracing::field::display) + ); + return Poll::Ready(Ok(io::EitherIo::Right(io))); + } } - Ok(io::EitherIo::Right(io)) - })) + } } } diff --git a/linkerd/tls/src/lib.rs b/linkerd/tls/src/lib.rs index b9e17d7251..baa1c8b805 100755 --- a/linkerd/tls/src/lib.rs +++ b/linkerd/tls/src/lib.rs @@ -1,19 +1,18 @@ #![deny(warnings, rust_2018_idioms)] #![forbid(unsafe_code)] -pub use linkerd_identity::LocalId; -use linkerd_io as io; -pub use tokio_rustls::rustls::Session; - pub mod client; pub mod server; +pub use linkerd_identity::LocalId; +use linkerd_io as io; + pub use self::{ client::{Client, ClientTls, ConditionalClientTls, NoClientTls, ServerId}, server::{ClientId, ConditionalServerTls, NewDetectTls, NoServerTls, ServerTls}, }; -/// A trait implented by transport streams to indicate its negotiated protocol. +/// A trait implemented by transport streams to indicate its negotiated protocol. pub trait HasNegotiatedProtocol { fn negotiated_protocol(&self) -> Option>; } @@ -58,26 +57,6 @@ impl std::fmt::Debug for NegotiatedProtocolRef<'_> { } } -impl HasNegotiatedProtocol for self::client::TlsStream { - #[inline] - fn negotiated_protocol(&self) -> Option> { - self.get_ref() - .1 - .get_alpn_protocol() - .map(NegotiatedProtocolRef) - } -} - -impl HasNegotiatedProtocol for self::server::TlsStream { - #[inline] - fn negotiated_protocol(&self) -> Option> { - self.get_ref() - .1 - .get_alpn_protocol() - .map(NegotiatedProtocolRef) - } -} - impl HasNegotiatedProtocol for tokio::net::TcpStream { #[inline] fn negotiated_protocol(&self) -> Option> { diff --git a/linkerd/tls/src/server/mod.rs b/linkerd/tls/src/server/mod.rs index 2b79bc9131..b82c79d31a 100644 --- a/linkerd/tls/src/server/mod.rs +++ b/linkerd/tls/src/server/mod.rs @@ -1,42 +1,29 @@ mod client_hello; -use crate::{LocalId, NegotiatedProtocol, ServerId}; +use crate::{NegotiatedProtocol, ServerId}; use bytes::BytesMut; use futures::prelude::*; use linkerd_conditional::Conditional; -use linkerd_dns_name as dns; use linkerd_error::Error; use linkerd_identity as id; use linkerd_io::{self as io, AsyncReadExt, EitherIo, PrefixedIo}; -use linkerd_stack::{layer, ExtractParam, InsertParam, NewService, Param}; +use linkerd_stack::{layer, ExtractParam, InsertParam, NewService, Param, Service, ServiceExt}; use std::{ fmt, ops::Deref, pin::Pin, str::FromStr, - sync::Arc, task::{Context, Poll}, }; use thiserror::Error; use tokio::time::{self, Duration}; -use tokio_rustls::rustls::{self, Session}; -pub use tokio_rustls::server::TlsStream; -use tower::util::ServiceExt; use tracing::{debug, trace, warn}; -pub type Config = Arc; - -/// Produces a server config that fails to handshake all connections. -pub fn empty_config() -> Config { - let verifier = rustls::NoClientAuth::new(); - Arc::new(rustls::ServerConfig::new(verifier)) -} - /// A newtype for remote client idenities. #[derive(Clone, Debug, Eq, PartialEq, Hash)] pub struct ClientId(pub id::Name); -/// Indicates a serverside connection's TLS status. +/// Indicates a server-side connection's TLS status. #[derive(Clone, Debug, Eq, PartialEq, Hash)] pub enum ServerTls { Established { @@ -68,9 +55,9 @@ pub enum NoServerTls { /// Indicates whether TLS was established on an accepted connection. pub type ConditionalServerTls = Conditional; -type DetectIo = EitherIo>; +pub type DetectIo = EitherIo>; -pub type Io = EitherIo>, DetectIo>; +pub type Io = EitherIo>; #[derive(Clone, Debug)] pub struct NewDetectTls { @@ -140,15 +127,18 @@ where } } -impl tower::Service for DetectTls +impl Service for DetectTls where I: io::Peek + io::AsyncRead + io::AsyncWrite + Send + Sync + Unpin + 'static, T: Clone + Send + 'static, P: InsertParam + Clone + Send + Sync + 'static, P::Target: Send + 'static, - L: Param + Param, + L: Param + Clone + Send + 'static, + L: Service, Response = (ServerTls, LIo), Error = io::Error>, + L::Future: Send, + LIo: io::AsyncRead + io::AsyncWrite + Send + Sync + Unpin + 'static, N: NewService + Clone + Send + 'static, - NSvc: tower::Service, Response = ()> + Send + 'static, + NSvc: Service, Response = ()> + Send + 'static, NSvc::Error: Into, NSvc::Future: Send, { @@ -165,8 +155,7 @@ where let params = self.params.clone(); let new_accept = self.inner.clone(); - let config: Config = self.local_identity.param(); - let LocalId(local_id) = self.local_identity.param(); + let tls = self.local_identity.clone(); // Detect the SNI from a ClientHello (or timeout). let Timeout(timeout) = self.timeout; @@ -174,11 +163,12 @@ where Box::pin(async move { let (sni, io) = detect.await.map_err(|_| ServerTlsTimeoutError(()))??; + let id::LocalId(id) = tls.param(); let (peer, io) = match sni { // If we detected an SNI matching this proxy, terminate TLS. - Some(ServerId(id)) if id == local_id => { + Some(ServerId(sni)) if sni == id => { trace!("Identified local SNI"); - let (peer, io) = handshake(config, io).await?; + let (peer, io) = tls.oneshot(io).await?; (Conditional::Some(peer), EitherIo::Left(io)) } // If we detected another SNI, continue proxying the @@ -256,56 +246,6 @@ where Ok((None, io)) } -async fn handshake(tls_config: Config, io: T) -> io::Result<(ServerTls, TlsStream)> -where - T: io::AsyncRead + io::AsyncWrite + Unpin, -{ - let io = tokio_rustls::TlsAcceptor::from(tls_config) - .accept(io) - .await?; - - // Determine the peer's identity, if it exist. - let client_id = client_identity(&io); - - let negotiated_protocol = io - .get_ref() - .1 - .get_alpn_protocol() - .map(|b| NegotiatedProtocol(b.into())); - - debug!(client.id = ?client_id, alpn = ?negotiated_protocol, "Accepted TLS connection"); - let tls = ServerTls::Established { - client_id, - negotiated_protocol, - }; - Ok((tls, io)) -} - -fn client_identity(tls: &TlsStream) -> Option { - use webpki::GeneralDNSNameRef; - - let (_io, session) = tls.get_ref(); - let certs = session.get_peer_certificates()?; - let c = certs.first().map(rustls::Certificate::as_ref)?; - let end_cert = webpki::EndEntityCert::from(c).ok()?; - let dns_names = end_cert.dns_names().ok()?; - - match dns_names.first()? { - GeneralDNSNameRef::DNSName(n) => { - // Unfortunately we have to allocate a new string here, since there's no way to get the - // underlying bytes from a `DNSNameRef`. - let name = AsRef::::as_ref(&n.to_owned()) - .parse::() - .ok()?; - Some(ClientId(name.into())) - } - GeneralDNSNameRef::Wildcard(_) => { - // Wildcards can perhaps be handled in a future path... - None - } - } -} - // === impl ClientId === impl From for ClientId { diff --git a/linkerd/tls/tests/tls_accept.rs b/linkerd/tls/tests/tls_accept.rs index 2440ae3972..e93aa098ae 100644 --- a/linkerd/tls/tests/tls_accept.rs +++ b/linkerd/tls/tests/tls_accept.rs @@ -8,17 +8,16 @@ use futures::prelude::*; use linkerd_conditional::Conditional; use linkerd_error::Infallible; -use linkerd_identity as id; use linkerd_io::{self as io, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use linkerd_proxy_transport::{ addrs::*, listen::{Addrs, Bind, BindTcp}, ConnectTcp, Keepalive, ListenAddr, }; -use linkerd_stack::{ExtractParam, InsertParam, NewService, Param}; +use linkerd_stack::{ExtractParam, InsertParam, NewService, Param, Service}; use linkerd_tls as tls; -use std::{future::Future, time::Duration}; -use std::{net::SocketAddr, sync::mpsc}; +use linkerd_tls_rustls as rustls; +use std::{future::Future, net::SocketAddr, sync::mpsc, task, time::Duration}; use tokio::net::TcpStream; use tower::{ layer::Layer, @@ -26,12 +25,15 @@ use tower::{ }; use tracing::instrument::Instrument; -type ServerConn = ((tls::ConditionalServerTls, T), tls::server::Io); +type ServerConn = ( + (tls::ConditionalServerTls, T), + io::EitherIo>, tls::server::DetectIo>, +); #[tokio::test(flavor = "current_thread")] async fn plaintext() { - let server_tls = id::test_util::FOO_NS1.validate().unwrap(); - let client_tls = id::test_util::BAR_NS1.validate().unwrap(); + let server_tls = rustls::test_util::FOO_NS1.validate().unwrap(); + let client_tls = rustls::test_util::BAR_NS1.validate().unwrap(); let (client_result, server_result) = run_test( client_tls, Conditional::None(tls::NoClientTls::NotProvidedByServiceDiscovery), @@ -56,8 +58,8 @@ async fn plaintext() { #[tokio::test(flavor = "current_thread")] async fn proxy_to_proxy_tls_works() { - let server_tls = id::test_util::FOO_NS1.validate().unwrap(); - let client_tls = id::test_util::BAR_NS1.validate().unwrap(); + let server_tls = rustls::test_util::FOO_NS1.validate().unwrap(); + let client_tls = rustls::test_util::BAR_NS1.validate().unwrap(); let server_id = tls::ServerId(server_tls.name().clone()); let (client_result, server_result) = run_test( client_tls.clone(), @@ -87,14 +89,14 @@ async fn proxy_to_proxy_tls_works() { #[tokio::test(flavor = "current_thread")] async fn proxy_to_proxy_tls_pass_through_when_identity_does_not_match() { - let server_tls = id::test_util::FOO_NS1.validate().unwrap(); + let server_tls = rustls::test_util::FOO_NS1.validate().unwrap(); // Misuse the client's identity instead of the server's identity. Any // identity other than `server_tls.server_identity` would work. - let client_tls = id::test_util::BAR_NS1 + let client_tls = rustls::test_util::BAR_NS1 .validate() .expect("valid client cert"); - let sni = id::test_util::BAR_NS1.crt().name().clone(); + let sni = rustls::test_util::BAR_NS1.crt().name().clone(); let (client_result, server_result) = run_test( client_tls, @@ -127,17 +129,19 @@ struct Transported { #[derive(Clone)] struct ServerParams { - identity: id::CrtKey, + identity: rustls::CrtKey, } +type ClientIo = io::EitherIo, rustls::ClientIo>>; + /// Runs a test for a single TCP connection. `client` processes the connection /// on the client side and `server` processes the connection on the server /// side. async fn run_test( - client_tls: id::CrtKey, + client_tls: rustls::CrtKey, client_server_id: Conditional, client: C, - server_id: id::CrtKey, + server_id: rustls::CrtKey, server: S, ) -> ( Transported, @@ -145,7 +149,7 @@ async fn run_test( ) where // Client - C: FnOnce(tls::client::Io>) -> CF + Clone + Send + 'static, + C: FnOnce(ClientIo) -> CF + Clone + Send + 'static, CF: Future> + Send + 'static, CR: Send + 'static, // Server @@ -310,7 +314,7 @@ struct Server; struct Target(SocketAddr, tls::ConditionalClientTls); #[derive(Clone)] -struct Tls(id::CrtKey); +struct Tls(rustls::CrtKey); // === impl Target === @@ -328,15 +332,30 @@ impl Param for Target { // === impl Tls === -impl Param for Tls { - fn param(&self) -> tls::client::Config { - self.0.client_config() +impl NewService for Tls { + type Service = rustls::Connect; + + fn new_service(&self, target: tls::ClientTls) -> Self::Service { + rustls::Connect::new(target, self.0.client_config()) } } -impl Param for Tls { - fn param(&self) -> tls::server::Config { - self.0.server_config() +impl Service for Tls +where + I: io::AsyncRead + io::AsyncWrite + Send + Unpin, +{ + type Response = (tls::ServerTls, rustls::ServerIo); + type Error = io::Error; + type Future = rustls::TerminateFuture; + + #[inline] + fn poll_ready(&mut self, _: &mut task::Context<'_>) -> task::Poll> { + task::Poll::Ready(Ok(())) + } + + #[inline] + fn call(&mut self, io: I) -> Self::Future { + rustls::terminate(self.0.server_config(), io) } }