Skip to content

Commit

Permalink
fix(sqlite): audit for bad casts
Browse files Browse the repository at this point in the history
  • Loading branch information
abonander committed Aug 20, 2024
1 parent 3dfc305 commit e9ffde4
Show file tree
Hide file tree
Showing 10 changed files with 116 additions and 33 deletions.
16 changes: 13 additions & 3 deletions sqlx-sqlite/src/connection/collation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,23 @@ where
C: Fn(&str, &str) -> Ordering,
{
let boxed_f: *mut C = data as *mut C;
debug_assert!(!boxed_f.is_null());

// Note: unwinding is now caught at the FFI boundary:
// https://doc.rust-lang.org/nomicon/ffi.html#ffi-and-unwinding
assert!(!boxed_f.is_null());

let left_len =
usize::try_from(left_len).unwrap_or_else(|_| panic!("left_len out of range: {left_len}"));

let right_len = usize::try_from(right_len)
.unwrap_or_else(|_| panic!("right_len out of range: {right_len}"));

let s1 = {
let c_slice = slice::from_raw_parts(left_ptr as *const u8, left_len as usize);
let c_slice = slice::from_raw_parts(left_ptr as *const u8, left_len);
from_utf8_unchecked(c_slice)
};
let s2 = {
let c_slice = slice::from_raw_parts(right_ptr as *const u8, right_len as usize);
let c_slice = slice::from_raw_parts(right_ptr as *const u8, right_len);
from_utf8_unchecked(c_slice)
};
let t = (*boxed_f)(s1, s2);
Expand Down
7 changes: 7 additions & 0 deletions sqlx-sqlite/src/connection/explain.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
// Bad casts in this module SHOULD NOT result in a SQL injection
// https://github.com/launchbadge/sqlx/issues/3440
#![allow(
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_sign_loss
)]
use crate::connection::intmap::IntMap;
use crate::connection::{execute, ConnectionState};
use crate::error::Error;
Expand Down
7 changes: 7 additions & 0 deletions sqlx-sqlite/src/connection/intmap.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
// Bad casts in this module SHOULD NOT result in a SQL injection
// https://github.com/launchbadge/sqlx/issues/3440
#![allow(
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_sign_loss
)]
use std::cmp::Ordering;
use std::{fmt::Debug, hash::Hash};

Expand Down
8 changes: 8 additions & 0 deletions sqlx-sqlite/src/logger.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
// Bad casts in this module SHOULD NOT result in a SQL injection
// https://github.com/launchbadge/sqlx/issues/3440
#![allow(
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_sign_loss
)]

use crate::connection::intmap::IntMap;
use std::collections::HashSet;
use std::fmt::Debug;
Expand Down
74 changes: 50 additions & 24 deletions sqlx-sqlite/src/statement/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@ pub(crate) struct StatementHandle(NonNull<sqlite3_stmt>);

unsafe impl Send for StatementHandle {}

macro_rules! expect_ret_valid {
($fn_name:ident($($args:tt)*)) => {{
let val = $fn_name($($args)*);

TryFrom::try_from(val)
.unwrap_or_else(|_| panic!("{}() returned invalid value: {val:?}", stringify!($fn_name)))
}}
}

macro_rules! check_col_idx {
($idx:ident) => {
c_int::try_from($idx).unwrap_or_else(|_| panic!("invalid column index: {}", $idx))
};
}

// might use some of this later
#[allow(dead_code)]
impl StatementHandle {
Expand Down Expand Up @@ -71,22 +86,22 @@ impl StatementHandle {
#[inline]
pub(crate) fn column_count(&self) -> usize {
// https://sqlite.org/c3ref/column_count.html
unsafe { sqlite3_column_count(self.0.as_ptr()) as usize }
unsafe { expect_ret_valid!(sqlite3_column_count(self.0.as_ptr())) }
}

#[inline]
pub(crate) fn changes(&self) -> u64 {
// returns the number of changes of the *last* statement; not
// necessarily this statement.
// https://sqlite.org/c3ref/changes.html
unsafe { sqlite3_changes(self.db_handle()) as u64 }
unsafe { expect_ret_valid!(sqlite3_changes(self.db_handle())) }
}

#[inline]
pub(crate) fn column_name(&self, index: usize) -> &str {
// https://sqlite.org/c3ref/column_name.html
unsafe {
let name = sqlite3_column_name(self.0.as_ptr(), index as c_int);
let name = sqlite3_column_name(self.0.as_ptr(), check_col_idx!(index));
debug_assert!(!name.is_null());

from_utf8_unchecked(CStr::from_ptr(name).to_bytes())
Expand All @@ -107,7 +122,7 @@ impl StatementHandle {
#[inline]
pub(crate) fn column_decltype(&self, index: usize) -> Option<SqliteTypeInfo> {
unsafe {
let decl = sqlite3_column_decltype(self.0.as_ptr(), index as c_int);
let decl = sqlite3_column_decltype(self.0.as_ptr(), check_col_idx!(index));
if decl.is_null() {
// If the Nth column of the result set is an expression or subquery,
// then a NULL pointer is returned.
Expand All @@ -123,16 +138,18 @@ impl StatementHandle {

pub(crate) fn column_nullable(&self, index: usize) -> Result<Option<bool>, Error> {
unsafe {
let index = check_col_idx!(index);

// https://sqlite.org/c3ref/column_database_name.html
//
// ### Note
// The returned string is valid until the prepared statement is destroyed using
// sqlite3_finalize() or until the statement is automatically reprepared by the
// first call to sqlite3_step() for a particular run or until the same information
// is requested again in a different encoding.
let db_name = sqlite3_column_database_name(self.0.as_ptr(), index as c_int);
let table_name = sqlite3_column_table_name(self.0.as_ptr(), index as c_int);
let origin_name = sqlite3_column_origin_name(self.0.as_ptr(), index as c_int);
let db_name = sqlite3_column_database_name(self.0.as_ptr(), index);
let table_name = sqlite3_column_table_name(self.0.as_ptr(), index);
let origin_name = sqlite3_column_origin_name(self.0.as_ptr(), index);

if db_name.is_null() || table_name.is_null() || origin_name.is_null() {
return Ok(None);
Expand Down Expand Up @@ -174,7 +191,7 @@ impl StatementHandle {
#[inline]
pub(crate) fn bind_parameter_count(&self) -> usize {
// https://www.sqlite.org/c3ref/bind_parameter_count.html
unsafe { sqlite3_bind_parameter_count(self.0.as_ptr()) as usize }
unsafe { expect_ret_valid!(sqlite3_bind_parameter_count(self.0.as_ptr())) }
}

// Name Of A Host Parameter
Expand All @@ -183,7 +200,7 @@ impl StatementHandle {
pub(crate) fn bind_parameter_name(&self, index: usize) -> Option<&str> {
unsafe {
// https://www.sqlite.org/c3ref/bind_parameter_name.html
let name = sqlite3_bind_parameter_name(self.0.as_ptr(), index as c_int);
let name = sqlite3_bind_parameter_name(self.0.as_ptr(), check_col_idx!(index));
if name.is_null() {
return None;
}
Expand All @@ -200,7 +217,7 @@ impl StatementHandle {
unsafe {
sqlite3_bind_blob64(
self.0.as_ptr(),
index as c_int,
check_col_idx!(index),
v.as_ptr() as *const c_void,
v.len() as u64,
SQLITE_TRANSIENT(),
Expand All @@ -210,76 +227,85 @@ impl StatementHandle {

#[inline]
pub(crate) fn bind_text(&self, index: usize, v: &str) -> c_int {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let encoding = SQLITE_UTF8 as u8;

unsafe {
sqlite3_bind_text64(
self.0.as_ptr(),
index as c_int,
check_col_idx!(index),
v.as_ptr() as *const c_char,
v.len() as u64,
SQLITE_TRANSIENT(),
SQLITE_UTF8 as u8,
encoding,
)
}
}

#[inline]
pub(crate) fn bind_int(&self, index: usize, v: i32) -> c_int {
unsafe { sqlite3_bind_int(self.0.as_ptr(), index as c_int, v as c_int) }
unsafe { sqlite3_bind_int(self.0.as_ptr(), check_col_idx!(index), v as c_int) }
}

#[inline]
pub(crate) fn bind_int64(&self, index: usize, v: i64) -> c_int {
unsafe { sqlite3_bind_int64(self.0.as_ptr(), index as c_int, v) }
unsafe { sqlite3_bind_int64(self.0.as_ptr(), check_col_idx!(index), v) }
}

#[inline]
pub(crate) fn bind_double(&self, index: usize, v: f64) -> c_int {
unsafe { sqlite3_bind_double(self.0.as_ptr(), index as c_int, v) }
unsafe { sqlite3_bind_double(self.0.as_ptr(), check_col_idx!(index), v) }
}

#[inline]
pub(crate) fn bind_null(&self, index: usize) -> c_int {
unsafe { sqlite3_bind_null(self.0.as_ptr(), index as c_int) }
unsafe { sqlite3_bind_null(self.0.as_ptr(), check_col_idx!(index)) }
}

// result values from the query
// https://www.sqlite.org/c3ref/column_blob.html

#[inline]
pub(crate) fn column_type(&self, index: usize) -> c_int {
unsafe { sqlite3_column_type(self.0.as_ptr(), index as c_int) }
unsafe { sqlite3_column_type(self.0.as_ptr(), check_col_idx!(index)) }
}

#[inline]
pub(crate) fn column_int(&self, index: usize) -> i32 {
unsafe { sqlite3_column_int(self.0.as_ptr(), index as c_int) as i32 }
unsafe { sqlite3_column_int(self.0.as_ptr(), check_col_idx!(index)) as i32 }
}

#[inline]
pub(crate) fn column_int64(&self, index: usize) -> i64 {
unsafe { sqlite3_column_int64(self.0.as_ptr(), index as c_int) as i64 }
unsafe { sqlite3_column_int64(self.0.as_ptr(), check_col_idx!(index)) as i64 }
}

#[inline]
pub(crate) fn column_double(&self, index: usize) -> f64 {
unsafe { sqlite3_column_double(self.0.as_ptr(), index as c_int) }
unsafe { sqlite3_column_double(self.0.as_ptr(), check_col_idx!(index)) }
}

#[inline]
pub(crate) fn column_value(&self, index: usize) -> *mut sqlite3_value {
unsafe { sqlite3_column_value(self.0.as_ptr(), index as c_int) }
unsafe { sqlite3_column_value(self.0.as_ptr(), check_col_idx!(index)) }
}

pub(crate) fn column_blob(&self, index: usize) -> &[u8] {
let index = index as c_int;
let len = unsafe { sqlite3_column_bytes(self.0.as_ptr(), index) } as usize;
let len = unsafe { sqlite3_column_bytes(self.0.as_ptr(), check_col_idx!(index)) };

// This likely means UB in SQLite itself or our usage of it;
// signed integer overflow is UB in the C standard.
let len = usize::try_from(len).unwrap_or_else(|_| {
panic!("sqlite3_value_bytes() returned value out of range for usize: {len}")
});

if len == 0 {
// empty blobs are NULL so just return an empty slice
return &[];
}

let ptr = unsafe { sqlite3_column_blob(self.0.as_ptr(), index) } as *const u8;
let ptr =
unsafe { sqlite3_column_blob(self.0.as_ptr(), check_col_idx!(index)) } as *const u8;
debug_assert!(!ptr.is_null());

unsafe { from_raw_parts(ptr, len) }
Expand Down
3 changes: 2 additions & 1 deletion sqlx-sqlite/src/statement/unlock_notify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ pub unsafe fn wait(conn: *mut sqlite3) -> Result<(), SqliteError> {

unsafe extern "C" fn unlock_notify_cb(ptr: *mut *mut c_void, len: c_int) {
let ptr = ptr as *mut &Notify;
let slice = slice::from_raw_parts(ptr, len as usize);
// We don't have a choice; we can't panic and unwind into FFI here.
let slice = slice::from_raw_parts(ptr, usize::try_from(len).unwrap_or(0));

for notify in slice {
notify.fire();
Expand Down
8 changes: 7 additions & 1 deletion sqlx-sqlite/src/statement/virtual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,13 @@ fn prepare(
let mut tail: *const c_char = null();

let query_ptr = query.as_ptr() as *const c_char;
let query_len = query.len() as i32;
let query_len = i32::try_from(query.len()).map_err(|_| {
err_protocol!(
"query string too large for SQLite3 API ({} bytes); \
try breaking it into smaller chunks (< 2 GiB), executed separately",
query.len()
)
})?;

// <https://www.sqlite.org/c3ref/prepare.html>
let status = unsafe {
Expand Down
16 changes: 13 additions & 3 deletions sqlx-sqlite/src/types/chrono.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,20 @@ fn decode_datetime_from_float(value: f64) -> Option<DateTime<FixedOffset>> {
let epoch_in_julian_days = 2_440_587.5;
let seconds_in_day = 86400.0;
let timestamp = (value - epoch_in_julian_days) * seconds_in_day;
let seconds = timestamp as i64;
let nanos = (timestamp.fract() * 1E9) as u32;

Utc.fix().timestamp_opt(seconds, nanos).single()
if !timestamp.is_finite() {
return None;
}

// We don't really have a choice but to do lossy casts for this conversion
// We checked above if the value is infinite or NaN which could otherwise cause problems
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
{
let seconds = timestamp.trunc() as i64;
let nanos = (timestamp.fract() * 1E9).abs() as u32;

Utc.fix().timestamp_opt(seconds, nanos).single()
}
}

impl<'r> Decode<'r, Sqlite> for NaiveDateTime {
Expand Down
2 changes: 2 additions & 0 deletions sqlx-sqlite/src/types/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ impl<'q> Encode<'q, Sqlite> for f32 {

impl<'r> Decode<'r, Sqlite> for f32 {
fn decode(value: SqliteValueRef<'r>) -> Result<f32, BoxDynError> {
// Truncation is intentional
#[allow(clippy::cast_possible_truncation)]
Ok(value.double() as f32)
}
}
Expand Down
8 changes: 7 additions & 1 deletion sqlx-sqlite/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,13 @@ impl SqliteValue {
}

fn blob(&self) -> &[u8] {
let len = unsafe { sqlite3_value_bytes(self.handle.0.as_ptr()) } as usize;
let len = unsafe { sqlite3_value_bytes(self.handle.0.as_ptr()) };

// This likely means UB in SQLite itself or our usage of it;
// signed integer overflow is UB in the C standard.
let len = usize::try_from(len).unwrap_or_else(|_| {
panic!("sqlite3_value_bytes() returned value out of range for usize: {len}")
});

if len == 0 {
// empty blobs are NULL so just return an empty slice
Expand Down

0 comments on commit e9ffde4

Please sign in to comment.