From 1903c79d170991f40492e7e9f89f1606f8a9e197 Mon Sep 17 00:00:00 2001 From: Ivan Kalinin Date: Tue, 28 Nov 2023 19:22:23 +0100 Subject: [PATCH] Add `BoundedBytes` wrapper --- Cargo.toml | 2 +- src/seq.rs | 111 +++++++++++++++++++++++++++++++++++- test_suite/tests/tl_read.rs | 20 ++++++- 3 files changed, 130 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9cf9c36..8d7a12f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ name = "tl-proto" description = "A collection of traits for working with TL serialization/deserialization" authors = ["Ivan Kalinin "] repository = "https://github.com/broxus/tl-proto" -version = "0.4.2" +version = "0.4.3" edition = "2021" include = ["src/**/*.rs", "README.md"] license = "MIT" diff --git a/src/seq.rs b/src/seq.rs index 24f78c2..2e452ac 100644 --- a/src/seq.rs +++ b/src/seq.rs @@ -329,6 +329,114 @@ where } } +/// Bytes slice with a max length bound. +#[derive(Debug)] +#[repr(transparent)] +pub struct BoundedBytes([u8]); + +impl BoundedBytes { + /// Wraps a byte slice into a new type with length check. + #[inline] + pub const fn try_wrap(bytes: &[u8]) -> Option<&Self> { + if bytes.len() <= N { + // SAFETY: `BoundedBytes` has the same repr as `[u8]` + Some(unsafe { &*(bytes as *const [u8] as *const BoundedBytes) }) + } else { + None + } + } + + /// Wraps a byte slice into a new type without any checks. + /// + /// # Safety + /// + /// The following must be true: + /// - `bytes` must have length not greater than `N` + #[inline] + pub unsafe fn wrap_unchecked(bytes: &[u8]) -> &Self { + // SAFETY: `BoundedBytes` has the same repr as `[u8]` + unsafe { &*(bytes as *const [u8] as *const BoundedBytes) } + } +} + +impl AsRef<[u8]> for BoundedBytes { + #[inline] + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl AsMut<[u8]> for BoundedBytes { + #[inline] + fn as_mut(&mut self) -> &mut [u8] { + &mut self.0 + } +} + +impl std::ops::Deref for BoundedBytes { + type Target = [u8]; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for BoundedBytes { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl<'a, const N: usize> TlRead<'a> for &'a BoundedBytes { + type Repr = Bare; + + #[inline] + fn read_from(packet: &'a [u8], offset: &mut usize) -> TlResult { + fn read_bytes_with_max_len<'a>( + packet: &'a [u8], + max_len: usize, + offset: &mut usize, + ) -> TlResult<&'a [u8]> { + let current_offset = *offset; + let (prefix_len, len, padding) = ok!(compute_bytes_meta(packet, current_offset)); + if len > max_len { + return Err(TlError::InvalidData); + } + + let result = unsafe { + std::slice::from_raw_parts(packet.as_ptr().add(current_offset + prefix_len), len) + }; + + *offset += prefix_len + len + padding; + Ok(result) + } + + let result = ok!(read_bytes_with_max_len(packet, N, offset)); + + // SAFETY: `len <= N` + Ok(unsafe { BoundedBytes::wrap_unchecked(result) }) + } +} + +impl TlWrite for &BoundedBytes { + type Repr = Bare; + + #[inline(always)] + fn max_size_hint(&self) -> usize { + bytes_max_size_hint(self.len()) + } + + #[inline(always)] + fn write_to

(&self, packet: &mut P) + where + P: TlPacket, + { + write_bytes(self, packet) + } +} + /// Helper type which is used to represent field value as bytes. #[derive(Debug, Clone)] pub struct IntermediateBytes(pub T); @@ -425,8 +533,9 @@ impl PartialEq for RawBytes<'_, R> { impl Copy for RawBytes<'_, R> {} impl Clone for RawBytes<'_, R> { + #[inline] fn clone(&self) -> Self { - Self(self.0, std::marker::PhantomData) + *self } } diff --git a/test_suite/tests/tl_read.rs b/test_suite/tests/tl_read.rs index 8615e6a..a425ee2 100644 --- a/test_suite/tests/tl_read.rs +++ b/test_suite/tests/tl_read.rs @@ -1,6 +1,6 @@ #[allow(dead_code)] mod tests { - use tl_proto::{BytesMeta, TlError, TlRead, TlResult}; + use tl_proto::{BoundedBytes, BytesMeta, TlError, TlRead, TlResult}; #[derive(TlRead)] struct SimpleStruct { @@ -166,4 +166,22 @@ mod tests { Err(TlError::UnexpectedEof) )); } + + #[test] + fn bounded_bytes() { + #[derive(TlRead)] + struct Data<'tl> { + bytes: &'tl BoundedBytes<4>, + } + + let packet = [4, 1, 2, 3, 4, 0, 0, 0]; + let Data { bytes } = tl_proto::deserialize(&packet).unwrap(); + assert_eq!(bytes.as_ref(), &[1, 2, 3, 4]); + + let big_packet = [5, 1, 2, 3, 4, 5, 0, 0]; + assert!(matches!( + tl_proto::deserialize::(&big_packet), + Err(TlError::InvalidData) + )); + } }