Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade rustls to 0.20 #1505

Merged
merged 12 commits into from
Apr 14, 2022
56 changes: 32 additions & 24 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions sqlx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ _rt-actix = ["tokio-stream"]
_rt-async-std = []
_rt-tokio = ["tokio-stream"]
_tls-native-tls = []
_tls-rustls = ["rustls", "webpki", "webpki-roots"]
_tls-rustls = ["rustls", "rustls-pemfile", "webpki-roots"]

# support offline/decoupled building (enables serialization of `Describe`)
offline = ["serde", "either/serde"]
Expand Down Expand Up @@ -144,7 +144,8 @@ parking_lot = "0.11.0"
rand = { version = "0.8.3", default-features = false, optional = true, features = ["std", "std_rng"] }
regex = { version = "1.3.9", optional = true }
rsa = { version = "0.5.0", optional = true }
rustls = { version = "0.19.0", features = ["dangerous_configuration"], optional = true }
rustls = { version = "0.20.1", features = ["dangerous_configuration"], optional = true }
rustls-pemfile = { version = "0.2.0", optional = true }
serde = { version = "1.0.106", features = ["derive", "rc"], optional = true }
serde_json = { version = "1.0.51", features = ["raw_value"], optional = true }
sha-1 = { version = "0.9.0", default-features = false, optional = true }
Expand All @@ -156,8 +157,7 @@ tokio-stream = { version = "0.1.2", features = ["fs"], optional = true }
smallvec = "1.4.0"
url = { version = "2.1.1", default-features = false }
uuid = { version = "0.8.1", default-features = false, optional = true, features = ["std"] }
webpki = { version = "0.21.0", optional = true }
webpki-roots = { version = "0.21.0", optional = true }
webpki-roots = { version = "0.22.0", optional = true }
whoami = { version = "1.0.1", optional = true }
stringprep = "0.1.2"
bstr = { version = "0.2.14", default-features = false, features = ["std"], optional = true }
Expand Down
8 changes: 0 additions & 8 deletions sqlx-core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,14 +253,6 @@ impl From<sqlx_rt::native_tls::Error> for Error {
}
}

#[cfg(feature = "_tls-rustls")]
impl From<webpki::InvalidDNSNameError> for Error {
#[inline]
fn from(error: webpki::InvalidDNSNameError) -> Self {
Error::Tls(Box::new(error))
}
}

paolobarbolini marked this conversation as resolved.
Show resolved Hide resolved
// Format an error message as a `Protocol` error
macro_rules! err_protocol {
($expr:expr) => {
Expand Down
3 changes: 2 additions & 1 deletion sqlx-core/src/net/tls/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![allow(dead_code)]

use std::convert::TryFrom;
use std::io;
use std::ops::{Deref, DerefMut};
use std::path::PathBuf;
Expand Down Expand Up @@ -104,7 +105,7 @@ where
};

#[cfg(feature = "_tls-rustls")]
let host = webpki::DNSNameRef::try_from_ascii_str(host)?;
let host = ::rustls::ServerName::try_from(host).map_err(|err| Error::Tls(err.into()))?;

*self = MaybeTlsStream::Tls(connector.connect(host, stream).await?);

Expand Down
87 changes: 58 additions & 29 deletions sqlx-core/src/net/tls/rustls.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::net::CertificateInput;
use rustls::{
Certificate, ClientConfig, RootCertStore, ServerCertVerified, ServerCertVerifier, TLSError,
WebPKIVerifier,
client::{ServerCertVerified, ServerCertVerifier, WebPkiVerifier},
ClientConfig, Error as TlsError, OwnedTrustAnchor, RootCertStore, ServerName,
};
use std::io::Cursor;
use std::sync::Arc;
use webpki::DNSNameRef;
use std::time::SystemTime;

use crate::error::Error;

Expand All @@ -14,32 +14,47 @@ pub async fn configure_tls_connector(
accept_invalid_hostnames: bool,
root_cert_path: Option<&CertificateInput>,
) -> Result<sqlx_rt::TlsConnector, Error> {
let mut config = ClientConfig::new();
let config = ClientConfig::builder().with_safe_defaults();

if accept_invalid_certs {
let config = if accept_invalid_certs {
config
.dangerous()
.set_certificate_verifier(Arc::new(DummyTlsVerifier));
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier))
.with_no_client_auth()
} else {
config
.root_store
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
let mut cert_store = RootCertStore::empty();
cert_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));

if let Some(ca) = root_cert_path {
let data = ca.data().await?;
let mut cursor = Cursor::new(data);
config
.root_store
.add_pem_file(&mut cursor)
.map_err(|_| Error::Tls(format!("Invalid certificate {}", ca).into()))?;

for cert in rustls_pemfile::certs(&mut cursor)
.map_err(|_| Error::Tls(format!("Invalid certificate {}", ca).into()))?
{
cert_store
.add(&rustls::Certificate(cert))
.map_err(|err| Error::Tls(err.into()))?;
}
}

if accept_invalid_hostnames {
let verifier = WebPkiVerifier::new(cert_store, None);

config
.with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier { verifier }))
.with_no_client_auth()
} else {
config
.dangerous()
.set_certificate_verifier(Arc::new(NoHostnameTlsVerifier));
.with_root_certificates(cert_store)
.with_no_client_auth()
}
}
};

Ok(Arc::new(config).into())
}
Expand All @@ -49,28 +64,42 @@ struct DummyTlsVerifier;
impl ServerCertVerifier for DummyTlsVerifier {
fn verify_server_cert(
&self,
_roots: &RootCertStore,
_presented_certs: &[Certificate],
_dns_name: DNSNameRef<'_>,
_end_entity: &rustls::Certificate,
_intermediates: &[rustls::Certificate],
_server_name: &ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp_response: &[u8],
) -> Result<ServerCertVerified, TLSError> {
_now: SystemTime,
) -> Result<ServerCertVerified, TlsError> {
Ok(ServerCertVerified::assertion())
}
}

pub struct NoHostnameTlsVerifier;
pub struct NoHostnameTlsVerifier {
verifier: WebPkiVerifier,
}

impl ServerCertVerifier for NoHostnameTlsVerifier {
fn verify_server_cert(
&self,
roots: &RootCertStore,
presented_certs: &[Certificate],
dns_name: DNSNameRef<'_>,
end_entity: &rustls::Certificate,
intermediates: &[rustls::Certificate],
server_name: &ServerName,
scts: &mut dyn Iterator<Item = &[u8]>,
ocsp_response: &[u8],
) -> Result<ServerCertVerified, TLSError> {
let verifier = WebPKIVerifier::new();
match verifier.verify_server_cert(roots, presented_certs, dns_name, ocsp_response) {
Err(TLSError::WebPKIError(webpki::Error::CertNotValidForName)) => {
now: SystemTime,
) -> Result<ServerCertVerified, TlsError> {
match self.verifier.verify_server_cert(
end_entity,
intermediates,
server_name,
scts,
ocsp_response,
now,
) {
Err(TlsError::InvalidCertificateData(reason))
if reason.contains("CertNotValidForName") =>
{
Ok(ServerCertVerified::assertion())
}
res => res,
Expand Down
6 changes: 3 additions & 3 deletions sqlx-rt/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ runtime-async-std-native-tls = [
runtime-tokio-native-tls = ["_rt-tokio", "_tls-native-tls", "tokio-native-tls"]

runtime-actix-rustls = ["_rt-actix", "_tls-rustls", "tokio-rustls"]
runtime-async-std-rustls = ["_rt-async-std", "_tls-rustls", "async-rustls"]
runtime-async-std-rustls = ["_rt-async-std", "_tls-rustls", "futures-rustls"]
runtime-tokio-rustls = ["_rt-tokio", "_tls-rustls", "tokio-rustls"]

# Not used directly and not re-exported from sqlx
Expand All @@ -32,11 +32,11 @@ _tls-rustls = []

[dependencies]
async-native-tls = { version = "0.3.3", optional = true }
async-rustls = { version = "0.2.0", optional = true }
futures-rustls = { version = "0.22.0", optional = true }
actix-rt = { version = "2.0.0", default-features = false, optional = true }
async-std = { version = "1.7.0", features = ["unstable"], optional = true }
tokio-native-tls = { version = "0.3.0", optional = true }
tokio-rustls = { version = "0.22.0", optional = true }
tokio-rustls = { version = "0.23.0", optional = true }
native-tls = { version = "0.2.4", optional = true }
once_cell = { version = "1.4", features = ["std"], optional = true }

Expand Down
2 changes: 1 addition & 1 deletion sqlx-rt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,4 @@ pub use async_native_tls::{TlsConnector, TlsStream};
feature = "_rt-actix"
)),
))]
pub use async_rustls::{client::TlsStream, TlsConnector};
pub use futures_rustls::{client::TlsStream, TlsConnector};