Skip to content

Commit

Permalink
Merge pull request blackbeam#245 from prisma/inline-reading-settings
Browse files Browse the repository at this point in the history
Inline ops that read settings
  • Loading branch information
blackbeam authored Apr 20, 2023
2 parents 73dbb96 + cd1ae04 commit 5843e91
Showing 1 changed file with 37 additions and 23 deletions.
60 changes: 37 additions & 23 deletions src/conn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use mysql_common::{
OldEofPacket, ResultSetTerminator, SslRequest,
},
proto::MySerialize,
row::Row,
};

use std::{
Expand Down Expand Up @@ -918,10 +919,8 @@ impl Conn {
conn.do_handshake_response().await?;
conn.continue_auth().await?;
conn.switch_to_compression()?;
conn.read_socket().await?;
conn.read_settings().await?;
conn.reconnect_via_socket_if_needed().await?;
conn.read_max_allowed_packet().await?;
conn.read_wait_timeout().await?;
conn.run_init_commands().await?;
conn.run_setup_commands().await?;

Expand Down Expand Up @@ -953,38 +952,53 @@ impl Conn {
Ok(())
}

/// Reads and stores socket address inside the connection.
/// Configures the connection based on server settings. In particular:
///
/// Do nothing if socket address is already in [`Opts`] or if `prefer_socket` is `false`.
async fn read_socket(&mut self) -> Result<()> {
if self.inner.opts.prefer_socket() && self.inner.socket.is_none() {
let row_opt = self.query_internal("SELECT @@socket").await?;
self.inner.socket = row_opt.unwrap_or(None);
/// * It reads and stores socket address inside the connection unless if socket address is
/// already in [`Opts`] or if `prefer_socket` is `false`.
///
/// * It reads and stores `max_allowed_packet` in the connection unless it's already in [`Opts`]
///
/// * It reads and stores `wait_timeout` in the connection unless it's already in [`Opts`]
///
async fn read_settings(&mut self) -> Result<()> {
let read_socket = self.inner.opts.prefer_socket() && self.inner.socket.is_none();
let read_max_allowed_packet = self.opts().max_allowed_packet().is_none();
let read_wait_timeout = self.opts().wait_timeout().is_none();

let settings: Option<Row> = if read_socket || read_max_allowed_packet || read_wait_timeout {
self.query_internal("SELECT @@socket, @@max_allowed_packet, @@wait_timeout")
.await?
} else {
None
};

// set socket inside the connection
if read_socket {
self.inner.socket = settings.as_ref().map(|s| s.get("@@socket")).unwrap_or(None);
}
Ok(())
}

/// Reads and stores `max_allowed_packet` in the connection.
async fn read_max_allowed_packet(&mut self) -> Result<()> {
let max_allowed_packet = if let Some(value) = self.opts().max_allowed_packet() {
Some(value)
// set max_allowed_packet
let max_allowed_packet = if read_max_allowed_packet {
settings
.as_ref()
.map(|s| s.get("@@max_allowed_packet"))
.unwrap()
} else {
self.query_internal("SELECT @@max_allowed_packet").await?
self.opts().max_allowed_packet()
};
if let Some(stream) = self.inner.stream.as_mut() {
stream.set_max_allowed_packet(max_allowed_packet.unwrap_or(DEFAULT_MAX_ALLOWED_PACKET));
}
Ok(())
}

/// Reads and stores `wait_timeout` in the connection.
async fn read_wait_timeout(&mut self) -> Result<()> {
let wait_timeout = if let Some(value) = self.opts().wait_timeout() {
Some(value)
// set read_wait_timeout
let wait_timeout = if read_wait_timeout {
settings.as_ref().map(|s| s.get("@@wait_timeout")).unwrap()
} else {
self.query_internal("SELECT @@wait_timeout").await?
self.opts().wait_timeout()
};
self.inner.wait_timeout = Duration::from_secs(wait_timeout.unwrap_or(28800) as u64);

Ok(())
}

Expand Down

0 comments on commit 5843e91

Please sign in to comment.