Skip to content

Commit

Permalink
Merge pull request blackbeam#254 from blackbeam/issue-253
Browse files Browse the repository at this point in the history
Do not read unnecessary settings in Conn::read_settings
  • Loading branch information
blackbeam authored Aug 7, 2023
2 parents ddee16e + 668a7e4 commit f3cdaa0
Showing 1 changed file with 111 additions and 28 deletions.
139 changes: 111 additions & 28 deletions src/conn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ pub mod pool;
pub mod routines;
pub mod stmt_cache;

const DEFAULT_WAIT_TIMEOUT: usize = 28800;

/// Helper that asynchronously disconnects the givent connection on the default tokio executor.
fn disconnect(mut conn: Conn) {
let disconnected = conn.inner.disconnected;
Expand Down Expand Up @@ -962,42 +964,123 @@ impl Conn {
/// * It reads and stores `wait_timeout` in the connection unless it's already in [`Opts`]
///
async fn read_settings(&mut self) -> Result<()> {
let read_socket = self.inner.opts.prefer_socket() && self.inner.socket.is_none();
let read_max_allowed_packet = self.opts().max_allowed_packet().is_none();
let read_wait_timeout = self.opts().wait_timeout().is_none();
enum Action {
Load(Cfg),
Apply(CfgData),
}

let settings: Option<Row> = if read_socket || read_max_allowed_packet || read_wait_timeout {
self.query_internal("SELECT @@socket, @@max_allowed_packet, @@wait_timeout")
.await?
} else {
None
};
enum CfgData {
MaxAllowedPacket(usize),
WaitTimeout(usize),
}

// set socket inside the connection
if read_socket {
self.inner.socket = settings.as_ref().map(|s| s.get("@@socket")).unwrap_or(None);
impl CfgData {
fn apply(&self, conn: &mut Conn) {
match self {
Self::MaxAllowedPacket(value) => {
if let Some(stream) = conn.inner.stream.as_mut() {
stream.set_max_allowed_packet(*value);
}
}
Self::WaitTimeout(value) => {
conn.inner.wait_timeout = Duration::from_secs(*value as u64);
}
}
}
}

// set max_allowed_packet
let max_allowed_packet = if read_max_allowed_packet {
settings
.as_ref()
.map(|s| s.get("@@max_allowed_packet"))
.unwrap()
} else {
self.opts().max_allowed_packet()
};
if let Some(stream) = self.inner.stream.as_mut() {
stream.set_max_allowed_packet(max_allowed_packet.unwrap_or(DEFAULT_MAX_ALLOWED_PACKET));
enum Cfg {
Socket,
MaxAllowedPacket,
WaitTimeout,
}

// set read_wait_timeout
let wait_timeout = if read_wait_timeout {
settings.as_ref().map(|s| s.get("@@wait_timeout")).unwrap()
impl Cfg {
const fn name(&self) -> &'static str {
match self {
Self::Socket => "@@socket",
Self::MaxAllowedPacket => "@@max_allowed_packet",
Self::WaitTimeout => "@@wait_timeout",
}
}

fn apply(&self, conn: &mut Conn, value: Option<crate::Value>) {
match self {
Cfg::Socket => {
conn.inner.socket = value.map(crate::from_value).flatten();
}
Cfg::MaxAllowedPacket => {
if let Some(stream) = conn.inner.stream.as_mut() {
stream.set_max_allowed_packet(
value
.map(crate::from_value)
.flatten()
.unwrap_or(DEFAULT_MAX_ALLOWED_PACKET),
);
}
}
Cfg::WaitTimeout => {
conn.inner.wait_timeout = Duration::from_secs(
value
.map(crate::from_value)
.flatten()
.unwrap_or(DEFAULT_WAIT_TIMEOUT) as u64,
);
}
}
}
}

let mut actions = vec![
if let Some(x) = self.opts().max_allowed_packet() {
Action::Apply(CfgData::MaxAllowedPacket(x))
} else {
Action::Load(Cfg::MaxAllowedPacket)
},
if let Some(x) = self.opts().wait_timeout() {
Action::Apply(CfgData::WaitTimeout(x))
} else {
Action::Load(Cfg::WaitTimeout)
},
];

if self.inner.opts.prefer_socket() && self.inner.socket.is_none() {
actions.push(Action::Load(Cfg::Socket))
}

let loads = actions
.iter()
.filter_map(|x| match x {
Action::Load(x) => Some(x),
Action::Apply(_) => None,
})
.collect::<Vec<_>>();

let loaded = if !loads.is_empty() {
let query = loads
.iter()
.zip(std::iter::once(' ').chain(std::iter::repeat(',')))
.fold("SELECT".to_owned(), |mut acc, (cfg, prefix)| {
acc.push(prefix);
acc.push_str(cfg.name());
acc
});

self.query_internal::<Row, String>(query)
.await?
.map(|row| row.unwrap())
.unwrap_or_else(|| vec![crate::Value::NULL; loads.len()])
} else {
self.opts().wait_timeout()
vec![]
};
self.inner.wait_timeout = Duration::from_secs(wait_timeout.unwrap_or(28800) as u64);
let mut loaded = loaded.into_iter();

for action in actions {
match action {
Action::Load(cfg) => cfg.apply(self, loaded.next()),
Action::Apply(cfg) => cfg.apply(self),
}
}

Ok(())
}
Expand Down

0 comments on commit f3cdaa0

Please sign in to comment.