Skip to content

Commit

Permalink
feat:add tls support fot memcached
Browse files Browse the repository at this point in the history
  • Loading branch information
wangrui committed Dec 27, 2024
1 parent 86d6c42 commit df9e099
Show file tree
Hide file tree
Showing 5 changed files with 351 additions and 37 deletions.
32 changes: 17 additions & 15 deletions core/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 12 additions & 1 deletion core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,12 @@ services-ipmfs = []
services-koofr = []
services-lakefs = []
services-libsql = ["dep:hrana-client-proto"]
services-memcached = ["dep:bb8"]
services-memcached = [
"dep:bb8",
"dep:tokio-rustls",
"dep:webpki-roots",
"dep:rustls",
]
services-memory = []
services-mini-moka = ["dep:mini-moka"]
services-moka = ["dep:moka"]
Expand Down Expand Up @@ -359,6 +364,12 @@ monoio = { version = "0.2.4", optional = true, features = [
"unlinkat",
"renameat",
] }
# for services-memcached
rustls = { version = "0.23.15", default-features = false, features = [
"std",
], optional = true }
tokio-rustls = { version = "0.26.1", optional = true }
webpki-roots = { version = "0.26.7", optional = true }

# Layers
# for layers-async-backtrace
Expand Down
128 changes: 116 additions & 12 deletions core/src/services/memcached/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,23 @@
// specific language governing permissions and limitations
// under the License.

use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;

use bb8::RunError;
use tokio::net::TcpStream;
use tokio::sync::OnceCell;

use super::binary;
use crate::raw::adapters::kv;
use crate::raw::*;
use crate::services::MemcachedConfig;
use crate::*;

use bb8::RunError;
use rustls::pki_types::pem::PemObject;
use rustls::pki_types::{CertificateDer, ServerName};
use tokio::net::TcpStream;
use tokio::sync::OnceCell;
use tokio_rustls::TlsConnector;

impl Configurator for MemcachedConfig {
type Builder = MemcachedBuilder;
fn into_builder(self) -> Self::Builder {
Expand Down Expand Up @@ -82,6 +87,18 @@ impl MemcachedBuilder {
self.config.default_ttl = Some(ttl);
self
}

/// Set the tls connect on.
pub fn tls(mut self, tls: bool) -> Self {
self.config.tls = Some(tls);
self
}

/// Set the tls connect on.
pub fn cafile(mut self, cafile: PathBuf) -> Self {
self.config.cafile = Some(cafile);
self
}
}

impl Builder for MemcachedBuilder {
Expand Down Expand Up @@ -126,6 +143,14 @@ impl Builder for MemcachedBuilder {
.with_context("endpoint", &endpoint),
);
};
if self.config.tls.unwrap_or(false) {
ServerName::try_from(host.clone()).map_err(|err| {
Error::new(ErrorKind::ConfigInvalid, "Invalid dns name error")
.with_context("service", Scheme::Memcached)
.with_context("host", &host)
.set_source(err)
})?;
}
let port = if let Some(port) = uri.port_u16() {
port
} else {
Expand All @@ -150,6 +175,9 @@ impl Builder for MemcachedBuilder {
endpoint,
username: self.config.username.clone(),
password: self.config.password.clone(),
tls: self.config.tls.clone(),
cafile: self.config.cafile.clone(),
host,
conn,
default_ttl: self.config.default_ttl,
})
Expand All @@ -166,6 +194,9 @@ pub struct Adapter {
username: Option<String>,
password: Option<String>,
default_ttl: Option<Duration>,
tls: Option<bool>,
cafile: Option<PathBuf>,
host: String,
conn: OnceCell<bb8::Pool<MemcacheConnectionManager>>,
}

Expand All @@ -178,6 +209,9 @@ impl Adapter {
&self.endpoint,
self.username.clone(),
self.password.clone(),
self.tls.clone(),
self.cafile.clone(),
&self.host,
);

bb8::Pool::builder().build(mgr).await.map_err(|err| {
Expand Down Expand Up @@ -246,14 +280,27 @@ struct MemcacheConnectionManager {
address: String,
username: Option<String>,
password: Option<String>,
tls: Option<bool>,
cafile: Option<PathBuf>,
host: String,
}

impl MemcacheConnectionManager {
fn new(address: &str, username: Option<String>, password: Option<String>) -> Self {
fn new(
address: &str,
username: Option<String>,
password: Option<String>,
tls: Option<bool>,
cafile: Option<PathBuf>,
host: &str,
) -> Self {
Self {
address: address.to_string(),
username,
password,
tls,
cafile,
host: host.to_string(),
}
}
}
Expand All @@ -265,14 +312,71 @@ impl bb8::ManageConnection for MemcacheConnectionManager {

/// TODO: Implement unix stream support.
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
let conn = TcpStream::connect(&self.address)
.await
.map_err(new_std_io_error)?;
let mut conn = binary::Connection::new(conn);
let conn = if self.tls.unwrap_or(false) {
let mut root_cert_store = rustls::RootCertStore::empty();

if let Some(cafile) = &self.cafile {
for cert in CertificateDer::pem_file_iter(cafile).map_err(|err| match err {
rustls::pki_types::pem::Error::Io(err) => new_std_io_error(err),
_ => unreachable!(),
})? {
root_cert_store
.add(cert.map_err(|err| match err {
rustls::pki_types::pem::Error::Io(err) => new_std_io_error(err),
_ => unreachable!(),
})?)
.map_err(|err| {
Error::new(ErrorKind::Unexpected, "tls connect failed").set_source(err)
})?;
}
} else {
root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
}

if let (Some(username), Some(password)) = (self.username.as_ref(), self.password.as_ref()) {
conn.auth(username, password).await?;
}
let config = rustls::ClientConfig::builder()
.with_root_certificates(root_cert_store)
.with_no_client_auth();

let connector = TlsConnector::from(Arc::new(config));
let conn = TcpStream::connect(&self.address)
.await
.map_err(new_std_io_error)?;
let domain = ServerName::try_from(self.host.as_str())
.map_err(|err| {
Error::new(ErrorKind::ConfigInvalid, "Invalid dns name error")
.with_context("service", Scheme::Memcached)
.with_context("address", &self.address)
.set_source(err)
})?
.to_owned();

let conn = connector.connect(domain, conn).await.map_err(|err| {
Error::new(ErrorKind::Unexpected, "tls connect failed").set_source(err)
})?;

let mut conn = binary::TlsConnection::new(conn);

if let (Some(username), Some(password)) =
(self.username.as_ref(), self.password.as_ref())
{
conn.auth(username, password).await?;
}
binary::Connection::Tls(conn)
} else {
let conn = TcpStream::connect(&self.address)
.await
.map_err(new_std_io_error)?;

let mut conn = binary::TcpConnection::new(conn);

if let (Some(username), Some(password)) =
(self.username.as_ref(), self.password.as_ref())
{
conn.auth(username, password).await?;
}

binary::Connection::Tcp(conn)
};
Ok(conn)
}

Expand Down
Loading

0 comments on commit df9e099

Please sign in to comment.