Skip to content

Commit

Permalink
Implement Conn::change_user
Browse files Browse the repository at this point in the history
  • Loading branch information
blackbeam committed Apr 12, 2023
1 parent ad90c52 commit 134cbf8
Show file tree
Hide file tree
Showing 9 changed files with 451 additions and 54 deletions.
233 changes: 206 additions & 27 deletions src/conn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ use crate::{
transaction::TxStatus,
BinaryProtocol, Queryable, TextProtocol,
},
BinlogStream, InfileData, OptsBuilder,
BinlogStream, ChangeUserOpts, InfileData, OptsBuilder,
};

use self::routines::Routine;
Expand Down Expand Up @@ -102,13 +102,15 @@ struct ConnInner {
pool: Option<Pool>,
pending_result: std::result::Result<Option<PendingResult>, ServerError>,
tx_status: TxStatus,
reset_upon_returning_to_a_pool: bool,
opts: Opts,
last_io: Instant,
wait_timeout: Duration,
stmt_cache: StmtCache,
nonce: Vec<u8>,
auth_plugin: AuthPlugin<'static>,
auth_switched: bool,
server_key: Option<Vec<u8>>,
/// Connection is already disconnected.
pub(crate) disconnected: bool,
/// One-time connection-level infile handler.
Expand All @@ -126,6 +128,8 @@ impl fmt::Debug for ConnInner {
.field("tx_status", &self.tx_status)
.field("stream", &self.stream)
.field("options", &self.opts)
.field("server_key", &self.server_key)
.field("auth_plugin", &self.auth_plugin)
.finish()
}
}
Expand Down Expand Up @@ -154,7 +158,9 @@ impl ConnInner {
auth_plugin: AuthPlugin::MysqlNativePassword,
auth_switched: false,
disconnected: false,
server_key: None,
infile_handler: None,
reset_upon_returning_to_a_pool: false,
}
}

Expand Down Expand Up @@ -416,16 +422,33 @@ impl Conn {
/// Returns true if io stream is encrypted.
fn is_secure(&self) -> bool {
#[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))]
if let Some(ref stream) = self.inner.stream {
stream.is_secure()
} else {
false
{
self.inner
.stream
.as_ref()
.map(|x| x.is_secure())
.unwrap_or_default()
}

#[cfg(not(any(feature = "native-tls-tls", feature = "rustls-tls")))]
false
}

/// Returns true if io stream is socket.
fn is_socket(&self) -> bool {
#[cfg(unix)]
{
self.inner
.stream
.as_ref()
.map(|x| x.is_socket())
.unwrap_or_default()
}

#[cfg(not(unix))]
false
}

/// Hacky way to move connection through &mut. `self` becomes unusable.
fn take(&mut self) -> Conn {
mem::replace(self, Conn::empty(Default::default()))
Expand Down Expand Up @@ -663,16 +686,21 @@ impl Conn {
let mut pass = crate::BUFFER_POOL.get_with(pass.as_bytes());
pass.as_mut().push(0);

if self.is_secure() {
if self.is_secure() || self.is_socket() {
self.write_packet(pass).await?;
} else {
self.write_bytes(&[0x02][..]).await?;
let packet = self.read_packet().await?;
let key = &packet[1..];
if self.inner.server_key.is_none() {
self.write_bytes(&[0x02][..]).await?;
let packet = self.read_packet().await?;
self.inner.server_key = Some(packet[1..].to_vec());
}
for (i, byte) in pass.as_mut().iter_mut().enumerate() {
*byte ^= self.inner.nonce[i % self.inner.nonce.len()];
}
let encrypted_pass = crypto::encrypt(&*pass, key);
let encrypted_pass = crypto::encrypt(
&*pass,
self.inner.server_key.as_deref().expect("unreachable"),
);
self.write_bytes(&*encrypted_pass).await?;
};
self.drop_packet().await?;
Expand Down Expand Up @@ -958,12 +986,13 @@ impl Conn {
self.inner.last_io.elapsed()
}

/// Executes `COM_RESET_CONNECTION` on `self`.
/// Executes [`COM_RESET_CONNECTION`][1].
///
/// If server version is older than 5.7.2, then it'll reconnect.
pub async fn reset(&mut self) -> Result<()> {
let pool = self.inner.pool.clone();

/// Returns `false` if command is not supported (requires MySql >5.7.2, MariaDb >10.2.3).
/// For older versions consider using [`Conn::change_user`].
///
/// [1]: https://dev.mysql.com/doc/c-api/5.7/en/mysql-reset-connection.html
pub async fn reset(&mut self) -> Result<bool> {
let supports_com_reset_connection = if self.inner.is_mariadb {
self.inner.version >= (10, 2, 4)
} else {
Expand All @@ -973,19 +1002,62 @@ impl Conn {

if supports_com_reset_connection {
self.routine(routines::ResetRoutine).await?;
} else {
let opts = self.inner.opts.clone();
let old_conn = std::mem::replace(self, Conn::new(opts).await?);
// tidy up the old connection
old_conn.close_conn().await?;
};
self.inner.stmt_cache.clear();
self.inner.infile_handler = None;
}

Ok(supports_com_reset_connection)
}

/// Executes [`COM_CHANGE_USER`][1].
///
/// This might be used as an older and slower alternative to `COM_RESET_CONNECTION` that
/// works on MySql prior to 5.7.3 (MariaDb prior ot 10.2.4).
///
/// ## Note
///
/// * Using non-default `opts` for a pooled connection is discouraging.
/// * Connection options will be permanently updated.
///
/// [1]: https://dev.mysql.com/doc/c-api/5.7/en/mysql-change-user.html
pub async fn change_user(&mut self, opts: ChangeUserOpts) -> Result<()> {
// We'll kick this connection from a pool if opts are changed.
if opts != ChangeUserOpts::default() {
let mut opts_changed = false;
if let Some(user) = opts.user() {
opts_changed |= user != self.opts().user()
};
if let Some(pass) = opts.pass() {
opts_changed |= pass != self.opts().pass()
};
if let Some(db_name) = opts.db_name() {
opts_changed |= db_name != self.opts().db_name()
};
if opts_changed {
if let Some(pool) = self.inner.pool.take() {
pool.cancel_connection();
}
}
}

let conn_opts = &mut self.inner.opts;
opts.update_opts(conn_opts);
self.routine(routines::ChangeUser).await?;
self.inner.stmt_cache.clear();
self.inner.infile_handler = None;
self.inner.pool = pool;
Ok(())
}

/// Resets the connection upon returning it to a pool.
///
/// Will invoke `COM_CHANGE_USER` if `COM_RESET_CONNECTION` is not supported.
async fn reset_for_pool(mut self) -> Result<Self> {
if !self.reset().await? {
self.change_user(Default::default()).await?;
}
Ok(self)
}

/// Requires that `self.inner.tx_status != TxStatus::None`
async fn rollback_transaction(&mut self) -> Result<()> {
debug_assert_ne!(self.inner.tx_status, TxStatus::None);
Expand Down Expand Up @@ -1094,13 +1166,14 @@ mod test {
use bytes::Bytes;
use futures_util::stream::{self, StreamExt};
use mysql_common::{binlog::events::EventData, constants::MAX_PAYLOAD_LEN};
use rand::Fill;
use tokio::time::timeout;

use std::time::Duration;

use crate::{
from_row, params, prelude::*, test_misc::get_opts, BinlogDumpFlags, BinlogRequest, Conn,
Error, OptsBuilder, Pool, WhiteListFsHandler,
from_row, params, prelude::*, test_misc::get_opts, BinlogDumpFlags, BinlogRequest,
ChangeUserOpts, Conn, Error, OptsBuilder, Pool, Value, WhiteListFsHandler,
};

async fn gen_dummy_data() -> super::Result<()> {
Expand Down Expand Up @@ -1471,9 +1544,115 @@ mod test {
#[tokio::test]
async fn should_reset_the_connection() -> super::Result<()> {
let mut conn = Conn::new(get_opts()).await?;
conn.exec_drop("SELECT ?", (1_u8,)).await?;
conn.reset().await?;
conn.exec_drop("SELECT ?", (1_u8,)).await?;
let max_execution_time = conn
.query_first::<u64, _>("SELECT @@max_execution_time")
.await?
.unwrap();

conn.exec_drop(
"SET SESSION max_execution_time = ?",
(max_execution_time + 1,),
)
.await?;

assert_eq!(
conn.query_first::<u64, _>("SELECT @@max_execution_time")
.await?,
Some(max_execution_time + 1)
);

if conn.reset().await? {
assert_eq!(
conn.query_first::<u64, _>("SELECT @@max_execution_time")
.await?,
Some(max_execution_time)
);
} else {
assert_eq!(
conn.query_first::<u64, _>("SELECT @@max_execution_time")
.await?,
Some(max_execution_time + 1)
);
}

conn.disconnect().await?;
Ok(())
}

#[tokio::test]
async fn should_change_user() -> super::Result<()> {
let mut conn = Conn::new(get_opts()).await?;
let max_execution_time = conn
.query_first::<u64, _>("SELECT @@max_execution_time")
.await?
.unwrap();

conn.exec_drop(
"SET SESSION max_execution_time = ?",
(max_execution_time + 1,),
)
.await?;

assert_eq!(
conn.query_first::<u64, _>("SELECT @@max_execution_time")
.await?,
Some(max_execution_time + 1)
);

conn.change_user(Default::default()).await?;
assert_eq!(
conn.query_first::<u64, _>("SELECT @@max_execution_time")
.await?,
Some(max_execution_time)
);

let plugins: &[&str] = if !conn.inner.is_mariadb && conn.server_version() >= (5, 8, 0) {
&["mysql_native_password", "caching_sha2_password"]
} else {
&["mysql_native_password"]
};

for plugin in plugins {
let mut conn2 = Conn::new(get_opts()).await.unwrap();

let mut rng = rand::thread_rng();
let mut pass = [0u8; 10];
pass.try_fill(&mut rng).unwrap();
let pass: String = IntoIterator::into_iter(pass)
.map(|x| ((x % (123 - 97)) + 97) as char)
.collect();
conn.query_drop("DROP USER IF EXISTS __mysql_async_test_user")
.await
.unwrap();
conn.query_drop(format!(
"CREATE USER '__mysql_async_test_user'@'%' IDENTIFIED WITH {} BY {}",
plugin,
Value::from(pass.clone()).as_sql(false)
))
.await
.unwrap();
conn.query_drop("FLUSH PRIVILEGES").await.unwrap();

conn2
.change_user(
ChangeUserOpts::default()
.with_db_name(None)
.with_user(Some("__mysql_async_test_user".into()))
.with_pass(Some(pass)),
)
.await
.unwrap();
assert_eq!(
conn2
.query_first::<(Option<String>, String), _>("SELECT DATABASE(), USER();")
.await
.unwrap(),
Some((None, String::from("__mysql_async_test_user@localhost"))),
);

conn2.disconnect().await.unwrap();
}

conn.disconnect().await?;
Ok(())
}
Expand Down
14 changes: 10 additions & 4 deletions src/conn/pool/futures/get_conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,18 @@ pub struct GetConn {
pub(crate) queue_id: Option<QueueId>,
pub(crate) pool: Option<Pool>,
pub(crate) inner: GetConnInner,
reset_upon_returning_to_a_pool: bool,
#[cfg(feature = "tracing")]
span: Arc<Span>,
}

impl GetConn {
pub(crate) fn new(pool: &Pool) -> GetConn {
pub(crate) fn new(pool: &Pool, reset_upon_returning_to_a_pool: bool) -> GetConn {
GetConn {
queue_id: None,
pool: Some(pool.clone()),
inner: GetConnInner::New,
reset_upon_returning_to_a_pool,
#[cfg(feature = "tracing")]
span: Arc::new(debug_span!("mysql_async::get_conn")),
}
Expand Down Expand Up @@ -141,6 +143,8 @@ impl Future for GetConn {
return match result {
Ok(mut c) => {
c.inner.pool = Some(pool);
c.inner.reset_upon_returning_to_a_pool =
self.reset_upon_returning_to_a_pool;
Poll::Ready(Ok(c))
}
Err(e) => {
Expand All @@ -152,12 +156,14 @@ impl Future for GetConn {
GetConnInner::Checking(ref mut f) => {
let result = ready!(Pin::new(f).poll(cx));
match result {
Ok(mut checked_conn) => {
Ok(mut c) => {
self.inner = GetConnInner::Done;

let pool = self.pool_take();
checked_conn.inner.pool = Some(pool);
return Poll::Ready(Ok(checked_conn));
c.inner.pool = Some(pool);
c.inner.reset_upon_returning_to_a_pool =
self.reset_upon_returning_to_a_pool;
return Poll::Ready(Ok(c));
}
Err(_) => {
// Idling connection is broken. We'll drop it and try again.
Expand Down
Loading

0 comments on commit 134cbf8

Please sign in to comment.