Skip to content

Commit

Permalink
proxy: allow disable protocol detection on specific ports (#648)
Browse files Browse the repository at this point in the history
- Adds environment variables to configure a set of ports that, when an
  incoming connection has an SO_ORIGINAL_DST with a port matching, will
  disable protocol detection for that connection and immediately start a
  TCP proxy.
- Adds a default list of well known ports: SMTP and MySQL.

Closes #339
  • Loading branch information
seanmonstar authored Apr 2, 2018
1 parent 97546e0 commit 47f9665
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 19 deletions.
39 changes: 35 additions & 4 deletions proxy/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use std::collections::HashMap;
use std::env;
use std::iter::FromIterator;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::str::FromStr;
use std::time::Duration;

use http;
use indexmap::IndexSet;

use transport::{Host, HostAndPort, HostAndPortError};
use convert::TryFrom;
Expand Down Expand Up @@ -38,6 +40,10 @@ pub struct Config {
/// The maximum amount of time to wait for a connection to the private peer.
pub private_connect_timeout: Duration,

pub inbound_ports_disable_protocol_detection: IndexSet<u16>,

pub outbound_ports_disable_protocol_detection: IndexSet<u16>,

/// The path to "/etc/resolv.conf"
pub resolv_conf_path: PathBuf,

Expand Down Expand Up @@ -136,6 +142,11 @@ const ENV_PRIVATE_CONNECT_TIMEOUT: &str = "CONDUIT_PROXY_PRIVATE_CONNECT_TIMEOUT
const ENV_PUBLIC_CONNECT_TIMEOUT: &str = "CONDUIT_PROXY_PUBLIC_CONNECT_TIMEOUT";
pub const ENV_BIND_TIMEOUT: &str = "CONDUIT_PROXY_BIND_TIMEOUT";

// These *disable* our protocol detection for connections whose SO_ORIGINAL_DST
// has a port in the provided list.
pub const ENV_INBOUND_PORTS_DISABLE_PROTOCOL_DETECTION: &str = "CONDUIT_PROXY_INBOUND_PORTS_DISABLE_PROTOCOL_DETECTION";
pub const ENV_OUTBOUND_PORTS_DISABLE_PROTOCOL_DETECTION: &str = "CONDUIT_PROXY_OUTBOUND_PORTS_DISABLE_PROTOCOL_DETECTION";

const ENV_NODE_NAME: &str = "CONDUIT_PROXY_NODE_NAME";
const ENV_POD_NAME: &str = "CONDUIT_PROXY_POD_NAME";
pub const ENV_POD_NAMESPACE: &str = "CONDUIT_PROXY_POD_NAMESPACE";
Expand All @@ -156,6 +167,14 @@ const DEFAULT_PUBLIC_CONNECT_TIMEOUT_MS: u64 = 300;
const DEFAULT_BIND_TIMEOUT_MS: u64 = 10_000; // ten seconds, as in Linkerd.
const DEFAULT_RESOLV_CONF: &str = "/etc/resolv.conf";

// By default, we keep a list of known assigned ports of server-first protocols.
//
// https://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.txt
const DEFAULT_PORTS_DISABLE_PROTOCOL_DETECTION: &[u16] = &[
25, // SMTP
3306, // MySQL
];

// ===== impl Config =====

impl<'a> TryFrom<&'a Strings> for Config {
Expand All @@ -172,6 +191,8 @@ impl<'a> TryFrom<&'a Strings> for Config {
let private_forward = parse(strings, ENV_PRIVATE_FORWARD, str::parse);
let public_connect_timeout = parse(strings, ENV_PUBLIC_CONNECT_TIMEOUT, parse_number);
let private_connect_timeout = parse(strings, ENV_PRIVATE_CONNECT_TIMEOUT, parse_number);
let inbound_disable_ports = parse(strings, ENV_INBOUND_PORTS_DISABLE_PROTOCOL_DETECTION, parse_port_set);
let outbound_disable_ports = parse(strings, ENV_OUTBOUND_PORTS_DISABLE_PROTOCOL_DETECTION, parse_port_set);
let bind_timeout = parse(strings, ENV_BIND_TIMEOUT, parse_number);
let resolv_conf_path = strings.get(ENV_RESOLV_CONF);
let event_buffer_capacity = parse(strings, ENV_EVENT_BUFFER_CAPACITY, parse_number);
Expand Down Expand Up @@ -224,6 +245,10 @@ impl<'a> TryFrom<&'a Strings> for Config {
private_connect_timeout:
Duration::from_millis(private_connect_timeout?
.unwrap_or(DEFAULT_PRIVATE_CONNECT_TIMEOUT_MS)),
inbound_ports_disable_protocol_detection: inbound_disable_ports?
.unwrap_or_else(|| default_disable_ports_protocol_detection()),
outbound_ports_disable_protocol_detection: outbound_disable_ports?
.unwrap_or_else(|| default_disable_ports_protocol_detection()),
resolv_conf_path: resolv_conf_path?
.unwrap_or(DEFAULT_RESOLV_CONF.into())
.into(),
Expand All @@ -244,10 +269,8 @@ impl<'a> TryFrom<&'a Strings> for Config {
}
}

impl Config {
pub fn default_destination_namespace(&self) -> &str {
&self.pod_namespace
}
fn default_disable_ports_protocol_detection() -> IndexSet<u16> {
IndexSet::from_iter(DEFAULT_PORTS_DISABLE_PROTOCOL_DETECTION.iter().cloned())
}

// ===== impl Addr =====
Expand Down Expand Up @@ -330,6 +353,14 @@ fn parse_url(s: &str) -> Result<HostAndPort, ParseError> {
.map_err(|e| ParseError::UrlError(UrlError::AuthorityError(e)))
}

fn parse_port_set(s: &str) -> Result<IndexSet<u16>, ParseError> {
let mut set = IndexSet::new();
for num in s.split(',') {
set.insert(parse_number::<u16>(num)?);
}
Ok(set)
}

fn parse<T, Parse>(strings: &Strings, name: &str, parse: Parse) -> Result<Option<T>, Error>
where Parse: FnOnce(&str) -> Result<T, ParseError> {
match strings.get(name)? {
Expand Down
18 changes: 16 additions & 2 deletions proxy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ use std::sync::Arc;
use std::thread;
use std::time::Duration;

use indexmap::IndexSet;
use tokio_core::reactor::{Core, Handle};
use tower::NewService;
use tower_fn::*;
Expand Down Expand Up @@ -186,6 +187,14 @@ where
"serving Prometheus metrics on {:?}",
metrics_listener.local_addr(),
);
info!(
"protocol detection disabled for inbound ports {:?}",
config.inbound_ports_disable_protocol_detection,
);
info!(
"protocol detection disabled for outbound ports {:?}",
config.outbound_ports_disable_protocol_detection,
);

let (sensors, telemetry) = telemetry::new(
&process_ctx,
Expand Down Expand Up @@ -215,6 +224,7 @@ where
inbound_listener,
Inbound::new(default_addr, bind),
config.private_connect_timeout,
config.inbound_ports_disable_protocol_detection,
ctx,
sensors.clone(),
get_original_dst.clone(),
Expand All @@ -234,14 +244,15 @@ where
let outgoing = Outbound::new(
bind,
control,
config.default_destination_namespace().to_owned(),
config.pod_namespace.to_owned(),
config.bind_timeout,
);

let fut = serve(
outbound_listener,
outgoing,
config.public_connect_timeout,
config.outbound_ports_disable_protocol_detection,
ctx,
sensors,
get_original_dst,
Expand All @@ -254,6 +265,7 @@ where

let (_tx, controller_shutdown_signal) = futures::sync::oneshot::channel::<()>();
{
let report_timeout = config.report_timeout;
thread::Builder::new()
.name("controller-client".into())
.spawn(move || {
Expand Down Expand Up @@ -282,7 +294,7 @@ where
telemetry,
control_host_and_port,
dns_config,
config.report_timeout,
report_timeout,
&executor
);

Expand Down Expand Up @@ -312,6 +324,7 @@ fn serve<R, B, E, F, G>(
bound_port: BoundPort,
recognize: R,
tcp_connect_timeout: Duration,
disable_protocol_detection_ports: IndexSet<u16>,
proxy_ctx: Arc<ctx::Proxy>,
sensors: telemetry::Sensors,
get_orig_dst: G,
Expand Down Expand Up @@ -362,6 +375,7 @@ where
get_orig_dst,
stack,
tcp_connect_timeout,
disable_protocol_detection_ports,
executor.clone(),
);

Expand Down
83 changes: 70 additions & 13 deletions proxy/src/transparency/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::time::{Duration, Instant};
use futures::Future;
use http;
use hyper;
use indexmap::IndexSet;
use tokio_core::reactor::Handle;
use tower::NewService;
use tower_h2;
Expand All @@ -30,6 +31,7 @@ where
S: NewService<Request=http::Request<HttpBody>>,
S::Future: 'static,
{
disable_protocol_detection_ports: IndexSet<u16>,
executor: Handle,
get_orig_dst: G,
h1: hyper::server::Http,
Expand Down Expand Up @@ -60,11 +62,13 @@ where
get_orig_dst: G,
stack: S,
tcp_connect_timeout: Duration,
disable_protocol_detection_ports: IndexSet<u16>,
executor: Handle,
) -> Self {
let recv_body_svc = HttpBodyNewSvc::new(stack.clone());
let tcp = tcp::Proxy::new(tcp_connect_timeout, sensors.clone(), &executor);
Server {
disable_protocol_detection_ports,
executor: executor.clone(),
get_orig_dst,
h1: hyper::server::Http::new(),
Expand All @@ -89,9 +93,34 @@ where
// create Server context
let orig_dst = connection.original_dst_addr(&self.get_orig_dst);
let local_addr = connection.local_addr().unwrap_or(self.listen_addr);
let proxy_ctx = self.proxy_ctx.clone();

// We are using the port from the connection's SO_ORIGINAL_DST to
// determine whether to skip protocol detection, not any port that
// would be found after doing discovery.
let disable_protocol_detection = orig_dst
.map(|addr| {
self.disable_protocol_detection_ports.contains(&addr.port())
})
.unwrap_or(false);

if disable_protocol_detection {
trace!("protocol detection disabled for {:?}", orig_dst);
let fut = tcp_serve(
&self.tcp,
connection,
&self.sensors,
opened_at,
&self.proxy_ctx,
LocalAddr(&local_addr),
RemoteAddr(&remote_addr),
OrigDst(&orig_dst),
);
self.executor.spawn(fut);
return;
}

// try to sniff protocol
let proxy_ctx = self.proxy_ctx.clone();
let sniff = [0u8; 32];
let sensors = self.sensors.clone();
let h1 = self.h1.clone();
Expand Down Expand Up @@ -138,23 +167,51 @@ where
}
} else {
trace!("transparency did not detect protocol, treating as TCP");

let srv_ctx = ServerCtx::new(
tcp_serve(
&tcp,
connection,
&sensors,
opened_at,
&proxy_ctx,
&local_addr,
&remote_addr,
&orig_dst,
common::Protocol::Tcp,
);

// record telemetry
let tcp_in = sensors.accept(connection, opened_at, &srv_ctx);

tcp.serve(tcp_in, srv_ctx)
LocalAddr(&local_addr),
RemoteAddr(&remote_addr),
OrigDst(&orig_dst),
)
}
});

self.executor.spawn(fut);
}
}

// These newtypes act as a form of keyword arguments.
//
// It should be easier to notice when wrapping `LocalAddr(remote_addr)` at
// the call site, then simply passing multiple socket addr arguments.
struct LocalAddr<'a>(&'a SocketAddr);
struct RemoteAddr<'a>(&'a SocketAddr);
struct OrigDst<'a>(&'a Option<SocketAddr>);

fn tcp_serve(
tcp: &tcp::Proxy,
connection: Connection,
sensors: &Sensors,
opened_at: Instant,
proxy_ctx: &Arc<ProxyCtx>,
local_addr: LocalAddr,
remote_addr: RemoteAddr,
orig_dst: OrigDst,
) -> Box<Future<Item=(), Error=()>> {
let srv_ctx = ServerCtx::new(
proxy_ctx,
local_addr.0,
remote_addr.0,
orig_dst.0,
common::Protocol::Tcp,
);

// record telemetry
let tcp_in = sensors.accept(connection, opened_at, &srv_ctx);

tcp.serve(tcp_in, srv_ctx)
}
36 changes: 36 additions & 0 deletions proxy/tests/support/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ pub struct Proxy {
outbound: Option<server::Listening>,

metrics_flush_interval: Option<Duration>,
inbound_disable_ports_protocol_detection: Option<Vec<u16>>,
outbound_disable_ports_protocol_detection: Option<Vec<u16>>,
}

#[derive(Debug)]
Expand All @@ -38,6 +40,8 @@ impl Proxy {
outbound: None,

metrics_flush_interval: None,
inbound_disable_ports_protocol_detection: None,
outbound_disable_ports_protocol_detection: None,
}
}

Expand All @@ -61,6 +65,16 @@ impl Proxy {
self
}

pub fn disable_inbound_ports_protocol_detection(mut self, ports: Vec<u16>) -> Self {
self.inbound_disable_ports_protocol_detection = Some(ports);
self
}

pub fn disable_outbound_ports_protocol_detection(mut self, ports: Vec<u16>) -> Self {
self.outbound_disable_ports_protocol_detection = Some(ports);
self
}

pub fn run(self) -> Listening {
self.run_with_test_env(config::TestEnv::new())
}
Expand Down Expand Up @@ -121,6 +135,28 @@ fn run(proxy: Proxy, mut env: config::TestEnv) -> Listening {
env.put(config::ENV_METRICS_LISTENER, "tcp://127.0.0.1:0".to_owned());
env.put(config::ENV_POD_NAMESPACE, "test".to_owned());

if let Some(ports) = proxy.inbound_disable_ports_protocol_detection {
let ports = ports.into_iter()
.map(|p| p.to_string())
.collect::<Vec<_>>()
.join(",");
env.put(
config::ENV_INBOUND_PORTS_DISABLE_PROTOCOL_DETECTION,
ports
);
}

if let Some(ports) = proxy.outbound_disable_ports_protocol_detection {
let ports = ports.into_iter()
.map(|p| p.to_string())
.collect::<Vec<_>>()
.join(",");
env.put(
config::ENV_OUTBOUND_PORTS_DISABLE_PROTOCOL_DETECTION,
ports
);
}

let mut config = config::Config::try_from(&env).unwrap();

// TODO: We currently can't use `config::ENV_METRICS_FLUSH_INTERVAL_SECS`
Expand Down
Loading

0 comments on commit 47f9665

Please sign in to comment.