From 8386651feec5c4292c14defd22f128816f7f0e49 Mon Sep 17 00:00:00 2001 From: Ian McCormack Date: Mon, 8 Jan 2024 09:57:28 -0500 Subject: [PATCH] Switched to storing mz_stream as a raw pointer to fix tree borrows violation. Removed Deref and DerefMut implementations for StreamWrapper. --- src/ffi/c.rs | 173 ++++++++++++++++++++++++++++----------------------- src/mem.rs | 30 ++++++--- 2 files changed, 114 insertions(+), 89 deletions(-) diff --git a/src/ffi/c.rs b/src/ffi/c.rs index 01e16ac8..d12e33fd 100644 --- a/src/ffi/c.rs +++ b/src/ffi/c.rs @@ -4,7 +4,6 @@ use std::cmp; use std::convert::TryFrom; use std::fmt; use std::marker; -use std::ops::{Deref, DerefMut}; use std::os::raw::{c_int, c_uint, c_void}; use std::ptr; @@ -21,7 +20,10 @@ impl ErrorMessage { } pub struct StreamWrapper { - pub inner: Box, + // SAFETY: The field `inner` must always be accessed as a raw pointer, + // since it points to a cyclic structure, and it must never be copied + // by Rust. + pub inner: *mut mz_stream, } impl fmt::Debug for StreamWrapper { @@ -32,8 +34,12 @@ impl fmt::Debug for StreamWrapper { impl Default for StreamWrapper { fn default() -> StreamWrapper { + // SAFETY: The field `state` will be initialized across the FFI to + // point to the opaque type `mz_internal_state`, which will contain a copy + // of `inner`. This cyclic structure breaks the uniqueness invariant of + // &mut mz_stream, so we must use a raw pointer instead of Box. StreamWrapper { - inner: Box::new(mz_stream { + inner: Box::into_raw(Box::new(mz_stream { next_in: ptr::null_mut(), avail_in: 0, total_in: 0, @@ -54,11 +60,21 @@ impl Default for StreamWrapper { zalloc: Some(zalloc), #[cfg(not(all(feature = "any_zlib", not(feature = "cloudflare-zlib-sys"))))] zfree: Some(zfree), - }), + })), } } } +impl Drop for StreamWrapper { + fn drop(&mut self) { + // SAFETY: At this point, every other allocation for struct has been freed by + // `inflateEnd` or `deflateEnd`, and no copies of `inner` are retained by `C`, + // so it is safe to drop the struct as long as the user respects the invariant that + // `inner` must never be copied by Rust. + drop(unsafe { Box::from_raw(self.inner) }); + } +} + const ALIGN: usize = std::mem::align_of::(); fn align_up(size: usize, align: usize) -> usize { @@ -110,20 +126,6 @@ extern "C" fn zfree(_ptr: *mut c_void, address: *mut c_void) { } } -impl Deref for StreamWrapper { - type Target = mz_stream; - - fn deref(&self) -> &Self::Target { - &*self.inner - } -} - -impl DerefMut for StreamWrapper { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut *self.inner - } -} - unsafe impl Send for Stream {} unsafe impl Sync for Stream {} @@ -148,7 +150,10 @@ pub struct Stream { impl Stream { pub fn msg(&self) -> ErrorMessage { - let msg = self.stream_wrapper.msg; + // SAFETY: The field `inner` must always be accessed as a raw pointer, + // since it points to a cyclic structure. No copies of `inner` can be + // retained for longer than the lifetime of `self`. + let msg = unsafe { (*self.stream_wrapper.inner).msg }; ErrorMessage(if msg.is_null() { None } else { @@ -161,7 +166,7 @@ impl Stream { impl Drop for Stream { fn drop(&mut self) { unsafe { - let _ = D::destroy(&mut *self.stream_wrapper); + let _ = D::destroy(self.stream_wrapper.inner); } } } @@ -185,9 +190,9 @@ pub struct Inflate { impl InflateBackend for Inflate { fn make(zlib_header: bool, window_bits: u8) -> Self { unsafe { - let mut state = StreamWrapper::default(); + let state = StreamWrapper::default(); let ret = mz_inflateInit2( - &mut *state, + state.inner, if zlib_header { window_bits as c_int } else { @@ -212,33 +217,38 @@ impl InflateBackend for Inflate { output: &mut [u8], flush: FlushDecompress, ) -> Result { - let raw = &mut *self.inner.stream_wrapper; - raw.msg = ptr::null_mut(); - raw.next_in = input.as_ptr() as *mut u8; - raw.avail_in = cmp::min(input.len(), c_uint::MAX as usize) as c_uint; - raw.next_out = output.as_mut_ptr(); - raw.avail_out = cmp::min(output.len(), c_uint::MAX as usize) as c_uint; - - let rc = unsafe { mz_inflate(raw, flush as c_int) }; - - // Unfortunately the total counters provided by zlib might be only - // 32 bits wide and overflow while processing large amounts of data. - self.inner.total_in += (raw.next_in as usize - input.as_ptr() as usize) as u64; - self.inner.total_out += (raw.next_out as usize - output.as_ptr() as usize) as u64; - - // reset these pointers so we don't accidentally read them later - raw.next_in = ptr::null_mut(); - raw.avail_in = 0; - raw.next_out = ptr::null_mut(); - raw.avail_out = 0; - - match rc { - MZ_DATA_ERROR | MZ_STREAM_ERROR => mem::decompress_failed(self.inner.msg()), - MZ_OK => Ok(Status::Ok), - MZ_BUF_ERROR => Ok(Status::BufError), - MZ_STREAM_END => Ok(Status::StreamEnd), - MZ_NEED_DICT => mem::decompress_need_dict(raw.adler as u32), - c => panic!("unknown return code: {}", c), + let raw = self.inner.stream_wrapper.inner; + // SAFETY: The field `inner` must always be accessed as a raw pointer, + // since it points to a cyclic structure. No copies of `inner` can be + // retained for longer than the lifetime of `self`. + unsafe { + (*raw).msg = ptr::null_mut(); + (*raw).next_in = input.as_ptr() as *mut u8; + (*raw).avail_in = cmp::min(input.len(), c_uint::MAX as usize) as c_uint; + (*raw).next_out = output.as_mut_ptr(); + (*raw).avail_out = cmp::min(output.len(), c_uint::MAX as usize) as c_uint; + + let rc = mz_inflate(raw, flush as c_int); + + // Unfortunately the total counters provided by zlib might be only + // 32 bits wide and overflow while processing large amounts of data. + self.inner.total_in += ((*raw).next_in as usize - input.as_ptr() as usize) as u64; + self.inner.total_out += ((*raw).next_out as usize - output.as_ptr() as usize) as u64; + + // reset these pointers so we don't accidentally read them later + (*raw).next_in = ptr::null_mut(); + (*raw).avail_in = 0; + (*raw).next_out = ptr::null_mut(); + (*raw).avail_out = 0; + + match rc { + MZ_DATA_ERROR | MZ_STREAM_ERROR => mem::decompress_failed(self.inner.msg()), + MZ_OK => Ok(Status::Ok), + MZ_BUF_ERROR => Ok(Status::BufError), + MZ_STREAM_END => Ok(Status::StreamEnd), + MZ_NEED_DICT => mem::decompress_need_dict((*raw).adler as u32), + c => panic!("unknown return code: {}", c), + } } } @@ -249,7 +259,7 @@ impl InflateBackend for Inflate { -MZ_DEFAULT_WINDOW_BITS }; unsafe { - inflateReset2(&mut *self.inner.stream_wrapper, bits); + inflateReset2(self.inner.stream_wrapper.inner, bits); } self.inner.total_out = 0; self.inner.total_in = 0; @@ -276,9 +286,9 @@ pub struct Deflate { impl DeflateBackend for Deflate { fn make(level: Compression, zlib_header: bool, window_bits: u8) -> Self { unsafe { - let mut state = StreamWrapper::default(); + let state = StreamWrapper::default(); let ret = mz_deflateInit2( - &mut *state, + state.inner, level.0 as c_int, MZ_DEFLATED, if zlib_header { @@ -306,39 +316,44 @@ impl DeflateBackend for Deflate { output: &mut [u8], flush: FlushCompress, ) -> Result { - let raw = &mut *self.inner.stream_wrapper; - raw.msg = ptr::null_mut(); - raw.next_in = input.as_ptr() as *mut _; - raw.avail_in = cmp::min(input.len(), c_uint::MAX as usize) as c_uint; - raw.next_out = output.as_mut_ptr(); - raw.avail_out = cmp::min(output.len(), c_uint::MAX as usize) as c_uint; - - let rc = unsafe { mz_deflate(raw, flush as c_int) }; - - // Unfortunately the total counters provided by zlib might be only - // 32 bits wide and overflow while processing large amounts of data. - self.inner.total_in += (raw.next_in as usize - input.as_ptr() as usize) as u64; - self.inner.total_out += (raw.next_out as usize - output.as_ptr() as usize) as u64; - - // reset these pointers so we don't accidentally read them later - raw.next_in = ptr::null_mut(); - raw.avail_in = 0; - raw.next_out = ptr::null_mut(); - raw.avail_out = 0; - - match rc { - MZ_OK => Ok(Status::Ok), - MZ_BUF_ERROR => Ok(Status::BufError), - MZ_STREAM_END => Ok(Status::StreamEnd), - MZ_STREAM_ERROR => mem::compress_failed(self.inner.msg()), - c => panic!("unknown return code: {}", c), + let raw = self.inner.stream_wrapper.inner; + // SAFETY: The field `inner` must always be accessed as a raw pointer, + // since it points to a cyclic structure. No copies of `inner` can be + // retained for longer than the lifetime of `self`. + unsafe { + (*raw).msg = ptr::null_mut(); + (*raw).next_in = input.as_ptr() as *mut _; + (*raw).avail_in = cmp::min(input.len(), c_uint::MAX as usize) as c_uint; + (*raw).next_out = output.as_mut_ptr(); + (*raw).avail_out = cmp::min(output.len(), c_uint::MAX as usize) as c_uint; + + let rc = mz_deflate(raw, flush as c_int); + + // Unfortunately the total counters provided by zlib might be only + // 32 bits wide and overflow while processing large amounts of data. + + self.inner.total_in += ((*raw).next_in as usize - input.as_ptr() as usize) as u64; + self.inner.total_out += ((*raw).next_out as usize - output.as_ptr() as usize) as u64; + // reset these pointers so we don't accidentally read them later + (*raw).next_in = ptr::null_mut(); + (*raw).avail_in = 0; + (*raw).next_out = ptr::null_mut(); + (*raw).avail_out = 0; + + match rc { + MZ_OK => Ok(Status::Ok), + MZ_BUF_ERROR => Ok(Status::BufError), + MZ_STREAM_END => Ok(Status::StreamEnd), + MZ_STREAM_ERROR => mem::compress_failed(self.inner.msg()), + c => panic!("unknown return code: {}", c), + } } } fn reset(&mut self) { self.inner.total_in = 0; self.inner.total_out = 0; - let rc = unsafe { mz_deflateReset(&mut *self.inner.stream_wrapper) }; + let rc = unsafe { mz_deflateReset(self.inner.stream_wrapper.inner) }; assert_eq!(rc, MZ_OK); } } diff --git a/src/mem.rs b/src/mem.rs index d4a50917..86fa8d3b 100644 --- a/src/mem.rs +++ b/src/mem.rs @@ -265,16 +265,19 @@ impl Compress { /// Returns the Adler-32 checksum of the dictionary. #[cfg(feature = "any_zlib")] pub fn set_dictionary(&mut self, dictionary: &[u8]) -> Result { - let stream = &mut *self.inner.inner.stream_wrapper; - stream.msg = std::ptr::null_mut(); + // SAFETY: The field `inner` must always be accessed as a raw pointer, + // since it points to a cyclic structure. No copies of `inner` can be + // retained for longer than the lifetime of `self.inner.inner.stream_wrapper`. + let stream = self.inner.inner.stream_wrapper.inner; let rc = unsafe { + (*stream).msg = std::ptr::null_mut(); assert!(dictionary.len() < ffi::uInt::MAX as usize); ffi::deflateSetDictionary(stream, dictionary.as_ptr(), dictionary.len() as ffi::uInt) }; match rc { ffi::MZ_STREAM_ERROR => compress_failed(self.inner.inner.msg()), - ffi::MZ_OK => Ok(stream.adler as u32), + ffi::MZ_OK => Ok(unsafe { (*stream).adler } as u32), c => panic!("unknown return code: {}", c), } } @@ -299,9 +302,13 @@ impl Compress { #[cfg(feature = "any_zlib")] pub fn set_level(&mut self, level: Compression) -> Result<(), CompressError> { use std::os::raw::c_int; - let stream = &mut *self.inner.inner.stream_wrapper; - stream.msg = std::ptr::null_mut(); - + // SAFETY: The field `inner` must always be accessed as a raw pointer, + // since it points to a cyclic structure. No copies of `inner` can be + // retained for longer than the lifetime of `self.inner.inner.stream_wrapper`. + let stream = self.inner.inner.stream_wrapper.inner; + unsafe { + (*stream).msg = std::ptr::null_mut(); + } let rc = unsafe { ffi::deflateParams(stream, level.0 as c_int, ffi::MZ_DEFAULT_STRATEGY) }; match rc { @@ -476,17 +483,20 @@ impl Decompress { /// Specifies the decompression dictionary to use. #[cfg(feature = "any_zlib")] pub fn set_dictionary(&mut self, dictionary: &[u8]) -> Result { - let stream = &mut *self.inner.inner.stream_wrapper; - stream.msg = std::ptr::null_mut(); + // SAFETY: The field `inner` must always be accessed as a raw pointer, + // since it points to a cyclic structure. No copies of `inner` can be + // retained for longer than the lifetime of `self.inner.inner.stream_wrapper`. + let stream = self.inner.inner.stream_wrapper.inner; let rc = unsafe { + (*stream).msg = std::ptr::null_mut(); assert!(dictionary.len() < ffi::uInt::MAX as usize); ffi::inflateSetDictionary(stream, dictionary.as_ptr(), dictionary.len() as ffi::uInt) }; match rc { ffi::MZ_STREAM_ERROR => decompress_failed(self.inner.inner.msg()), - ffi::MZ_DATA_ERROR => decompress_need_dict(stream.adler as u32), - ffi::MZ_OK => Ok(stream.adler as u32), + ffi::MZ_DATA_ERROR => decompress_need_dict(unsafe { (*stream).adler } as u32), + ffi::MZ_OK => Ok(unsafe { (*stream).adler } as u32), c => panic!("unknown return code: {}", c), } }