diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 5656aa95b5881..d79bbf7469329 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -1,11 +1,13 @@ use std::collections::HashMap; use std::env; +use std::iter::FromIterator; use std::net::SocketAddr; use std::path::PathBuf; use std::str::FromStr; use std::time::Duration; use http; +use indexmap::IndexSet; use transport::{Host, HostAndPort, HostAndPortError}; use convert::TryFrom; @@ -38,6 +40,10 @@ pub struct Config { /// The maximum amount of time to wait for a connection to the private peer. pub private_connect_timeout: Duration, + pub inbound_ports_disable_protocol_detection: IndexSet, + + pub outbound_ports_disable_protocol_detection: IndexSet, + /// The path to "/etc/resolv.conf" pub resolv_conf_path: PathBuf, @@ -136,6 +142,11 @@ const ENV_PRIVATE_CONNECT_TIMEOUT: &str = "CONDUIT_PROXY_PRIVATE_CONNECT_TIMEOUT const ENV_PUBLIC_CONNECT_TIMEOUT: &str = "CONDUIT_PROXY_PUBLIC_CONNECT_TIMEOUT"; pub const ENV_BIND_TIMEOUT: &str = "CONDUIT_PROXY_BIND_TIMEOUT"; +// These *disable* our protocol detection for connections whose SO_ORIGINAL_DST +// has a port in the provided list. +pub const ENV_INBOUND_PORTS_DISABLE_PROTOCOL_DETECTION: &str = "CONDUIT_PROXY_INBOUND_PORTS_DISABLE_PROTOCOL_DETECTION"; +pub const ENV_OUTBOUND_PORTS_DISABLE_PROTOCOL_DETECTION: &str = "CONDUIT_PROXY_OUTBOUND_PORTS_DISABLE_PROTOCOL_DETECTION"; + const ENV_NODE_NAME: &str = "CONDUIT_PROXY_NODE_NAME"; const ENV_POD_NAME: &str = "CONDUIT_PROXY_POD_NAME"; pub const ENV_POD_NAMESPACE: &str = "CONDUIT_PROXY_POD_NAMESPACE"; @@ -156,6 +167,14 @@ const DEFAULT_PUBLIC_CONNECT_TIMEOUT_MS: u64 = 300; const DEFAULT_BIND_TIMEOUT_MS: u64 = 10_000; // ten seconds, as in Linkerd. const DEFAULT_RESOLV_CONF: &str = "/etc/resolv.conf"; +// By default, we keep a list of known assigned ports of server-first protocols. +// +// https://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.txt +const DEFAULT_PORTS_DISABLE_PROTOCOL_DETECTION: &[u16] = &[ + 25, // SMTP + 3306, // MySQL +]; + // ===== impl Config ===== impl<'a> TryFrom<&'a Strings> for Config { @@ -172,6 +191,8 @@ impl<'a> TryFrom<&'a Strings> for Config { let private_forward = parse(strings, ENV_PRIVATE_FORWARD, str::parse); let public_connect_timeout = parse(strings, ENV_PUBLIC_CONNECT_TIMEOUT, parse_number); let private_connect_timeout = parse(strings, ENV_PRIVATE_CONNECT_TIMEOUT, parse_number); + let inbound_disable_ports = parse(strings, ENV_INBOUND_PORTS_DISABLE_PROTOCOL_DETECTION, parse_port_set); + let outbound_disable_ports = parse(strings, ENV_OUTBOUND_PORTS_DISABLE_PROTOCOL_DETECTION, parse_port_set); let bind_timeout = parse(strings, ENV_BIND_TIMEOUT, parse_number); let resolv_conf_path = strings.get(ENV_RESOLV_CONF); let event_buffer_capacity = parse(strings, ENV_EVENT_BUFFER_CAPACITY, parse_number); @@ -224,6 +245,10 @@ impl<'a> TryFrom<&'a Strings> for Config { private_connect_timeout: Duration::from_millis(private_connect_timeout? .unwrap_or(DEFAULT_PRIVATE_CONNECT_TIMEOUT_MS)), + inbound_ports_disable_protocol_detection: inbound_disable_ports? + .unwrap_or_else(|| default_disable_ports_protocol_detection()), + outbound_ports_disable_protocol_detection: outbound_disable_ports? + .unwrap_or_else(|| default_disable_ports_protocol_detection()), resolv_conf_path: resolv_conf_path? .unwrap_or(DEFAULT_RESOLV_CONF.into()) .into(), @@ -244,10 +269,8 @@ impl<'a> TryFrom<&'a Strings> for Config { } } -impl Config { - pub fn default_destination_namespace(&self) -> &str { - &self.pod_namespace - } +fn default_disable_ports_protocol_detection() -> IndexSet { + IndexSet::from_iter(DEFAULT_PORTS_DISABLE_PROTOCOL_DETECTION.iter().cloned()) } // ===== impl Addr ===== @@ -330,6 +353,14 @@ fn parse_url(s: &str) -> Result { .map_err(|e| ParseError::UrlError(UrlError::AuthorityError(e))) } +fn parse_port_set(s: &str) -> Result, ParseError> { + let mut set = IndexSet::new(); + for num in s.split(',') { + set.insert(parse_number::(num)?); + } + Ok(set) +} + fn parse(strings: &Strings, name: &str, parse: Parse) -> Result, Error> where Parse: FnOnce(&str) -> Result { match strings.get(name)? { diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index e4aa2fb25ff82..9320af0635d54 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -51,6 +51,7 @@ use std::sync::Arc; use std::thread; use std::time::Duration; +use indexmap::IndexSet; use tokio_core::reactor::{Core, Handle}; use tower::NewService; use tower_fn::*; @@ -186,6 +187,14 @@ where "serving Prometheus metrics on {:?}", metrics_listener.local_addr(), ); + info!( + "protocol detection disabled for inbound ports {:?}", + config.inbound_ports_disable_protocol_detection, + ); + info!( + "protocol detection disabled for outbound ports {:?}", + config.outbound_ports_disable_protocol_detection, + ); let (sensors, telemetry) = telemetry::new( &process_ctx, @@ -215,6 +224,7 @@ where inbound_listener, Inbound::new(default_addr, bind), config.private_connect_timeout, + config.inbound_ports_disable_protocol_detection, ctx, sensors.clone(), get_original_dst.clone(), @@ -234,7 +244,7 @@ where let outgoing = Outbound::new( bind, control, - config.default_destination_namespace().to_owned(), + config.pod_namespace.to_owned(), config.bind_timeout, ); @@ -242,6 +252,7 @@ where outbound_listener, outgoing, config.public_connect_timeout, + config.outbound_ports_disable_protocol_detection, ctx, sensors, get_original_dst, @@ -254,6 +265,7 @@ where let (_tx, controller_shutdown_signal) = futures::sync::oneshot::channel::<()>(); { + let report_timeout = config.report_timeout; thread::Builder::new() .name("controller-client".into()) .spawn(move || { @@ -282,7 +294,7 @@ where telemetry, control_host_and_port, dns_config, - config.report_timeout, + report_timeout, &executor ); @@ -312,6 +324,7 @@ fn serve( bound_port: BoundPort, recognize: R, tcp_connect_timeout: Duration, + disable_protocol_detection_ports: IndexSet, proxy_ctx: Arc, sensors: telemetry::Sensors, get_orig_dst: G, @@ -362,6 +375,7 @@ where get_orig_dst, stack, tcp_connect_timeout, + disable_protocol_detection_ports, executor.clone(), ); diff --git a/proxy/src/transparency/server.rs b/proxy/src/transparency/server.rs index 4b24b518ecade..b3e3edd772d9c 100644 --- a/proxy/src/transparency/server.rs +++ b/proxy/src/transparency/server.rs @@ -6,6 +6,7 @@ use std::time::{Duration, Instant}; use futures::Future; use http; use hyper; +use indexmap::IndexSet; use tokio_core::reactor::Handle; use tower::NewService; use tower_h2; @@ -30,6 +31,7 @@ where S: NewService>, S::Future: 'static, { + disable_protocol_detection_ports: IndexSet, executor: Handle, get_orig_dst: G, h1: hyper::server::Http, @@ -60,11 +62,13 @@ where get_orig_dst: G, stack: S, tcp_connect_timeout: Duration, + disable_protocol_detection_ports: IndexSet, executor: Handle, ) -> Self { let recv_body_svc = HttpBodyNewSvc::new(stack.clone()); let tcp = tcp::Proxy::new(tcp_connect_timeout, sensors.clone(), &executor); Server { + disable_protocol_detection_ports, executor: executor.clone(), get_orig_dst, h1: hyper::server::Http::new(), @@ -89,9 +93,34 @@ where // create Server context let orig_dst = connection.original_dst_addr(&self.get_orig_dst); let local_addr = connection.local_addr().unwrap_or(self.listen_addr); - let proxy_ctx = self.proxy_ctx.clone(); + + // We are using the port from the connection's SO_ORIGINAL_DST to + // determine whether to skip protocol detection, not any port that + // would be found after doing discovery. + let disable_protocol_detection = orig_dst + .map(|addr| { + self.disable_protocol_detection_ports.contains(&addr.port()) + }) + .unwrap_or(false); + + if disable_protocol_detection { + trace!("protocol detection disabled for {:?}", orig_dst); + let fut = tcp_serve( + &self.tcp, + connection, + &self.sensors, + opened_at, + &self.proxy_ctx, + LocalAddr(&local_addr), + RemoteAddr(&remote_addr), + OrigDst(&orig_dst), + ); + self.executor.spawn(fut); + return; + } // try to sniff protocol + let proxy_ctx = self.proxy_ctx.clone(); let sniff = [0u8; 32]; let sensors = self.sensors.clone(); let h1 = self.h1.clone(); @@ -138,19 +167,16 @@ where } } else { trace!("transparency did not detect protocol, treating as TCP"); - - let srv_ctx = ServerCtx::new( + tcp_serve( + &tcp, + connection, + &sensors, + opened_at, &proxy_ctx, - &local_addr, - &remote_addr, - &orig_dst, - common::Protocol::Tcp, - ); - - // record telemetry - let tcp_in = sensors.accept(connection, opened_at, &srv_ctx); - - tcp.serve(tcp_in, srv_ctx) + LocalAddr(&local_addr), + RemoteAddr(&remote_addr), + OrigDst(&orig_dst), + ) } }); @@ -158,3 +184,34 @@ where } } +// These newtypes act as a form of keyword arguments. +// +// It should be easier to notice when wrapping `LocalAddr(remote_addr)` at +// the call site, then simply passing multiple socket addr arguments. +struct LocalAddr<'a>(&'a SocketAddr); +struct RemoteAddr<'a>(&'a SocketAddr); +struct OrigDst<'a>(&'a Option); + +fn tcp_serve( + tcp: &tcp::Proxy, + connection: Connection, + sensors: &Sensors, + opened_at: Instant, + proxy_ctx: &Arc, + local_addr: LocalAddr, + remote_addr: RemoteAddr, + orig_dst: OrigDst, +) -> Box> { + let srv_ctx = ServerCtx::new( + proxy_ctx, + local_addr.0, + remote_addr.0, + orig_dst.0, + common::Protocol::Tcp, + ); + + // record telemetry + let tcp_in = sensors.accept(connection, opened_at, &srv_ctx); + + tcp.serve(tcp_in, srv_ctx) +} diff --git a/proxy/tests/support/proxy.rs b/proxy/tests/support/proxy.rs index 55dbbb6f671d5..960b3d2377aa0 100644 --- a/proxy/tests/support/proxy.rs +++ b/proxy/tests/support/proxy.rs @@ -15,6 +15,8 @@ pub struct Proxy { outbound: Option, metrics_flush_interval: Option, + inbound_disable_ports_protocol_detection: Option>, + outbound_disable_ports_protocol_detection: Option>, } #[derive(Debug)] @@ -38,6 +40,8 @@ impl Proxy { outbound: None, metrics_flush_interval: None, + inbound_disable_ports_protocol_detection: None, + outbound_disable_ports_protocol_detection: None, } } @@ -61,6 +65,16 @@ impl Proxy { self } + pub fn disable_inbound_ports_protocol_detection(mut self, ports: Vec) -> Self { + self.inbound_disable_ports_protocol_detection = Some(ports); + self + } + + pub fn disable_outbound_ports_protocol_detection(mut self, ports: Vec) -> Self { + self.outbound_disable_ports_protocol_detection = Some(ports); + self + } + pub fn run(self) -> Listening { self.run_with_test_env(config::TestEnv::new()) } @@ -121,6 +135,28 @@ fn run(proxy: Proxy, mut env: config::TestEnv) -> Listening { env.put(config::ENV_METRICS_LISTENER, "tcp://127.0.0.1:0".to_owned()); env.put(config::ENV_POD_NAMESPACE, "test".to_owned()); + if let Some(ports) = proxy.inbound_disable_ports_protocol_detection { + let ports = ports.into_iter() + .map(|p| p.to_string()) + .collect::>() + .join(","); + env.put( + config::ENV_INBOUND_PORTS_DISABLE_PROTOCOL_DETECTION, + ports + ); + } + + if let Some(ports) = proxy.outbound_disable_ports_protocol_detection { + let ports = ports.into_iter() + .map(|p| p.to_string()) + .collect::>() + .join(","); + env.put( + config::ENV_OUTBOUND_PORTS_DISABLE_PROTOCOL_DETECTION, + ports + ); + } + let mut config = config::Config::try_from(&env).unwrap(); // TODO: We currently can't use `config::ENV_METRICS_FLUSH_INTERVAL_SECS` diff --git a/proxy/tests/transparency.rs b/proxy/tests/transparency.rs index 689eac316bc9a..e7d7e9f0b6fac 100644 --- a/proxy/tests/transparency.rs +++ b/proxy/tests/transparency.rs @@ -228,6 +228,46 @@ fn inbound_tcp() { assert_eq!(tcp_client.read(), msg2.as_bytes()); } +#[test] +fn tcp_server_first() { + use std::sync::mpsc; + + let _ = env_logger::try_init(); + + let msg1 = "custom tcp server starts"; + let msg2 = "custom tcp client second"; + + let (tx, rx) = mpsc::channel(); + + let srv = server::tcp() + .accept_fut(move |sock| { + tokio_io::io::write_all(sock, msg1.as_bytes()) + .and_then(move |(sock, _)| { + tokio_io::io::read(sock, vec![0; 512]) + }) + .map(move |(_sock, vec, n)| { + assert_eq!(&vec[..n], msg2.as_bytes()); + tx.send(()).unwrap(); + }) + .map_err(|e| panic!("tcp server error: {}", e)) + }) + .run(); + let ctrl = controller::new().run(); + let proxy = proxy::new() + .controller(ctrl) + .disable_inbound_ports_protocol_detection(vec![srv.addr.port()]) + .inbound(srv) + .run(); + + let client = client::tcp(proxy.inbound); + + let tcp_client = client.connect(); + + assert_eq!(tcp_client.read(), msg1.as_bytes()); + tcp_client.write(msg2); + rx.recv_timeout(Duration::from_secs(5)).unwrap(); +} + #[test] fn tcp_with_no_orig_dst() { let _ = env_logger::try_init();