Skip to content

Commit

Permalink
Merge pull request #41 from launchbadge/ab/tls
Browse files Browse the repository at this point in the history
implement TLS for Postgres and MySQL
  • Loading branch information
mehcode authored Jan 14, 2020
2 parents 6c8fd94 + 114aaa5 commit 684068a
Show file tree
Hide file tree
Showing 19 changed files with 800 additions and 25 deletions.
11 changes: 8 additions & 3 deletions .github/workflows/mysql.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ jobs:
# will assign a random free host port
- 3306/tcp
# needed because the container does not provide a healthcheck
options: --health-cmd "mysqladmin ping --silent" --health-interval 30s --health-timeout 30s --health-retries 10
options: >-
--health-cmd "mysqladmin ping --silent" --health-interval 30s --health-timeout 30s
--health-retries 10 -v /data/mysql:/var/lib/mysql
steps:
- uses: actions/checkout@v1
Expand All @@ -48,9 +51,11 @@ jobs:

# -----------------------------------------------------

- run: cargo test -p sqlx --no-default-features --features 'mysql macros chrono'
- run: cargo test -p sqlx --no-default-features --features 'mysql macros chrono tls'
env:
DATABASE_URL: mysql://root:password@localhost:${{ job.services.mysql.ports[3306] }}/sqlx
# pass the path to the CA that the MySQL service generated
# Github Actions' YML parser doesn't handle multiline strings correctly
DATABASE_URL: mysql://root:password@localhost:${{ job.services.mysql.ports[3306] }}/sqlx?ssl-mode=VERIFY_CA&ssl-ca=%2Fdata%2Fmysql%2Fca.pem

# Rust ------------------------------------------------

Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/postgres.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ jobs:

# -----------------------------------------------------

# Check that we build with TLS support (TODO: we need a postgres image with SSL certs to test)
- run: cargo check -p sqlx-core --no-default-features --features 'postgres macros uuid chrono tls'

- run: cargo test -p sqlx --no-default-features --features 'postgres macros uuid chrono'
env:
DATABASE_URL: postgres://postgres:postgres@localhost:${{ job.services.postgres.ports[5432] }}/postgres
Expand Down
181 changes: 181 additions & 0 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ all-features = true
[features]
default = [ "macros" ]
macros = [ "sqlx-macros", "proc-macro-hack" ]
tls = ["sqlx-core/tls"]

# database
postgres = [ "sqlx-core/postgres", "sqlx-macros/postgres" ]
Expand All @@ -48,6 +49,7 @@ hex = "0.4.0"
[dev-dependencies]
anyhow = "1.0.26"
futures = "0.3.1"
env_logger = "0.7"
async-std = { version = "1.4.0", features = [ "attributes" ] }
dotenv = "0.15.0"

Expand Down
2 changes: 2 additions & 0 deletions sqlx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ default = []
unstable = []
postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac" ]
mysql = [ "sha-1", "sha2", "generic-array", "num-bigint", "base64", "digest", "rand" ]
tls = ["async-native-tls"]

[dependencies]
async-native-tls = { version = "0.3", optional = true }
async-std = "1.4.0"
async-stream = { version = "0.2.0", default-features = false }
base64 = { version = "0.11.0", default-features = false, optional = true, features = [ "std" ] }
Expand Down
31 changes: 31 additions & 0 deletions sqlx-core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ pub enum Error {
/// [Pool::close] was called while we were waiting in [Pool::acquire].
PoolClosed,

/// An error occurred during a TLS upgrade.
TlsUpgrade(Box<dyn StdError + Send + Sync>),

Decode(DecodeError),

// TODO: Remove and replace with `#[non_exhaustive]` when possible
Expand All @@ -62,6 +65,8 @@ impl StdError for Error {

Error::Decode(DecodeError::Other(error)) => Some(&**error),

Error::TlsUpgrade(error) => Some(&**error),

_ => None,
}
}
Expand Down Expand Up @@ -100,6 +105,8 @@ impl Display for Error {

Error::PoolClosed => f.write_str("attempted to acquire a connection on a closed pool"),

Error::TlsUpgrade(ref err) => write!(f, "error during TLS upgrade: {}", err),

Error::__Nonexhaustive => unreachable!(),
}
}
Expand Down Expand Up @@ -140,6 +147,21 @@ impl From<ProtocolError<'_>> for Error {
}
}

#[cfg(feature = "tls")]
impl From<async_native_tls::Error> for Error {
#[inline]
fn from(err: async_native_tls::Error) -> Self {
Error::TlsUpgrade(err.into())
}
}

impl From<TlsError<'_>> for Error {
#[inline]
fn from(err: TlsError<'_>) -> Self {
Error::TlsUpgrade(err.args.to_string().into())
}
}

impl<T> From<T> for Error
where
T: 'static + DatabaseError,
Expand Down Expand Up @@ -189,6 +211,15 @@ macro_rules! protocol_err (
}
);

pub(crate) struct TlsError<'a> {
pub args: fmt::Arguments<'a>,
}

#[allow(unused_macros)]
macro_rules! tls_err {
($($args:tt)*) => { crate::error::TlsError { args: format_args!($($args)*)} };
}

#[allow(unused_macros)]
macro_rules! impl_fmt_error {
($err:ty) => {
Expand Down
21 changes: 21 additions & 0 deletions sqlx-core/src/io/buf_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use async_std::io::{
Read, Write,
};
use std::io;
use std::ops::{Deref, DerefMut};

const RBUF_SIZE: usize = 8 * 1024;

Expand Down Expand Up @@ -51,6 +52,12 @@ where
Ok(())
}

pub fn clear_bufs(&mut self) {
self.rbuf_rindex = 0;
self.rbuf_windex = 0;
self.wbuf.clear();
}

#[inline]
pub fn consume(&mut self, cnt: usize) {
self.rbuf_rindex += cnt;
Expand Down Expand Up @@ -118,6 +125,20 @@ where
}
}

impl<S> Deref for BufStream<S> {
type Target = S;

fn deref(&self) -> &Self::Target {
&self.stream
}
}

impl<S> DerefMut for BufStream<S> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.stream
}
}

// TODO: Find a nicer way to do this
// Return `Ok(None)` immediately from a function if the wrapped value is `None`
#[allow(unused)]
Expand Down
3 changes: 3 additions & 0 deletions sqlx-core/src/io/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ mod buf;
mod buf_mut;
mod byte_str;

mod tls;

pub use self::{
buf::{Buf, ToBuf},
buf_mut::BufMut,
buf_stream::BufStream,
byte_str::ByteStr,
tls::MaybeTlsStream,
};

#[cfg(test)]
Expand Down
126 changes: 126 additions & 0 deletions sqlx-core/src/io/tls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
use std::io::{IoSlice, IoSliceMut};
use std::pin::Pin;
use std::task::{Context, Poll};

use async_std::io::{self, Read, Write};
use async_std::net::{Shutdown, TcpStream};

use crate::url::Url;

use self::Inner::*;

pub struct MaybeTlsStream {
inner: Inner,
}

enum Inner {
NotTls(TcpStream),
#[cfg(feature = "tls")]
Tls(async_native_tls::TlsStream<TcpStream>),
#[cfg(feature = "tls")]
Upgrading,
}

impl MaybeTlsStream {
pub async fn connect(url: &Url, default_port: u16) -> crate::Result<Self> {
let conn = TcpStream::connect((url.host(), url.port(default_port))).await?;
Ok(Self {
inner: Inner::NotTls(conn),
})
}

#[allow(dead_code)]
pub fn is_tls(&self) -> bool {
match self.inner {
Inner::NotTls(_) => false,
#[cfg(feature = "tls")]
Inner::Tls(_) => true,
#[cfg(feature = "tls")]
Inner::Upgrading => false,
}
}

#[cfg(feature = "tls")]
pub async fn upgrade(
&mut self,
url: &Url,
connector: async_native_tls::TlsConnector,
) -> crate::Result<()> {
let conn = match std::mem::replace(&mut self.inner, Upgrading) {
NotTls(conn) => conn,
Tls(_) => return Err(tls_err!("connection already upgraded").into()),
Upgrading => return Err(tls_err!("connection already failed to upgrade").into()),
};

self.inner = Tls(connector.connect(url.host(), conn).await?);

Ok(())
}

pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
match self.inner {
NotTls(ref conn) => conn.shutdown(how),
#[cfg(feature = "tls")]
Tls(ref conn) => conn.get_ref().shutdown(how),
#[cfg(feature = "tls")]
// connection already closed
Upgrading => Ok(()),
}
}
}

macro_rules! forward_pin (
($self:ident.$method:ident($($arg:ident),*)) => (
match &mut $self.inner {
NotTls(ref mut conn) => Pin::new(conn).$method($($arg),*),
#[cfg(feature = "tls")]
Tls(ref mut conn) => Pin::new(conn).$method($($arg),*),
#[cfg(feature = "tls")]
Upgrading => Err(io::Error::new(io::ErrorKind::Other, "connection broken; TLS upgrade failed")).into(),
}
)
);

impl Read for MaybeTlsStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
forward_pin!(self.poll_read(cx, buf))
}

fn poll_read_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context,
bufs: &mut [IoSliceMut],
) -> Poll<io::Result<usize>> {
forward_pin!(self.poll_read_vectored(cx, bufs))
}
}

impl Write for MaybeTlsStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
forward_pin!(self.poll_write(cx, buf))
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
forward_pin!(self.poll_flush(cx))
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
forward_pin!(self.poll_close(cx))
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context,
bufs: &[IoSlice],
) -> Poll<io::Result<usize>> {
forward_pin!(self.poll_write_vectored(cx, bufs))
}
}
Loading

0 comments on commit 684068a

Please sign in to comment.