From e9ffde4076f6554c0c159acbf13886de5a2154cc Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Tue, 20 Aug 2024 03:04:49 -0700 Subject: [PATCH] fix(sqlite): audit for bad casts --- sqlx-sqlite/src/connection/collation.rs | 16 ++++- sqlx-sqlite/src/connection/explain.rs | 7 ++ sqlx-sqlite/src/connection/intmap.rs | 7 ++ sqlx-sqlite/src/logger.rs | 8 +++ sqlx-sqlite/src/statement/handle.rs | 74 +++++++++++++++------- sqlx-sqlite/src/statement/unlock_notify.rs | 3 +- sqlx-sqlite/src/statement/virtual.rs | 8 ++- sqlx-sqlite/src/types/chrono.rs | 16 ++++- sqlx-sqlite/src/types/float.rs | 2 + sqlx-sqlite/src/value.rs | 8 ++- 10 files changed, 116 insertions(+), 33 deletions(-) diff --git a/sqlx-sqlite/src/connection/collation.rs b/sqlx-sqlite/src/connection/collation.rs index 8cffda84c5..573a9af892 100644 --- a/sqlx-sqlite/src/connection/collation.rs +++ b/sqlx-sqlite/src/connection/collation.rs @@ -127,13 +127,23 @@ where C: Fn(&str, &str) -> Ordering, { let boxed_f: *mut C = data as *mut C; - debug_assert!(!boxed_f.is_null()); + + // Note: unwinding is now caught at the FFI boundary: + // https://doc.rust-lang.org/nomicon/ffi.html#ffi-and-unwinding + assert!(!boxed_f.is_null()); + + let left_len = + usize::try_from(left_len).unwrap_or_else(|_| panic!("left_len out of range: {left_len}")); + + let right_len = usize::try_from(right_len) + .unwrap_or_else(|_| panic!("right_len out of range: {right_len}")); + let s1 = { - let c_slice = slice::from_raw_parts(left_ptr as *const u8, left_len as usize); + let c_slice = slice::from_raw_parts(left_ptr as *const u8, left_len); from_utf8_unchecked(c_slice) }; let s2 = { - let c_slice = slice::from_raw_parts(right_ptr as *const u8, right_len as usize); + let c_slice = slice::from_raw_parts(right_ptr as *const u8, right_len); from_utf8_unchecked(c_slice) }; let t = (*boxed_f)(s1, s2); diff --git a/sqlx-sqlite/src/connection/explain.rs b/sqlx-sqlite/src/connection/explain.rs index a18cd58a53..89762d171f 100644 --- a/sqlx-sqlite/src/connection/explain.rs +++ b/sqlx-sqlite/src/connection/explain.rs @@ -1,3 +1,10 @@ +// Bad casts in this module SHOULD NOT result in a SQL injection +// https://github.com/launchbadge/sqlx/issues/3440 +#![allow( + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::cast_sign_loss +)] use crate::connection::intmap::IntMap; use crate::connection::{execute, ConnectionState}; use crate::error::Error; diff --git a/sqlx-sqlite/src/connection/intmap.rs b/sqlx-sqlite/src/connection/intmap.rs index 05a27ba9d8..dc09162f64 100644 --- a/sqlx-sqlite/src/connection/intmap.rs +++ b/sqlx-sqlite/src/connection/intmap.rs @@ -1,3 +1,10 @@ +// Bad casts in this module SHOULD NOT result in a SQL injection +// https://github.com/launchbadge/sqlx/issues/3440 +#![allow( + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::cast_sign_loss +)] use std::cmp::Ordering; use std::{fmt::Debug, hash::Hash}; diff --git a/sqlx-sqlite/src/logger.rs b/sqlx-sqlite/src/logger.rs index a3de1374e3..40fabd48ed 100644 --- a/sqlx-sqlite/src/logger.rs +++ b/sqlx-sqlite/src/logger.rs @@ -1,3 +1,11 @@ +// Bad casts in this module SHOULD NOT result in a SQL injection +// https://github.com/launchbadge/sqlx/issues/3440 +#![allow( + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::cast_sign_loss +)] + use crate::connection::intmap::IntMap; use std::collections::HashSet; use std::fmt::Debug; diff --git a/sqlx-sqlite/src/statement/handle.rs b/sqlx-sqlite/src/statement/handle.rs index e1a7ab3de1..b1b72242d3 100644 --- a/sqlx-sqlite/src/statement/handle.rs +++ b/sqlx-sqlite/src/statement/handle.rs @@ -34,6 +34,21 @@ pub(crate) struct StatementHandle(NonNull); unsafe impl Send for StatementHandle {} +macro_rules! expect_ret_valid { + ($fn_name:ident($($args:tt)*)) => {{ + let val = $fn_name($($args)*); + + TryFrom::try_from(val) + .unwrap_or_else(|_| panic!("{}() returned invalid value: {val:?}", stringify!($fn_name))) + }} +} + +macro_rules! check_col_idx { + ($idx:ident) => { + c_int::try_from($idx).unwrap_or_else(|_| panic!("invalid column index: {}", $idx)) + }; +} + // might use some of this later #[allow(dead_code)] impl StatementHandle { @@ -71,7 +86,7 @@ impl StatementHandle { #[inline] pub(crate) fn column_count(&self) -> usize { // https://sqlite.org/c3ref/column_count.html - unsafe { sqlite3_column_count(self.0.as_ptr()) as usize } + unsafe { expect_ret_valid!(sqlite3_column_count(self.0.as_ptr())) } } #[inline] @@ -79,14 +94,14 @@ impl StatementHandle { // returns the number of changes of the *last* statement; not // necessarily this statement. // https://sqlite.org/c3ref/changes.html - unsafe { sqlite3_changes(self.db_handle()) as u64 } + unsafe { expect_ret_valid!(sqlite3_changes(self.db_handle())) } } #[inline] pub(crate) fn column_name(&self, index: usize) -> &str { // https://sqlite.org/c3ref/column_name.html unsafe { - let name = sqlite3_column_name(self.0.as_ptr(), index as c_int); + let name = sqlite3_column_name(self.0.as_ptr(), check_col_idx!(index)); debug_assert!(!name.is_null()); from_utf8_unchecked(CStr::from_ptr(name).to_bytes()) @@ -107,7 +122,7 @@ impl StatementHandle { #[inline] pub(crate) fn column_decltype(&self, index: usize) -> Option { unsafe { - let decl = sqlite3_column_decltype(self.0.as_ptr(), index as c_int); + let decl = sqlite3_column_decltype(self.0.as_ptr(), check_col_idx!(index)); if decl.is_null() { // If the Nth column of the result set is an expression or subquery, // then a NULL pointer is returned. @@ -123,6 +138,8 @@ impl StatementHandle { pub(crate) fn column_nullable(&self, index: usize) -> Result, Error> { unsafe { + let index = check_col_idx!(index); + // https://sqlite.org/c3ref/column_database_name.html // // ### Note @@ -130,9 +147,9 @@ impl StatementHandle { // sqlite3_finalize() or until the statement is automatically reprepared by the // first call to sqlite3_step() for a particular run or until the same information // is requested again in a different encoding. - let db_name = sqlite3_column_database_name(self.0.as_ptr(), index as c_int); - let table_name = sqlite3_column_table_name(self.0.as_ptr(), index as c_int); - let origin_name = sqlite3_column_origin_name(self.0.as_ptr(), index as c_int); + let db_name = sqlite3_column_database_name(self.0.as_ptr(), index); + let table_name = sqlite3_column_table_name(self.0.as_ptr(), index); + let origin_name = sqlite3_column_origin_name(self.0.as_ptr(), index); if db_name.is_null() || table_name.is_null() || origin_name.is_null() { return Ok(None); @@ -174,7 +191,7 @@ impl StatementHandle { #[inline] pub(crate) fn bind_parameter_count(&self) -> usize { // https://www.sqlite.org/c3ref/bind_parameter_count.html - unsafe { sqlite3_bind_parameter_count(self.0.as_ptr()) as usize } + unsafe { expect_ret_valid!(sqlite3_bind_parameter_count(self.0.as_ptr())) } } // Name Of A Host Parameter @@ -183,7 +200,7 @@ impl StatementHandle { pub(crate) fn bind_parameter_name(&self, index: usize) -> Option<&str> { unsafe { // https://www.sqlite.org/c3ref/bind_parameter_name.html - let name = sqlite3_bind_parameter_name(self.0.as_ptr(), index as c_int); + let name = sqlite3_bind_parameter_name(self.0.as_ptr(), check_col_idx!(index)); if name.is_null() { return None; } @@ -200,7 +217,7 @@ impl StatementHandle { unsafe { sqlite3_bind_blob64( self.0.as_ptr(), - index as c_int, + check_col_idx!(index), v.as_ptr() as *const c_void, v.len() as u64, SQLITE_TRANSIENT(), @@ -210,36 +227,39 @@ impl StatementHandle { #[inline] pub(crate) fn bind_text(&self, index: usize, v: &str) -> c_int { + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + let encoding = SQLITE_UTF8 as u8; + unsafe { sqlite3_bind_text64( self.0.as_ptr(), - index as c_int, + check_col_idx!(index), v.as_ptr() as *const c_char, v.len() as u64, SQLITE_TRANSIENT(), - SQLITE_UTF8 as u8, + encoding, ) } } #[inline] pub(crate) fn bind_int(&self, index: usize, v: i32) -> c_int { - unsafe { sqlite3_bind_int(self.0.as_ptr(), index as c_int, v as c_int) } + unsafe { sqlite3_bind_int(self.0.as_ptr(), check_col_idx!(index), v as c_int) } } #[inline] pub(crate) fn bind_int64(&self, index: usize, v: i64) -> c_int { - unsafe { sqlite3_bind_int64(self.0.as_ptr(), index as c_int, v) } + unsafe { sqlite3_bind_int64(self.0.as_ptr(), check_col_idx!(index), v) } } #[inline] pub(crate) fn bind_double(&self, index: usize, v: f64) -> c_int { - unsafe { sqlite3_bind_double(self.0.as_ptr(), index as c_int, v) } + unsafe { sqlite3_bind_double(self.0.as_ptr(), check_col_idx!(index), v) } } #[inline] pub(crate) fn bind_null(&self, index: usize) -> c_int { - unsafe { sqlite3_bind_null(self.0.as_ptr(), index as c_int) } + unsafe { sqlite3_bind_null(self.0.as_ptr(), check_col_idx!(index)) } } // result values from the query @@ -247,39 +267,45 @@ impl StatementHandle { #[inline] pub(crate) fn column_type(&self, index: usize) -> c_int { - unsafe { sqlite3_column_type(self.0.as_ptr(), index as c_int) } + unsafe { sqlite3_column_type(self.0.as_ptr(), check_col_idx!(index)) } } #[inline] pub(crate) fn column_int(&self, index: usize) -> i32 { - unsafe { sqlite3_column_int(self.0.as_ptr(), index as c_int) as i32 } + unsafe { sqlite3_column_int(self.0.as_ptr(), check_col_idx!(index)) as i32 } } #[inline] pub(crate) fn column_int64(&self, index: usize) -> i64 { - unsafe { sqlite3_column_int64(self.0.as_ptr(), index as c_int) as i64 } + unsafe { sqlite3_column_int64(self.0.as_ptr(), check_col_idx!(index)) as i64 } } #[inline] pub(crate) fn column_double(&self, index: usize) -> f64 { - unsafe { sqlite3_column_double(self.0.as_ptr(), index as c_int) } + unsafe { sqlite3_column_double(self.0.as_ptr(), check_col_idx!(index)) } } #[inline] pub(crate) fn column_value(&self, index: usize) -> *mut sqlite3_value { - unsafe { sqlite3_column_value(self.0.as_ptr(), index as c_int) } + unsafe { sqlite3_column_value(self.0.as_ptr(), check_col_idx!(index)) } } pub(crate) fn column_blob(&self, index: usize) -> &[u8] { - let index = index as c_int; - let len = unsafe { sqlite3_column_bytes(self.0.as_ptr(), index) } as usize; + let len = unsafe { sqlite3_column_bytes(self.0.as_ptr(), check_col_idx!(index)) }; + + // This likely means UB in SQLite itself or our usage of it; + // signed integer overflow is UB in the C standard. + let len = usize::try_from(len).unwrap_or_else(|_| { + panic!("sqlite3_value_bytes() returned value out of range for usize: {len}") + }); if len == 0 { // empty blobs are NULL so just return an empty slice return &[]; } - let ptr = unsafe { sqlite3_column_blob(self.0.as_ptr(), index) } as *const u8; + let ptr = + unsafe { sqlite3_column_blob(self.0.as_ptr(), check_col_idx!(index)) } as *const u8; debug_assert!(!ptr.is_null()); unsafe { from_raw_parts(ptr, len) } diff --git a/sqlx-sqlite/src/statement/unlock_notify.rs b/sqlx-sqlite/src/statement/unlock_notify.rs index b7e723a3f3..5821c23ae3 100644 --- a/sqlx-sqlite/src/statement/unlock_notify.rs +++ b/sqlx-sqlite/src/statement/unlock_notify.rs @@ -27,7 +27,8 @@ pub unsafe fn wait(conn: *mut sqlite3) -> Result<(), SqliteError> { unsafe extern "C" fn unlock_notify_cb(ptr: *mut *mut c_void, len: c_int) { let ptr = ptr as *mut &Notify; - let slice = slice::from_raw_parts(ptr, len as usize); + // We don't have a choice; we can't panic and unwind into FFI here. + let slice = slice::from_raw_parts(ptr, usize::try_from(len).unwrap_or(0)); for notify in slice { notify.fire(); diff --git a/sqlx-sqlite/src/statement/virtual.rs b/sqlx-sqlite/src/statement/virtual.rs index 3c17428912..6be980c36a 100644 --- a/sqlx-sqlite/src/statement/virtual.rs +++ b/sqlx-sqlite/src/statement/virtual.rs @@ -163,7 +163,13 @@ fn prepare( let mut tail: *const c_char = null(); let query_ptr = query.as_ptr() as *const c_char; - let query_len = query.len() as i32; + let query_len = i32::try_from(query.len()).map_err(|_| { + err_protocol!( + "query string too large for SQLite3 API ({} bytes); \ + try breaking it into smaller chunks (< 2 GiB), executed separately", + query.len() + ) + })?; // let status = unsafe { diff --git a/sqlx-sqlite/src/types/chrono.rs b/sqlx-sqlite/src/types/chrono.rs index c491a9aa66..7424720444 100644 --- a/sqlx-sqlite/src/types/chrono.rs +++ b/sqlx-sqlite/src/types/chrono.rs @@ -167,10 +167,20 @@ fn decode_datetime_from_float(value: f64) -> Option> { let epoch_in_julian_days = 2_440_587.5; let seconds_in_day = 86400.0; let timestamp = (value - epoch_in_julian_days) * seconds_in_day; - let seconds = timestamp as i64; - let nanos = (timestamp.fract() * 1E9) as u32; - Utc.fix().timestamp_opt(seconds, nanos).single() + if !timestamp.is_finite() { + return None; + } + + // We don't really have a choice but to do lossy casts for this conversion + // We checked above if the value is infinite or NaN which could otherwise cause problems + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + { + let seconds = timestamp.trunc() as i64; + let nanos = (timestamp.fract() * 1E9).abs() as u32; + + Utc.fix().timestamp_opt(seconds, nanos).single() + } } impl<'r> Decode<'r, Sqlite> for NaiveDateTime { diff --git a/sqlx-sqlite/src/types/float.rs b/sqlx-sqlite/src/types/float.rs index 499a694242..79224f5451 100644 --- a/sqlx-sqlite/src/types/float.rs +++ b/sqlx-sqlite/src/types/float.rs @@ -24,6 +24,8 @@ impl<'q> Encode<'q, Sqlite> for f32 { impl<'r> Decode<'r, Sqlite> for f32 { fn decode(value: SqliteValueRef<'r>) -> Result { + // Truncation is intentional + #[allow(clippy::cast_possible_truncation)] Ok(value.double() as f32) } } diff --git a/sqlx-sqlite/src/value.rs b/sqlx-sqlite/src/value.rs index 1a4d8898a4..967b3f7476 100644 --- a/sqlx-sqlite/src/value.rs +++ b/sqlx-sqlite/src/value.rs @@ -120,7 +120,13 @@ impl SqliteValue { } fn blob(&self) -> &[u8] { - let len = unsafe { sqlite3_value_bytes(self.handle.0.as_ptr()) } as usize; + let len = unsafe { sqlite3_value_bytes(self.handle.0.as_ptr()) }; + + // This likely means UB in SQLite itself or our usage of it; + // signed integer overflow is UB in the C standard. + let len = usize::try_from(len).unwrap_or_else(|_| { + panic!("sqlite3_value_bytes() returned value out of range for usize: {len}") + }); if len == 0 { // empty blobs are NULL so just return an empty slice