diff --git a/Cargo.toml b/Cargo.toml index 15be228..3c7e87d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ name = "tl-proto" description = "A collection of traits for working with TL serialization/deserialization" authors = ["Ivan Kalinin "] repository = "https://github.com/broxus/tl-proto" -version = "0.4.10" +version = "0.5.0" edition = "2021" include = ["src/**/*.rs", "README.md"] license = "MIT" @@ -18,7 +18,7 @@ sha2 = { version = "0.10", optional = true } smallvec = { version = "1.7", features = ["union", "const_generics"] } thiserror = "1.0.64" -tl-proto-proc = { version = "=0.4.7", path = "proc", optional = true } +tl-proto-proc = { version = "=0.5.0", path = "proc", optional = true } [features] default = ["derive", "bytes", "hash"] diff --git a/proc/Cargo.toml b/proc/Cargo.toml index 735c477..27df0ea 100644 --- a/proc/Cargo.toml +++ b/proc/Cargo.toml @@ -3,7 +3,7 @@ name = "tl-proto-proc" description = "A collection of traits for working with TL serialization/deserialization" authors = ["Ivan Kalinin "] repository = "https://github.com/broxus/tl-proto" -version = "0.4.7" +version = "0.5.0" edition = "2021" include = ["src/**/*.rs", "../README.md"] license = "MIT" diff --git a/proc/src/tl_read.rs b/proc/src/tl_read.rs index 2cc2d6a..7a265d5 100644 --- a/proc/src/tl_read.rs +++ b/proc/src/tl_read.rs @@ -119,8 +119,8 @@ fn build_enum(variants: &[ast::Variant]) -> TokenStream { }); quote! { - fn read_from(__packet: &'tl [u8], __offset: &mut usize) -> _tl_proto::TlResult { - match u32::read_from(__packet, __offset) { + fn read_from(__packet: &mut &'tl [u8]) -> _tl_proto::TlResult { + match u32::read_from(__packet) { Ok(constructor) => match constructor { #(#variants)* _ => Err(_tl_proto::TlError::UnknownConstructor) @@ -145,7 +145,7 @@ fn build_struct( .map(|id| { let id = id.unwrap_explicit(); quote! { - match u32::read_from(__packet, __offset) { + match u32::read_from(__packet) { Ok(constructor) => { if constructor != #id { return Err(_tl_proto::TlError::UnknownConstructor) @@ -160,7 +160,7 @@ fn build_struct( let read_from = build_read_from(quote! { Self }, style, fields); quote! { - fn read_from(__packet: &'tl [u8], __offset: &mut usize) -> _tl_proto::TlResult { + fn read_from(__packet: &mut &'tl [u8]) -> _tl_proto::TlResult { #(#prefix)* #read_from } @@ -194,7 +194,7 @@ fn build_read_from(ident: TokenStream, style: &ast::Style, fields: &[ast::Field] if field.attrs.flags { quote! { - let #ident = match >::read_from(__packet, __offset) { + let #ident = match >::read_from(__packet) { Ok(flags) => flags, Err(e) => return Err(e), }; @@ -210,14 +210,14 @@ fn build_read_from(ident: TokenStream, style: &ast::Style, fields: &[ast::Field] let read = if let Some(with) = &field.attrs.with { quote! { - match #with::read(__packet, __offset) { + match #with::read(__packet) { Ok(value) => value, Err(e) => return Err(e), } } } else if let Some(read_with) = &field.attrs.read_with { quote! { - match #read_with(__packet, __offset) { + match #read_with(__packet) { Ok(value) => value, Err(e) => return Err(e), } @@ -225,7 +225,7 @@ fn build_read_from(ident: TokenStream, style: &ast::Style, fields: &[ast::Field] } else { quote! { match <<#ty as IntoIterator>::Item as _tl_proto::TlRead<'tl>>::read_from( - __packet, __offset, + __packet, ) { Ok(value) => value, Err(e) => return Err(e), @@ -242,21 +242,21 @@ fn build_read_from(ident: TokenStream, style: &ast::Style, fields: &[ast::Field] } } else if let Some(with) = &field.attrs.with { quote! { - let #ident = match #with::read(__packet, __offset) { + let #ident = match #with::read(__packet) { Ok(value) => value, Err(e) => return Err(e), }; } } else if let Some(read_with) = &field.attrs.read_with { quote! { - let #ident = match #read_with(__packet, __offset) { + let #ident = match #read_with(__packet) { Ok(value) => value, Err(e) => return Err(e), }; } } else { quote! { - let #ident = match <#ty as _tl_proto::TlRead<'tl>>::read_from(__packet, __offset) { + let #ident = match <#ty as _tl_proto::TlRead<'tl>>::read_from(__packet) { Ok(value) => value, Err(e) => return Err(e), }; diff --git a/src/boxed.rs b/src/boxed.rs index ab63321..8c3a05d 100644 --- a/src/boxed.rs +++ b/src/boxed.rs @@ -56,9 +56,9 @@ where { type Repr = Boxed; - fn read_from(packet: &'a [u8], offset: &mut usize) -> TlResult { - match u32::read_from(packet, offset) { - Ok(id) if id == T::TL_ID => match T::read_from(packet, offset) { + fn read_from(packet: &mut &'a [u8]) -> TlResult { + match u32::read_from(packet) { + Ok(id) if id == T::TL_ID => match T::read_from(packet) { Ok(data) => Ok(BoxedWrapper(data)), Err(e) => Err(e), }, diff --git a/src/lib.rs b/src/lib.rs index 0007329..a0b2697 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,11 +32,11 @@ mod tuple; mod util; /// Tries to deserialize `T` from the TL representation. -pub fn deserialize<'a, T>(packet: &'a [u8]) -> TlResult +pub fn deserialize<'a, T>(mut packet: &'a [u8]) -> TlResult where T: TlRead<'a>, { - T::read_from(packet, &mut 0) + T::read_from(&mut packet) } /// Tries to deserialize `T` as boxed from the TL representation. diff --git a/src/primitive.rs b/src/primitive.rs index 6edae02..1fcbd47 100644 --- a/src/primitive.rs +++ b/src/primitive.rs @@ -1,10 +1,9 @@ use crate::traits::*; -use crate::util::*; impl TlRead<'_> for () { type Repr = Bare; - fn read_from(_packet: &'_ [u8], _offset: &mut usize) -> TlResult { + fn read_from(_packet: &mut &'_ [u8]) -> TlResult { Ok(()) } } @@ -28,8 +27,8 @@ impl TlWrite for () { impl TlRead<'_> for bool { type Repr = Boxed; - fn read_from(packet: &[u8], offset: &mut usize) -> TlResult { - match u32::read_from(packet, offset) { + fn read_from(packet: &mut &'_ [u8]) -> TlResult { + match u32::read_from(packet) { Ok(BOOL_TRUE) => Ok(true), Ok(BOOL_FALSE) => Ok(false), Ok(_) => Err(TlError::UnknownConstructor), @@ -61,15 +60,14 @@ macro_rules! impl_read_from_packet( type Repr = Bare; #[inline(always)] - fn read_from(packet: &[u8], offset: &mut usize) -> TlResult { - if unlikely(packet.len() < *offset + std::mem::size_of::<$ty>()) { - Err(TlError::UnexpectedEof) - } else { - let value = <$ty>::from_le_bytes(unsafe { - *(packet.as_ptr().add(*offset) as *const [u8; std::mem::size_of::<$ty>()]) - }); - *offset += std::mem::size_of::<$ty>(); - Ok(value) + fn read_from(packet: &mut &'_ [u8]) -> TlResult { + match packet.split_first_chunk() { + Some((first, tail)) => { + let value = <$ty>::from_le_bytes(*first); + *packet = tail; + Ok(value) + } + None => Err(TlError::UnexpectedEof), } } } @@ -233,8 +231,8 @@ macro_rules! impl_non_zero { type Repr = Bare; #[inline(always)] - fn read_from(packet: &[u8], offset: &mut usize) -> TlResult { - match <$ty>::new(<$read_ty>::read_from(packet, offset)?) { + fn read_from(packet: &mut &'_ [u8]) -> TlResult { + match <$ty>::new(<$read_ty>::read_from(packet)?) { Some(value) => Ok(value), None => Err(TlError::InvalidData), } @@ -276,71 +274,67 @@ mod test { fn read_non_zero() { // u32 assert!(matches!( - std::num::NonZeroU32::read_from(&[0, 0], &mut 0).unwrap_err(), + std::num::NonZeroU32::read_from(&mut [0, 0].as_ref()).unwrap_err(), TlError::UnexpectedEof )); assert!(matches!( - std::num::NonZeroU32::read_from(&[0, 0, 0, 0], &mut 0).unwrap_err(), + std::num::NonZeroU32::read_from(&mut [0, 0, 0, 0].as_ref()).unwrap_err(), TlError::InvalidData )); - let mut offset = 0; + let mut packet: &[u8] = &[123, 0, 0, 0]; assert_eq!( - std::num::NonZeroU32::read_from(&[123, 0, 0, 0], &mut offset).unwrap(), + std::num::NonZeroU32::read_from(&mut packet).unwrap(), std::num::NonZeroU32::new(123).unwrap(), ); - assert_eq!(offset, 4); + assert!(packet.is_empty()); // i32 assert!(matches!( - std::num::NonZeroI32::read_from(&[0, 0], &mut 0).unwrap_err(), + std::num::NonZeroI32::read_from(&mut [0, 0].as_ref()).unwrap_err(), TlError::UnexpectedEof )); assert!(matches!( - std::num::NonZeroI32::read_from(&[0, 0, 0, 0], &mut 0).unwrap_err(), + std::num::NonZeroI32::read_from(&mut [0, 0, 0, 0].as_ref()).unwrap_err(), TlError::InvalidData )); - let mut offset = 0; + let mut packet: &[u8] = &[0xfe, 0xff, 0xff, 0xff]; assert_eq!( - std::num::NonZeroI32::read_from(&[0xfe, 0xff, 0xff, 0xff], &mut offset).unwrap(), + std::num::NonZeroI32::read_from(&mut packet).unwrap(), std::num::NonZeroI32::new(-2).unwrap(), ); - assert_eq!(offset, 4); + assert!(packet.is_empty()); // u64 assert!(matches!( - std::num::NonZeroU64::read_from(&[0, 0, 0, 0], &mut 0).unwrap_err(), + std::num::NonZeroU64::read_from(&mut [0, 0, 0, 0].as_ref()).unwrap_err(), TlError::UnexpectedEof )); assert!(matches!( - std::num::NonZeroU64::read_from(&[0, 0, 0, 0, 0, 0, 0, 0], &mut 0).unwrap_err(), + std::num::NonZeroU64::read_from(&mut [0, 0, 0, 0, 0, 0, 0, 0].as_ref()).unwrap_err(), TlError::InvalidData )); - let mut offset = 0; + let mut packet: &[u8] = &[123, 0, 0, 0, 0, 0, 0, 0]; assert_eq!( - std::num::NonZeroU64::read_from(&[123, 0, 0, 0, 0, 0, 0, 0], &mut offset).unwrap(), + std::num::NonZeroU64::read_from(&mut packet).unwrap(), std::num::NonZeroU64::new(123).unwrap(), ); - assert_eq!(offset, 8); + assert!(packet.is_empty()); // i64 assert!(matches!( - std::num::NonZeroI64::read_from(&[0, 0, 0, 0], &mut 0).unwrap_err(), + std::num::NonZeroI64::read_from(&mut [0, 0, 0, 0].as_ref()).unwrap_err(), TlError::UnexpectedEof )); assert!(matches!( - std::num::NonZeroI64::read_from(&[0, 0, 0, 0, 0, 0, 0, 0], &mut 0).unwrap_err(), + std::num::NonZeroI64::read_from(&mut [0, 0, 0, 0, 0, 0, 0, 0].as_ref()).unwrap_err(), TlError::InvalidData )); - let mut offset = 0; + let mut packet: &[u8] = &[0xfe, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff]; assert_eq!( - std::num::NonZeroI64::read_from( - &[0xfe, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff], - &mut offset - ) - .unwrap(), + std::num::NonZeroI64::read_from(&mut packet,).unwrap(), std::num::NonZeroI64::new(-2).unwrap(), ); - assert_eq!(offset, 8); + assert!(packet.is_empty()); } } diff --git a/src/seq.rs b/src/seq.rs index d185bed..7d72021 100644 --- a/src/seq.rs +++ b/src/seq.rs @@ -20,8 +20,8 @@ impl<'a> TlRead<'a> for BytesMeta { type Repr = Bare; #[inline(always)] - fn read_from(packet: &'a [u8], offset: &mut usize) -> TlResult { - match compute_bytes_meta(packet, *offset) { + fn read_from(packet: &mut &'a [u8]) -> TlResult { + match compute_bytes_meta(packet) { Ok((prefix_len, len, padding)) => Ok(Self { prefix_len, len, @@ -37,8 +37,8 @@ impl<'a> TlRead<'a> for &'a [u8] { type Repr = Bare; #[inline(always)] - fn read_from(packet: &'a [u8], offset: &mut usize) -> TlResult { - read_bytes(packet, offset) + fn read_from(packet: &mut &'a [u8]) -> TlResult { + read_bytes(packet) } } @@ -65,8 +65,8 @@ impl<'a> TlRead<'a> for Box<[u8]> { type Repr = Bare; #[inline(always)] - fn read_from(packet: &'a [u8], offset: &mut usize) -> TlResult { - Ok(Box::from(ok!(read_bytes(packet, offset)))) + fn read_from(packet: &mut &'a [u8]) -> TlResult { + Ok(Box::from(ok!(read_bytes(packet)))) } } @@ -93,8 +93,8 @@ impl<'a> TlRead<'a> for Vec { type Repr = Bare; #[inline(always)] - fn read_from(packet: &'a [u8], offset: &mut usize) -> TlResult { - match read_bytes(packet, offset) { + fn read_from(packet: &mut &'a [u8]) -> TlResult { + match read_bytes(packet) { Ok(bytes) => Ok(bytes.to_vec()), Err(e) => Err(e), } @@ -125,9 +125,9 @@ impl TlRead<'_> for bytes::Bytes { type Repr = Bare; #[inline(always)] - fn read_from(packet: &'_ [u8], offset: &mut usize) -> TlResult { - match read_bytes(packet, offset) { - Ok(bytes) => Ok(bytes::Bytes::from(Box::from(bytes))), + fn read_from(packet: &mut &'_ [u8]) -> TlResult { + match read_bytes(packet) { + Ok(bytes) => Ok(bytes::Bytes::copy_from_slice(bytes)), Err(e) => Err(e), } } @@ -157,8 +157,8 @@ impl<'a, const N: usize> TlRead<'a> for &'a [u8; N] { type Repr = Bare; #[inline(always)] - fn read_from(packet: &'a [u8], offset: &mut usize) -> TlResult { - read_fixed_bytes(packet, offset) + fn read_from(packet: &mut &'a [u8]) -> TlResult { + read_fixed_bytes(packet) } } @@ -167,8 +167,8 @@ impl<'a, const N: usize> TlRead<'a> for [u8; N] { type Repr = Bare; #[inline(always)] - fn read_from(packet: &'a [u8], offset: &mut usize) -> TlResult { - match read_fixed_bytes(packet, offset) { + fn read_from(packet: &mut &'a [u8]) -> TlResult { + match read_fixed_bytes(packet) { Ok(data) => Ok(*data), Err(e) => Err(e), } @@ -201,12 +201,12 @@ where { type Repr = Bare; - fn read_from(packet: &'a [u8], offset: &mut usize) -> TlResult { - let len = ok!(read_vector_len(packet, offset)); + fn read_from(packet: &mut &'a [u8]) -> TlResult { + let len = ok!(read_vector_len(packet)); let mut items = SmallVec::<[T; N]>::with_capacity(len); for _ in 0..len { - items.push(ok!(TlRead::read_from(packet, offset))); + items.push(ok!(TlRead::read_from(packet))); } Ok(items) } @@ -219,12 +219,12 @@ where { type Repr = Bare; - fn read_from(packet: &'a [u8], offset: &mut usize) -> TlResult { - let len = ok!(read_vector_len(packet, offset)); + fn read_from(packet: &mut &'a [u8]) -> TlResult { + let len = ok!(read_vector_len(packet)); let mut items = Vec::with_capacity(len); for _ in 0..len { - items.push(ok!(TlRead::read_from(packet, offset))); + items.push(ok!(TlRead::read_from(packet))); } Ok(items) } @@ -393,27 +393,30 @@ impl<'a, const N: usize> TlRead<'a> for &'a BoundedBytes { type Repr = Bare; #[inline] - fn read_from(packet: &'a [u8], offset: &mut usize) -> TlResult { + fn read_from(packet: &mut &'a [u8]) -> TlResult { fn read_bytes_with_max_len<'a>( - packet: &'a [u8], + packet: &mut &'a [u8], max_len: usize, - offset: &mut usize, ) -> TlResult<&'a [u8]> { - let current_offset = *offset; - let (prefix_len, len, padding) = ok!(compute_bytes_meta(packet, current_offset)); + let (prefix_len, len, padding) = ok!(compute_bytes_meta(packet)); if len > max_len { return Err(TlError::InvalidData); } - let result = unsafe { - std::slice::from_raw_parts(packet.as_ptr().add(current_offset + prefix_len), len) - }; + let packet_ptr = packet.as_ptr(); + let result = unsafe { std::slice::from_raw_parts(packet_ptr.add(prefix_len), len) }; - *offset += prefix_len + len + padding; + let skip_len = prefix_len + len + padding; + *packet = unsafe { + std::slice::from_raw_parts( + packet_ptr.add(skip_len), + packet.len().unchecked_sub(skip_len), + ) + }; Ok(result) } - let result = ok!(read_bytes_with_max_len(packet, N, offset)); + let result = ok!(read_bytes_with_max_len(packet, N)); // SAFETY: `len <= N` Ok(unsafe { BoundedBytes::wrap_unchecked(result) }) @@ -457,9 +460,9 @@ where { type Repr = Bare; - fn read_from(packet: &'a [u8], offset: &mut usize) -> TlResult { - match read_bytes(packet, offset) { - Ok(intermediate) => match T::read_from(intermediate, &mut 0) { + fn read_from(packet: &mut &'a [u8]) -> TlResult { + match read_bytes(packet) { + Ok(mut intermediate) => match T::read_from(&mut intermediate) { Ok(data) => Ok(IntermediateBytes(data)), Err(e) => Err(e), }, @@ -542,10 +545,11 @@ impl Clone for RawBytes<'_, R> { impl<'a, R: Repr> TlRead<'a> for RawBytes<'a, R> { type Repr = R; - fn read_from(packet: &'a [u8], offset: &mut usize) -> TlResult { - let len = packet.len() - std::cmp::min(*offset, packet.len()); - let result = unsafe { std::slice::from_raw_parts(packet.as_ptr().add(*offset), len) }; - *offset += len; + fn read_from(packet: &mut &'a [u8]) -> TlResult { + let result = *packet; + // NOTE: Assign the end of the packet instead of just empty slice to + // to leave the pointer at the same location. + *packet = &packet[packet.len()..]; Ok(Self::new(result)) } } @@ -614,8 +618,8 @@ impl AsRef<[u8]> for OwnedRawBytes { impl TlRead<'_> for OwnedRawBytes { type Repr = R; - fn read_from(packet: &'_ [u8], offset: &mut usize) -> TlResult { - match RawBytes::::read_from(packet, offset) { + fn read_from(packet: &mut &'_ [u8]) -> TlResult { + match RawBytes::::read_from(packet) { Ok(RawBytes(inner, ..)) => Ok(Self::new(inner.to_vec())), Err(e) => Err(e), } @@ -640,12 +644,12 @@ impl TlWrite for OwnedRawBytes { } #[inline(always)] -fn read_vector_len(packet: &[u8], offset: &mut usize) -> TlResult { - let len = ok!(u32::read_from(packet, offset)) as usize; +fn read_vector_len(packet: &mut &[u8]) -> TlResult { + let len = ok!(u32::read_from(packet)) as usize; // Length cannot be greater than the rest of the packet. // However min item size is 4 bytes so we could reduce it four times - if unlikely((len * 4 + *offset) > packet.len()) { + if unlikely((len * 4) > packet.len()) { Err(TlError::UnexpectedEof) } else { Ok(len) @@ -653,16 +657,13 @@ fn read_vector_len(packet: &[u8], offset: &mut usize) -> TlResult { } #[inline(always)] -fn read_fixed_bytes<'a, const N: usize>( - packet: &'a [u8], - offset: &mut usize, -) -> TlResult<&'a [u8; N]> { - if unlikely(packet.len() < *offset + N) { - Err(TlError::UnexpectedEof) - } else { - let ptr = unsafe { &*(packet.as_ptr().add(*offset) as *const [u8; N]) }; - *offset += N; - Ok(ptr) +fn read_fixed_bytes<'a, const N: usize>(packet: &mut &'a [u8]) -> TlResult<&'a [u8; N]> { + match packet.split_first_chunk() { + Some((chunk, tail)) => { + *packet = tail; + Ok(chunk) + } + None => Err(TlError::UnexpectedEof), } } @@ -706,21 +707,24 @@ where let remainder = have_written % 4; if remainder != 0 { - let buf = [0u8; 4]; - packet.write_raw_slice(&buf[remainder..]); + packet.write_raw_slice(&[0u8; 4][remainder..]); } } #[inline(always)] -fn read_bytes<'a>(packet: &'a [u8], offset: &mut usize) -> TlResult<&'a [u8]> { - let current_offset = *offset; - let (prefix_len, len, padding) = ok!(compute_bytes_meta(packet, current_offset)); - - let result = unsafe { - std::slice::from_raw_parts(packet.as_ptr().add(current_offset + prefix_len), len) +fn read_bytes<'a>(packet: &mut &'a [u8]) -> TlResult<&'a [u8]> { + let (prefix_len, len, padding) = ok!(compute_bytes_meta(packet)); + + let packet_ptr = packet.as_ptr(); + let result = unsafe { std::slice::from_raw_parts(packet_ptr.add(prefix_len), len) }; + + let skip_len = prefix_len + len + padding; + *packet = unsafe { + std::slice::from_raw_parts( + packet_ptr.add(skip_len), + packet.len().unchecked_sub(skip_len), + ) }; - - *offset += prefix_len + len + padding; Ok(result) } @@ -728,13 +732,13 @@ fn read_bytes<'a>(packet: &'a [u8], offset: &mut usize) -> TlResult<&'a [u8]> { /// /// Returns **prefix length**, **bytes length** and **padding length** #[inline(always)] -fn compute_bytes_meta(packet: &[u8], offset: usize) -> TlResult<(usize, usize, usize)> { +fn compute_bytes_meta(packet: &[u8]) -> TlResult<(usize, usize, usize)> { let packet_len = packet.len(); - if unlikely(packet_len <= offset + 4) { + if unlikely(packet_len < 4) { return Err(TlError::UnexpectedEof); } - let first_bytes = unsafe { packet.as_ptr().add(offset).cast::().read_unaligned() }; + let first_bytes = unsafe { packet.as_ptr().cast::().read_unaligned() }; let (len, have_read) = if first_bytes & 0xff != SIZE_MAGIC as u32 { ((first_bytes & 0xff) as usize, 1) } else { @@ -742,7 +746,7 @@ fn compute_bytes_meta(packet: &[u8], offset: usize) -> TlResult<(usize, usize, u }; let padding = (4 - (have_read + len) % 4) % 4; - if unlikely(packet_len < offset + have_read + len + padding) { + if unlikely(packet_len < have_read + len + padding) { return Err(TlError::UnexpectedEof); } @@ -750,3 +754,13 @@ fn compute_bytes_meta(packet: &[u8], offset: usize) -> TlResult<(usize, usize, u } const SIZE_MAGIC: u8 = 254; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn read_small_slice() { + assert_eq!(read_bytes(&mut [1, 123, 0, 0].as_ref()).unwrap(), &[123]); + } +} diff --git a/src/traits.rs b/src/traits.rs index 5e1d8e5..b5a6682 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -24,7 +24,7 @@ pub trait TlRead<'a>: Sized { type Repr: Repr; /// Tries to read itself from bytes at the specified offset, incrementing that offset. - fn read_from(packet: &'a [u8], offset: &mut usize) -> TlResult; + fn read_from(packet: &mut &'a [u8]) -> TlResult; } impl<'a, T> TlRead<'a> for Arc @@ -34,8 +34,8 @@ where type Repr = T::Repr; #[inline(always)] - fn read_from(packet: &'a [u8], offset: &mut usize) -> TlResult { - match T::read_from(packet, offset) { + fn read_from(packet: &mut &'a [u8]) -> TlResult { + match T::read_from(packet) { Ok(data) => Ok(Arc::new(data)), Err(e) => Err(e), } @@ -49,8 +49,8 @@ where type Repr = T::Repr; #[inline(always)] - fn read_from(packet: &'a [u8], offset: &mut usize) -> TlResult { - match T::read_from(packet, offset) { + fn read_from(packet: &mut &'a [u8]) -> TlResult { + match T::read_from(packet) { Ok(data) => Ok(Box::new(data)), Err(e) => Err(e), } diff --git a/src/tuple.rs b/src/tuple.rs index a51e35f..344deda 100644 --- a/src/tuple.rs +++ b/src/tuple.rs @@ -10,8 +10,8 @@ macro_rules! impl_traits_for_tuple { type Repr = Bare; #[inline(always)] - fn read_from(packet: &'a [u8], offset: &mut usize) -> TlResult { - Ok(($(ok!($ty::read_from(packet, offset))),*,)) + fn read_from(packet: &mut &'a [u8]) -> TlResult { + Ok(($(ok!($ty::read_from(packet))),*,)) } } diff --git a/test_suite/tests/tl_read.rs b/test_suite/tests/tl_read.rs index ce2baa3..87212ef 100644 --- a/test_suite/tests/tl_read.rs +++ b/test_suite/tests/tl_read.rs @@ -69,29 +69,23 @@ mod tests { signature: &'tl [u8], } - fn read_f32(mut packet: &[u8], offset: &mut usize) -> TlResult { - use std::io::Read; - - let mut bytes = [0; 4]; - packet - .read_exact(&mut bytes) - .map_err(|_| TlError::UnexpectedEof)?; - *offset += 4; - Ok(u32::from_le_bytes(bytes) as f32) + fn read_f32(packet: &mut &[u8]) -> TlResult { + let Some((bytes, tail)) = packet.split_first_chunk() else { + return Err(TlError::UnexpectedEof); + }; + *packet = tail; + Ok(u32::from_le_bytes(*bytes) as f32) } mod tl_u128 { use super::*; - pub fn read(packet: &[u8], offset: &mut usize) -> TlResult { - use std::io::Read; - - let mut bytes = [0; 16]; - (&packet[*offset..]) - .read_exact(&mut bytes) - .map_err(|_| TlError::UnexpectedEof)?; - *offset += 16; - Ok(u128::from_be_bytes(bytes)) + pub fn read(packet: &mut &[u8]) -> TlResult { + let Some((bytes, tail)) = packet.split_first_chunk() else { + return Err(TlError::UnexpectedEof); + }; + *packet = tail; + Ok(u128::from_be_bytes(*bytes)) } }