Skip to content

Commit

Permalink
first step to support custom serializer
Browse files Browse the repository at this point in the history
  • Loading branch information
shenjackyuanjie committed Apr 17, 2024
1 parent fe3c3d8 commit 9529d64
Show file tree
Hide file tree
Showing 14 changed files with 199 additions and 47 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions engineio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ license = "MIT"
all-features = true

[dependencies]
base64 = "0.21.5"
base64 = "0.22.0"
bytes = "1"
reqwest = { version = "0.12.3", features = ["blocking", "native-tls", "stream"] }
adler32 = "1.2.0"
Expand All @@ -29,7 +29,7 @@ async-trait = "0.1.79"
async-stream = "0.3.5"
thiserror = "1.0"
native-tls = "0.2.11"
url = "2.4.1"
url = "2.5.0"

[dev-dependencies]
criterion = { version = "0.5.1", features = ["async_tokio"] }
Expand Down
17 changes: 2 additions & 15 deletions engineio/src/asynchronous/async_transports/polling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,24 +101,11 @@ impl Stream for PollingTransport {

#[async_trait]
impl AsyncTransport for PollingTransport {
async fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()> {
let data_to_send = if is_binary_att {
// the binary attachment gets `base64` encoded
let mut packet_bytes = BytesMut::with_capacity(data.len() + 1);
packet_bytes.put_u8(b'b');

let encoded_data = general_purpose::STANDARD.encode(data);
packet_bytes.put(encoded_data.as_bytes());

packet_bytes.freeze()
} else {
data
};

async fn emit(&self, data: Bytes) -> Result<()> {
let status = self
.client
.post(self.address().await?)
.body(data_to_send)
.body(data)
.send()
.await?
.status()
Expand Down
14 changes: 13 additions & 1 deletion engineio/src/client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::transport::Transport;

use crate::error::{Error, Result};
use crate::header::HeaderMap;
use crate::packet::{HandshakePacket, Packet, PacketId};
use crate::packet::{HandshakePacket, Packet, PacketId, PacketSerializer};
use crate::transports::{PollingTransport, WebsocketSecureTransport, WebsocketTransport};
use crate::ENGINE_IO_VERSION;
use bytes::Bytes;
Expand All @@ -32,6 +32,7 @@ pub struct ClientBuilder {
url: Url,
tls_config: Option<TlsConnector>,
headers: Option<HeaderMap>,
serializer: PacketSerializer,
handshake: Option<HandshakePacket>,
on_error: OptionalCallback<String>,
on_open: OptionalCallback<()>,
Expand All @@ -55,6 +56,7 @@ impl ClientBuilder {
headers: None,
tls_config: None,
handshake: None,
serializer: PacketSerializer::default(),
on_close: OptionalCallback::default(),
on_data: OptionalCallback::default(),
on_error: OptionalCallback::default(),
Expand All @@ -63,6 +65,13 @@ impl ClientBuilder {
}
}

/// Specify Packet Serializer
pub fn packet_serializer(mut self, packet_serializer: PacketSerializer) -> Self {
self.serializer = packet_serializer;

self
}

/// Specify transport's tls config
pub fn tls_config(mut self, tls_config: TlsConnector) -> Self {
self.tls_config = Some(tls_config);
Expand Down Expand Up @@ -183,6 +192,7 @@ impl ClientBuilder {
Ok(Client {
socket: InnerSocket::new(
transport.into(),
self.serializer,
self.handshake.unwrap(),
self.on_close,
self.on_data,
Expand Down Expand Up @@ -228,6 +238,7 @@ impl ClientBuilder {
Ok(Client {
socket: InnerSocket::new(
transport.into(),
self.serializer,
self.handshake.unwrap(),
self.on_close,
self.on_data,
Expand All @@ -250,6 +261,7 @@ impl ClientBuilder {
Ok(Client {
socket: InnerSocket::new(
transport.into(),
self.serializer,
self.handshake.unwrap(),
self.on_close,
self.on_data,
Expand Down
1 change: 0 additions & 1 deletion engineio/src/header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use http::{
HeaderValue as HttpHeaderValue,
};
use std::collections::HashMap;
use std::convert::TryFrom;
use std::fmt::{Display, Formatter, Result as FmtResult};
use std::str::FromStr;

Expand Down
86 changes: 84 additions & 2 deletions engineio/src/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,94 @@ use base64::{engine::general_purpose, Engine as _};
use bytes::{BufMut, Bytes, BytesMut};
use serde::{Deserialize, Serialize};
use std::char;
use std::convert::TryFrom;
use std::convert::TryInto;
use std::fmt::{Display, Formatter, Result as FmtResult, Write};
use std::ops::Index;

use crate::error::{Error, Result};

pub struct PacketSerializer {
decode: Box<dyn Fn(Bytes) -> Result<Packet> + Send + Sync>,
encode: Box<dyn Fn(Packet) -> Bytes + Send + Sync>,
}

fn default_decode(bytes: Bytes) -> Result<Packet> {
if bytes.is_empty() {
return Err(Error::IncompletePacket());
}

let is_base64 = *bytes.first().ok_or(Error::IncompletePacket())? == b'b';

// only 'messages' packets could be encoded
let packet_id = if is_base64 {
PacketId::MessageBinary
} else {
(*bytes.first().ok_or(Error::IncompletePacket())?).try_into()?
};

if bytes.len() == 1 && packet_id == PacketId::Message {
return Err(Error::IncompletePacket());
}

let data: Bytes = bytes.slice(1..);

Ok(Packet {
packet_id,
data: if is_base64 {
Bytes::from(general_purpose::STANDARD.decode(data.as_ref())?)
} else {
data
},
})
}

fn default_encode(packet: Packet) -> Bytes {
let mut result = BytesMut::with_capacity(packet.data.len() + 1);
result.put_u8(packet.packet_id.to_string_byte());
if packet.packet_id == PacketId::MessageBinary {
result.extend(general_purpose::STANDARD.encode(packet.data).into_bytes());
} else {
result.put(packet.data);
}
result.freeze()
}


impl PacketSerializer {
const SEPARATOR: char = '\x1e';

pub fn new(
decode: Box<dyn Fn(Bytes) -> Result<Packet> + Send + Sync>,
encode: Box<dyn Fn(Packet) -> Bytes + Send + Sync>,
) -> Self {
Self {
decode,
encode,
}
}

pub fn default() -> Self {
let decode = Box::new(default_decode);
let encode = Box::new(default_encode);
Self::new(decode, encode)
}

pub fn decode(&self, datas: Bytes) -> Result<Packet> {
(self.decode)(datas)
}

pub fn decode_payload(&self, datas: Bytes) -> Result<Payload> {
datas
.split(|&c| c as char == PacketSerializer::SEPARATOR)
.map(|slice| self.decode(datas.slice_ref(slice)))
.collect::<Result<Vec<Packet>>>()
.map(Payload)
}

pub fn encode(&self, packet: Packet) -> Bytes {
(self.encode)(packet)
}
}

/// Enumeration of the `engine.io` `Packet` types.
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub enum PacketId {
Expand Down
Empty file.
Empty file added engineio/src/packet/normal.rs
Empty file.
9 changes: 6 additions & 3 deletions engineio/src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ use crate::callback::OptionalCallback;
use crate::transport::TransportType;

use crate::error::{Error, Result};
use crate::packet::{HandshakePacket, Packet, PacketId, Payload};
use crate::packet::{HandshakePacket, Packet, PacketId, PacketSerializer, Payload};
use bytes::Bytes;
use std::convert::TryFrom;
use std::sync::RwLock;
use std::time::Duration;
use std::{fmt::Debug, sync::atomic::Ordering};
Expand All @@ -23,6 +22,7 @@ pub const DEFAULT_MAX_POLL_TIMEOUT: Duration = Duration::from_secs(45);
#[derive(Clone)]
pub struct Socket {
transport: Arc<TransportType>,
serializer: PacketSerializer,
on_close: OptionalCallback<()>,
on_data: OptionalCallback<Bytes>,
on_error: OptionalCallback<String>,
Expand All @@ -40,6 +40,7 @@ pub struct Socket {
impl Socket {
pub(crate) fn new(
transport: TransportType,
serializer: PacketSerializer,
handshake: HandshakePacket,
on_close: OptionalCallback<()>,
on_data: OptionalCallback<Bytes>,
Expand All @@ -56,6 +57,7 @@ impl Socket {
on_open,
on_packet,
transport: Arc::new(transport),
serializer,
connected: Arc::new(AtomicBool::default()),
last_ping: Arc::new(Mutex::new(Instant::now())),
last_pong: Arc::new(Mutex::new(Instant::now())),
Expand Down Expand Up @@ -148,7 +150,8 @@ impl Socket {
continue;
}

let payload = Payload::try_from(data)?;
// let payload = Payload::try_from(data)?;
let payload = self.serializer.decode_payload(data)?;
let mut iter = payload.into_iter();

if let Some(packet) = iter.next() {
Expand Down
16 changes: 2 additions & 14 deletions engineio/src/transports/polling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,11 @@ impl PollingTransport {
}

impl Transport for PollingTransport {
fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()> {
let data_to_send = if is_binary_att {
// the binary attachment gets `base64` encoded
let mut packet_bytes = BytesMut::with_capacity(data.len() + 1);
packet_bytes.put_u8(b'b');

let encoded_data = general_purpose::STANDARD.encode(data);
packet_bytes.put(encoded_data.as_bytes());

packet_bytes.freeze()
} else {
data
};
fn emit(&self, data: Bytes, _is_binary_att: bool) -> Result<()> {
let status = self
.client
.post(self.address()?)
.body(data_to_send)
.body(data)
.send()?
.status()
.as_u16();
Expand Down
29 changes: 27 additions & 2 deletions socketio/src/asynchronous/client/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use rust_engineio::{
use std::collections::HashMap;
use url::Url;

use crate::{error::Result, Event, Payload, TransportType};
use crate::{error::Result, Event, PacketSerializer, Payload, TransportType};

use super::{
callback::{
Expand All @@ -31,6 +31,7 @@ pub struct ClientBuilder {
tls_config: Option<TlsConnector>,
opening_headers: Option<HeaderMap>,
transport_type: TransportType,
packet_serializer: PacketSerializer,
pub(crate) auth: Option<serde_json::Value>,
pub(crate) reconnect: bool,
pub(crate) reconnect_on_disconnect: bool,
Expand Down Expand Up @@ -89,7 +90,8 @@ impl ClientBuilder {
namespace: "/".to_owned(),
tls_config: None,
opening_headers: None,
transport_type: TransportType::Any,
transport_type: TransportType::default(),
packet_serializer: PacketSerializer::default(),
auth: None,
reconnect: true,
reconnect_on_disconnect: false,
Expand Down Expand Up @@ -395,6 +397,29 @@ impl ClientBuilder {
self
}

/// Specifies the [`PacketSerializer`] to use for encoding and decoding packets.
///
/// # Example
/// ```rust
/// use rust_socketio::{asynchronous::ClientBuilder, PacketSerializer};
///
/// #[tokio::main]
/// async fn main() {
/// let socket = ClientBuilder::new("http://localhost:4200/")
/// .namespace("/admin")
/// .on("error", |err, _| async move { eprintln!("Error: {:#?}", err) }.boxed())
/// .packet_serializer(PacketSerializer::Normal)
/// .connect()
/// .await
/// .expect("connection failed");
/// }
/// ```
pub fn packet_serializer(mut self, packet_serializer: PacketSerializer) -> Self {
self.packet_serializer = packet_serializer;

self
}

/// Connects the socket to a certain endpoint. This returns a connected
/// [`Client`] instance. This method returns an [`std::result::Result::Err`]
/// value if something goes wrong during connection. Also starts a separate
Expand Down
Loading

0 comments on commit 9529d64

Please sign in to comment.