Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(mysql): support packet splitting #2665

Merged
merged 4 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions sqlx-mysql/src/connection/stream.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::VecDeque;
use std::ops::{Deref, DerefMut};

use bytes::{Buf, Bytes};
use bytes::{Buf, Bytes, BytesMut};

use crate::collation::{CharSet, Collation};
use crate::error::Error;
Expand Down Expand Up @@ -126,9 +126,7 @@ impl<S: Socket> MySqlStream<S> {
.write_with(Packet(payload), (self.capabilities, &mut self.sequence_id));
}

// receive the next packet from the database server
// may block (async) on more data from the server
pub(crate) async fn recv_packet(&mut self) -> Result<Packet<Bytes>, Error> {
async fn recv_packet_part(&mut self) -> Result<Bytes, Error> {
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html
// https://mariadb.com/kb/en/library/0-packet/#standard-packet

Expand All @@ -142,10 +140,33 @@ impl<S: Socket> MySqlStream<S> {
let payload: Bytes = self.socket.read(packet_size).await?;

// TODO: packet compression
// TODO: packet joining

Ok(payload)
}

// receive the next packet from the database server
// may block (async) on more data from the server
pub(crate) async fn recv_packet(&mut self) -> Result<Packet<Bytes>, Error> {
let payload = self.recv_packet_part().await?;
let payload = if payload.len() < 0xFF_FF_FF {
payload
} else {
let mut final_payload = BytesMut::with_capacity(0xFF_FF_FF * 2);
final_payload.extend_from_slice(&payload);

drop(payload); // we don't need the allocation anymore

let mut last_read = 0xFF_FF_FF;
while last_read == 0xFF_FF_FF {
let part = self.recv_packet_part().await?;
last_read = part.len();
final_payload.extend_from_slice(&part);
}
final_payload.into()
};

if payload
.get(0)
.first()
.ok_or(err_protocol!("Packet empty"))?
.eq(&0xff)
{
Expand Down
31 changes: 26 additions & 5 deletions sqlx-mysql/src/protocol/packet.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::cmp::min;
use std::ops::{Deref, DerefMut};

use bytes::Bytes;
Expand All @@ -19,6 +20,14 @@ where
buf: &mut Vec<u8>,
(capabilities, sequence_id): (Capabilities, &'stream mut u8),
) {
let mut next_header = |len: u32| {
let mut buf = len.to_le_bytes();
buf[3] = *sequence_id;
*sequence_id = sequence_id.wrapping_add(1);

buf
};

// reserve space to write the prefixed length
let offset = buf.len();
buf.extend(&[0_u8; 4]);
Expand All @@ -31,13 +40,25 @@ where
let len = buf.len() - offset - 4;
let header = &mut buf[offset..];

// FIXME: Support larger packets
assert!(len < 0xFF_FF_FF);
header[..4].copy_from_slice(&next_header(min(len, 0xFF_FF_FF) as u32));

header[..4].copy_from_slice(&(len as u32).to_le_bytes());
header[3] = *sequence_id;
// add more packets if we need to split the data
if len >= 0xFF_FF_FF {
let rest = buf.split_off(offset + 4 + 0xFF_FF_FF);
let mut chunks = rest.chunks_exact(0xFF_FF_FF);

*sequence_id = sequence_id.wrapping_add(1);
for chunk in chunks.by_ref() {
buf.reserve(chunk.len() + 4);
buf.extend(&next_header(chunk.len() as u32));
buf.extend(chunk);
}

// this will also handle adding a zero sized packet if the data size is a multiple of 0xFF_FF_FF
let remainder = chunks.remainder();
buf.reserve(remainder.len() + 4);
buf.extend(&next_header(remainder.len() as u32));
buf.extend(remainder);
}
}
}

Expand Down
33 changes: 33 additions & 0 deletions tests/mysql/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,39 @@ async fn it_can_work_with_transactions() -> anyhow::Result<()> {
Ok(())
}

#[sqlx_macros::test]
async fn it_can_handle_split_packets() -> anyhow::Result<()> {
// This will only take effect on new connections
new::<MySql>()
.await?
.execute("SET GLOBAL max_allowed_packet = 4294967297")
.await?;

let mut conn = new::<MySql>().await?;

conn.execute(
r#"
CREATE TEMPORARY TABLE large_table (data LONGBLOB);
"#,
)
.await?;

let data = vec![0x41; 0xFF_FF_FF * 2];

sqlx::query("INSERT INTO large_table (data) VALUES (?)")
.bind(&data)
.execute(&mut conn)
.await?;

let ret: Vec<u8> = sqlx::query_scalar("SELECT * FROM large_table")
.fetch_one(&mut conn)
.await?;

assert_eq!(ret, data);

Ok(())
}

#[sqlx_macros::test]
async fn test_shrink_buffers() -> anyhow::Result<()> {
// We don't really have a good way to test that `.shrink_buffers()` functions as expected
Expand Down
Loading