diff --git a/sqlx-mysql/src/protocol/statement/row.rs b/sqlx-mysql/src/protocol/statement/row.rs index bf518f491f..dfbc4a1e96 100644 --- a/sqlx-mysql/src/protocol/statement/row.rs +++ b/sqlx-mysql/src/protocol/statement/row.rs @@ -45,7 +45,20 @@ impl<'de> Decode<'de, &'de [MySqlColumn]> for BinaryRow { // NOTE: MySQL will never generate NULL types for non-NULL values let type_info = &column.type_info; + // Unlike Postgres, MySQL does not length-prefix every value in a binary row. + // Values are *either* fixed-length or length-prefixed, + // so we need to inspect the type code to be sure. let size: usize = match type_info.r#type { + // All fixed-length types. + ColumnType::LongLong => 8, + ColumnType::Long | ColumnType::Int24 => 4, + ColumnType::Short | ColumnType::Year => 2, + ColumnType::Tiny => 1, + ColumnType::Float => 4, + ColumnType::Double => 8, + + // Blobs and strings are prefixed with their length, + // which is itself a length-encoded integer. ColumnType::String | ColumnType::VarChar | ColumnType::VarString @@ -61,20 +74,15 @@ impl<'de> Decode<'de, &'de [MySqlColumn]> for BinaryRow { | ColumnType::Json | ColumnType::NewDecimal => buf.get_uint_lenenc() as usize, - ColumnType::LongLong => 8, - ColumnType::Long | ColumnType::Int24 => 4, - ColumnType::Short | ColumnType::Year => 2, - ColumnType::Tiny => 1, - ColumnType::Float => 4, - ColumnType::Double => 8, - + // Like strings and blobs, these values are variable-length. + // Unlike strings and blobs, however, they exclusively use one byte for length. ColumnType::Time | ColumnType::Timestamp | ColumnType::Date | ColumnType::Datetime => { - // The size of this type is important for decoding + // Leave the length byte on the front of the value because decoding uses it. buf[0] as usize + 1 - } + }, // NOTE: MySQL will never generate NULL types for non-NULL values ColumnType::Null => unreachable!(), diff --git a/sqlx-mysql/src/types/chrono.rs b/sqlx-mysql/src/types/chrono.rs index 7da10434ab..dbbdc20864 100644 --- a/sqlx-mysql/src/types/chrono.rs +++ b/sqlx-mysql/src/types/chrono.rs @@ -1,14 +1,13 @@ use bytes::Buf; -use chrono::{ - DateTime, Datelike, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Timelike, Utc, -}; +use chrono::{DateTime, Datelike, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Timelike, Utc}; +use sqlx_core::database::Database; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::{BoxDynError, UnexpectedNullError}; use crate::protocol::text::ColumnType; use crate::type_info::MySqlTypeInfo; -use crate::types::Type; +use crate::types::{MySqlTime, MySqlTimeSign, Type}; use crate::{MySql, MySqlValueFormat, MySqlValueRef}; impl Type for DateTime { @@ -63,7 +62,7 @@ impl<'r> Decode<'r, MySql> for DateTime { impl Type for NaiveTime { fn type_info() -> MySqlTypeInfo { - MySqlTypeInfo::binary(ColumnType::Time) + MySqlTime::type_info() } } @@ -75,7 +74,7 @@ impl Encode<'_, MySql> for NaiveTime { // NaiveTime is not negative buf.push(0); - // "date on 4 bytes little-endian format" (?) + // Number of days in the interval; always 0 for time-of-day values. // https://mariadb.com/kb/en/resultset-row/#teimstamp-binary-encoding buf.extend_from_slice(&[0_u8; 4]); @@ -95,34 +94,18 @@ impl Encode<'_, MySql> for NaiveTime { } } +/// Decode from a `TIME` value. +/// +/// ### Errors +/// Returns an error if the `TIME` value is negative or exceeds `23:59:59.999999`. impl<'r> Decode<'r, MySql> for NaiveTime { fn decode(value: MySqlValueRef<'r>) -> Result { match value.format() { MySqlValueFormat::Binary => { - let mut buf = value.as_bytes()?; - - // data length, expecting 8 or 12 (fractional seconds) - let len = buf.get_u8(); - - // MySQL specifies that if all of hours, minutes, seconds, microseconds - // are 0 then the length is 0 and no further data is send - // https://dev.mysql.com/doc/internals/en/binary-protocol-value.html - if len == 0 { - return Ok(NaiveTime::from_hms_micro_opt(0, 0, 0, 0) - .expect("expected NaiveTime to construct from all zeroes")); - } - - // is negative : int<1> - let is_negative = buf.get_u8(); - debug_assert_eq!(is_negative, 0, "Negative dates/times are not supported"); - - // "date on 4 bytes little-endian format" (?) - // https://mariadb.com/kb/en/resultset-row/#timestamp-binary-encoding - buf.advance(4); - - decode_time(len - 5, buf) + // Covers most possible failure modes. + MySqlTime::decode(value)?.try_into() } - + // Retaining this parsing for now as it allows us to cross-check our impl. MySqlValueFormat::Text => { let s = value.as_str()?; NaiveTime::parse_from_str(s, "%H:%M:%S%.f").map_err(Into::into) @@ -131,6 +114,55 @@ impl<'r> Decode<'r, MySql> for NaiveTime { } } +impl TryFrom for NaiveTime { + type Error = BoxDynError; + + fn try_from(time: MySqlTime) -> Result { + NaiveTime::from_hms_micro_opt( + time.hours(), + time.minutes() as u32, + time.seconds() as u32, + time.microseconds(), + ) + .ok_or_else(|| format!("Cannot convert `MySqlTime` value to `NaiveTime`: {time}").into()) + } +} + +impl From for chrono::TimeDelta { + fn from(time: MySqlTime) -> Self { + chrono::TimeDelta::new(time.whole_seconds_signed(), time.subsec_nanos()) + .expect("BUG: chrono::TimeDelta should have a greater range than MySqlTime") + } +} + +impl TryFrom for MySqlTime { + type Error = BoxDynError; + + fn try_from(value: chrono::TimeDelta) -> Result { + let sign = if value < chrono::TimeDelta::zero() { MySqlTimeSign::Negative } else { + MySqlTimeSign::Positive + }; + + Ok( + // `std::time::Duration` has a greater positive range than `TimeDelta` + // which makes it a great intermediate if you ignore the sign. + MySqlTime::try_from(value.abs().to_std()?)?.with_sign(sign) + ) + } +} + +impl Type for chrono::TimeDelta { + fn type_info() -> MySqlTypeInfo { + MySqlTime::type_info() + } +} + +impl<'r> Decode<'r, MySql> for chrono::TimeDelta { + fn decode(value: ::ValueRef<'r>) -> Result { + Ok(MySqlTime::decode(value)?.into()) + } +} + impl Type for NaiveDate { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Date) @@ -155,7 +187,14 @@ impl<'r> Decode<'r, MySql> for NaiveDate { fn decode(value: MySqlValueRef<'r>) -> Result { match value.format() { MySqlValueFormat::Binary => { - decode_date(&value.as_bytes()?[1..])?.ok_or_else(|| UnexpectedNullError.into()) + let buf = value.as_bytes()?; + + // Row decoding should have left the length prefix. + if buf.is_empty() { + return Err("empty buffer".into()); + } + + decode_date(&buf[1..])?.ok_or_else(|| UnexpectedNullError.into()) } MySqlValueFormat::Text => { @@ -214,6 +253,10 @@ impl<'r> Decode<'r, MySql> for NaiveDateTime { MySqlValueFormat::Binary => { let buf = value.as_bytes()?; + if buf.is_empty() { + return Err("empty buffer".into()); + } + let len = buf[0]; let date = decode_date(&buf[1..])?.ok_or(UnexpectedNullError)?; diff --git a/sqlx-mysql/src/types/float.rs b/sqlx-mysql/src/types/float.rs index 6df7c6cec3..0b36a5e875 100644 --- a/sqlx-mysql/src/types/float.rs +++ b/sqlx-mysql/src/types/float.rs @@ -8,6 +8,7 @@ use crate::types::Type; use crate::{MySql, MySqlTypeInfo, MySqlValueFormat, MySqlValueRef}; fn real_compatible(ty: &MySqlTypeInfo) -> bool { + // NOTE: `DECIMAL` is explicitly excluded because floating-point numbers have different semantics. matches!(ty.r#type, ColumnType::Float | ColumnType::Double) } @@ -53,12 +54,22 @@ impl Decode<'_, MySql> for f32 { MySqlValueFormat::Binary => { let buf = value.as_bytes()?; - if buf.len() == 8 { + match buf.len() { + // These functions panic if `buf` is not exactly the right size. + 4 => LittleEndian::read_f32(buf), // MySQL can return 8-byte DOUBLE values for a FLOAT - // We take and truncate to f32 as that's the same behavior as *in* MySQL - LittleEndian::read_f64(buf) as f32 - } else { - LittleEndian::read_f32(buf) + // We take and truncate to f32 as that's the same behavior as *in* MySQL, + 8 => LittleEndian::read_f64(buf) as f32, + other => { + // Users may try to decode a DECIMAL as floating point; + // inform them why that's a bad idea. + return Err(format!( + "expected a FLOAT as 4 or 8 bytes, got {other} bytes; \ + note that decoding DECIMAL as `f32` is not supported \ + due to differing semantics" + ) + .into()); + } } } @@ -70,7 +81,26 @@ impl Decode<'_, MySql> for f32 { impl Decode<'_, MySql> for f64 { fn decode(value: MySqlValueRef<'_>) -> Result { Ok(match value.format() { - MySqlValueFormat::Binary => LittleEndian::read_f64(value.as_bytes()?), + MySqlValueFormat::Binary => { + let buf = value.as_bytes()?; + + // The `read_*` functions panic if `buf` is not exactly the right size. + match buf.len() { + // Allow implicit widening here + 4 => LittleEndian::read_f32(buf) as f64, + 8 => LittleEndian::read_f64(buf), + other => { + // Users may try to decode a DECIMAL as floating point; + // inform them why that's a bad idea. + return Err(format!( + "expected a DOUBLE as 4 or 8 bytes, got {other} bytes; \ + note that decoding DECIMAL as `f64` is not supported \ + due to differing semantics" + ) + .into()); + } + } + } MySqlValueFormat::Text => value.as_str()?.parse()?, }) } diff --git a/sqlx-mysql/src/types/int.rs b/sqlx-mysql/src/types/int.rs index 18a64f3155..c4896fa933 100644 --- a/sqlx-mysql/src/types/int.rs +++ b/sqlx-mysql/src/types/int.rs @@ -95,6 +95,20 @@ fn int_decode(value: MySqlValueRef<'_>) -> Result { MySqlValueFormat::Text => value.as_str()?.parse()?, MySqlValueFormat::Binary => { let buf = value.as_bytes()?; + + // Check conditions that could cause `read_int()` to panic. + if buf.is_empty() { + return Err("empty buffer".into()); + } + + if buf.len() > 8 { + return Err(format!( + "expected no more than 8 bytes for integer value, got {}", + buf.len() + ) + .into()); + } + LittleEndian::read_int(buf, buf.len()) } }) diff --git a/sqlx-mysql/src/types/mod.rs b/sqlx-mysql/src/types/mod.rs index 9b7ef29fc7..8408a78baf 100644 --- a/sqlx-mysql/src/types/mod.rs +++ b/sqlx-mysql/src/types/mod.rs @@ -20,6 +20,8 @@ //! | `IpAddr` | VARCHAR, TEXT | //! | `Ipv4Addr` | INET4 (MariaDB-only), VARCHAR, TEXT | //! | `Ipv6Addr` | INET6 (MariaDB-only), VARCHAR, TEXT | +//! | [`MySqlTime`] | TIME (encode and decode full range) | +//! | [`Duration`] | TIME (for decoding positive values only) | //! //! ##### Note: `BOOLEAN`/`BOOL` Type //! MySQL and MariaDB treat `BOOLEAN` as an alias of the `TINYINT` type: @@ -38,6 +40,12 @@ //! Thus, you must use the type override syntax in the query to tell the macros you are expecting //! a `bool` column. See the docs for `query!()` and `query_as!()` for details on this syntax. //! +//! ### NOTE: MySQL's `TIME` type is signed +//! MySQL's `TIME` type can be used as either a time-of-day value, or a signed interval. +//! Thus, it may take on negative values. +//! +//! Decoding a [`std::time::Duration`] returns an error if the `TIME` value is negative. +//! //! ### [`chrono`](https://crates.io/crates/chrono) //! //! Requires the `chrono` Cargo feature flag. @@ -48,7 +56,20 @@ //! | `chrono::DateTime` | TIMESTAMP | //! | `chrono::NaiveDateTime` | DATETIME | //! | `chrono::NaiveDate` | DATE | -//! | `chrono::NaiveTime` | TIME | +//! | `chrono::NaiveTime` | TIME (time-of-day only) | +//! | `chrono::TimeDelta` | TIME (decodes full range; see note for encoding) | +//! +//! ### NOTE: MySQL's `TIME` type is dual-purpose +//! MySQL's `TIME` type can be used as either a time-of-day value, or an interval. +//! However, `chrono::NaiveTime` is designed only to represent a time-of-day. +//! +//! Decoding a `TIME` value as `chrono::NaiveTime` will return an error if the value is out of range. +//! +//! The [`MySqlTime`] type supports the full range and it also implements `TryInto`. +//! +//! Decoding a `chrono::TimeDelta` also supports the full range. +//! +//! To encode a `chrono::TimeDelta`, convert it to [`MySqlTime`] first using `TryFrom`/`TryInto`. //! //! ### [`time`](https://crates.io/crates/time) //! @@ -59,7 +80,20 @@ //! | `time::PrimitiveDateTime` | DATETIME | //! | `time::OffsetDateTime` | TIMESTAMP | //! | `time::Date` | DATE | -//! | `time::Time` | TIME | +//! | `time::Time` | TIME (time-of-day only) | +//! | `time::Duration` | TIME (decodes full range; see note for encoding) | +//! +//! ### NOTE: MySQL's `TIME` type is dual-purpose +//! MySQL's `TIME` type can be used as either a time-of-day value, or an interval. +//! However, `time::Time` is designed only to represent a time-of-day. +//! +//! Decoding a `TIME` value as `time::Time` will return an error if the value is out of range. +//! +//! The [`MySqlTime`] type supports the full range, and it also implements `TryInto`. +//! +//! Decoding a `time::Duration` also supports the full range. +//! +//! To encode a `time::Duration`, convert it to [`MySqlTime`] first using `TryFrom`/`TryInto`. //! //! ### [`bigdecimal`](https://crates.io/crates/bigdecimal) //! Requires the `bigdecimal` Cargo feature flag. @@ -102,11 +136,14 @@ pub(crate) use sqlx_core::types::*; +pub use mysql_time::{MySqlTime, MySqlTimeError, MySqlTimeSign}; + mod bool; mod bytes; mod float; mod inet; mod int; +mod mysql_time; mod str; mod text; mod uint; diff --git a/sqlx-mysql/src/types/mysql_time.rs b/sqlx-mysql/src/types/mysql_time.rs new file mode 100644 index 0000000000..b2071ec07f --- /dev/null +++ b/sqlx-mysql/src/types/mysql_time.rs @@ -0,0 +1,720 @@ +//! The [`MysqlTime`] type. + +use crate::protocol::text::ColumnType; +use crate::{MySql, MySqlTypeInfo, MySqlValueFormat}; +use bytes::{Buf, BufMut}; +use sqlx_core::database::Database; +use sqlx_core::decode::Decode; +use sqlx_core::encode::{Encode, IsNull}; +use sqlx_core::error::BoxDynError; +use sqlx_core::types::Type; +use std::cmp::Ordering; +use std::fmt::{Debug, Display, Formatter, Write}; +use std::time::Duration; + +// Similar to `PgInterval` +/// Container for a MySQL `TIME` value, which may be an interval or a time-of-day. +/// +/// Allowed range is `-838:59:59.0` to `838:59:59.0`. +/// +/// If this value is used for a time-of-day, the range should be `00:00:00.0` to `23:59:59.999999`. +/// You can use [`Self::is_time_of_day()`] to check this easily. +/// +/// * [MySQL Manual 13.2.3: The TIME Type](https://dev.mysql.com/doc/refman/8.3/en/time.html) +/// * [MariaDB Manual: TIME](https://mariadb.com/kb/en/time/) +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct MySqlTime { + pub(crate) sign: MySqlTimeSign, + pub(crate) magnitude: TimeMagnitude, +} + +// By using a subcontainer for the actual time magnitude, +// we can still use a derived `Ord` implementation and just flip the comparison for negative values. +#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] +pub(crate) struct TimeMagnitude { + pub(crate) hours: u32, + pub(crate) minutes: u8, + pub(crate) seconds: u8, + pub(crate) microseconds: u32, +} + +const MAGNITUDE_ZERO: TimeMagnitude = TimeMagnitude { + hours: 0, + minutes: 0, + seconds: 0, + microseconds: 0, +}; + +/// Maximum magnitude (positive or negative). +const MAGNITUDE_MAX: TimeMagnitude = TimeMagnitude { + hours: MySqlTime::HOURS_MAX, + minutes: 59, + seconds: 59, + // Surprisingly this is not 999_999 which is why `MySqlTimeError::SubsecondExcess`. + microseconds: 0, +}; + +/// The sign for a [`MySqlTime`] type. +#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] +pub enum MySqlTimeSign { + // The protocol actually specifies negative as 1 and positive as 0, + // but by specifying variants this way we can derive `Ord` and it works as expected. + /// The interval is negative (invalid for time-of-day values). + Negative, + /// The interval is positive, or represents a time-of-day. + Positive, +} + +/// Errors returned by [`MySqlTime::new()`]. +#[derive(Debug, thiserror::Error)] +pub enum MySqlTimeError { + /// A field of [`MySqlTime`] exceeded its max range. + #[error("`MySqlTime` field `{field}` cannot exceed {max}, got {value}")] + FieldRange { + field: &'static str, + max: u32, + value: u64, + }, + /// Error returned for time magnitudes (positive or negative) between `838:59:59.0` and `839:00:00.0`. + /// + /// Other range errors should be covered by [`Self::FieldRange`] for the `hours` field. + /// + /// For applications which can tolerate rounding, a valid truncated value is provided. + #[error("`MySqlTime` cannot exceed +/-838:59:59.000000; got {sign}838:59:59.{microseconds:06}")] + SubsecondExcess { + /// The sign of the magnitude. + sign: MySqlTimeSign, + /// The number of microseconds over the maximum. + microseconds: u32, + /// The truncated value, + /// either [`MySqlTime::MIN`] if negative or [`MySqlTime::MAX`] if positive. + truncated: MySqlTime, + }, + /// MySQL coerces `-00:00:00` to `00:00:00` but this API considers that an error. + /// + /// For applications which can tolerate coercion, you can convert this error to [`MySqlTime::ZERO`]. + #[error("attempted to construct a `MySqlTime` value of negative zero")] + NegativeZero, +} + +impl MySqlTime { + /// The `MySqlTime` value corresponding to `TIME '0:00:00.0'` (zero). + pub const ZERO: Self = MySqlTime { + sign: MySqlTimeSign::Positive, + magnitude: MAGNITUDE_ZERO, + }; + + /// The `MySqlTime` value corresponding to `TIME '838:59:59.0'` (max value). + pub const MAX: Self = MySqlTime { + sign: MySqlTimeSign::Positive, + magnitude: MAGNITUDE_MAX, + }; + + /// The `MySqlTime` value corresponding to `TIME '-838:59:59.0'` (min value). + pub const MIN: Self = MySqlTime { + sign: MySqlTimeSign::Negative, + // Same magnitude, opposite sign. + magnitude: MAGNITUDE_MAX, + }; + + // The maximums for the other values are self-evident, but not necessarily this one. + pub(crate) const HOURS_MAX: u32 = 838; + + /// Construct a [`MySqlTime`] that is valid for use as a `TIME` value. + /// + /// ### Errors + /// * [`MySqlTimeError::NegativeZero`] if all fields are 0 but `sign` is [`MySqlSign::Negative`]. + /// * [`MySqlTimeError::FieldRange`] if any field is out of range: + /// * `hours > 838` + /// * `minutes > 59` + /// * `seconds > 59` + /// * `microseconds > 999_999` + /// * [`MySqlTimeError::SubsecondExcess`] if the magnitude is less than one second over the maximum. + /// * Durations 839 hours or greater are covered by `FieldRange`. + pub fn new( + sign: MySqlTimeSign, + hours: u32, + minutes: u8, + seconds: u8, + microseconds: u32, + ) -> Result { + macro_rules! check_fields { + ($($name:ident: $max:expr),+ $(,)?) => { + $( + if $name > $max { + return Err(MySqlTimeError::FieldRange { + field: stringify!($name), + max: $max as u32, + value: $name as u64 + }) + } + )+ + } + } + + check_fields!( + hours: Self::HOURS_MAX, + minutes: 59, + seconds: 59, + microseconds: 999_999 + ); + + let values = TimeMagnitude { + hours, + minutes, + seconds, + microseconds, + }; + + if sign.is_negative() && values == MAGNITUDE_ZERO { + return Err(MySqlTimeError::NegativeZero); + } + + // This is only `true` if less than 1 second over the maximum magnitude + if values > MAGNITUDE_MAX { + return Err(MySqlTimeError::SubsecondExcess { + sign, + microseconds, + truncated: if sign.is_positive() { + Self::MAX + } else { + Self::MIN + }, + }); + } + + Ok(Self { + sign, + magnitude: values, + }) + } + + /// Update the `sign` of this value. + pub fn with_sign(self, sign: MySqlTimeSign) -> Self { + Self { sign, ..self } + } + + /// Return the sign (positive or negative) for this TIME value. + pub fn sign(&self) -> MySqlTimeSign { + self.sign + } + + /// Returns `true` if `self` is zero (equal to [`Self::ZERO`]). + pub fn is_zero(&self) -> bool { + self == &Self::ZERO + } + + /// Returns `true` if `self` is positive or zero, `false` if negative. + pub fn is_positive(&self) -> bool { + self.sign.is_positive() + } + + /// Returns `true` if `self` is negative, `false` if positive or zero. + pub fn is_negative(&self) -> bool { + self.sign.is_positive() + } + + /// Returns `true` if this interval is a valid time-of-day. + /// + /// If `true`, the sign is positive and `hours` is not greater than 23. + pub fn is_valid_time_of_day(&self) -> bool { + self.sign.is_positive() && self.hours() < 24 + } + + /// Get the total number of hours in this interval, from 0 to 838. + /// + /// If this value represents a time-of-day, the range is 0 to 23. + pub fn hours(&self) -> u32 { + self.magnitude.hours + } + + /// Get the number of minutes in this interval, from 0 to 59. + pub fn minutes(&self) -> u8 { + self.magnitude.minutes + } + + /// Get the number of seconds in this interval, from 0 to 59. + pub fn seconds(&self) -> u8 { + self.magnitude.seconds + } + + /// Get the number of seconds in this interval, from 0 to 999,999. + pub fn microseconds(&self) -> u32 { + self.magnitude.microseconds + } + + /// Convert this TIME value to a [`std::time::Duration`]. + /// + /// Returns `None` if this value is negative (cannot be represented). + pub fn to_duration(&self) -> Option { + self.is_positive() + .then(|| Duration::new(self.whole_seconds() as u64, self.subsec_nanos())) + } + + /// Get the whole number of seconds (`seconds + (minutes * 60) + (hours * 3600)`) in this time. + /// + /// Sign is ignored. + pub(crate) fn whole_seconds(&self) -> u32 { + // If `hours` does not exceed 838 then this cannot overflow. + self.hours() * 3600 + self.minutes() as u32 * 60 + self.seconds() as u32 + } + + #[cfg_attr(not(any(feature = "time", feature = "chrono")), allow(dead_code))] + pub(crate) fn whole_seconds_signed(&self) -> i64 { + self.whole_seconds() as i64 * self.sign.signum() as i64 + } + + pub(crate) fn subsec_nanos(&self) -> u32 { + self.microseconds() * 1000 + } + + fn encoded_len(&self) -> u8 { + if self.is_zero() { + 0 + } else if self.microseconds() == 0 { + 8 + } else { + 12 + } + } +} + +impl PartialOrd for MySqlTime { + fn partial_cmp(&self, other: &MySqlTime) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for MySqlTime { + fn cmp(&self, other: &Self) -> Ordering { + // If the sides have different signs, we just need to compare those. + if self.sign != other.sign { + return self.sign.cmp(&other.sign); + } + + // We've checked that both sides have the same sign + match self.sign { + MySqlTimeSign::Positive => self.magnitude.cmp(&other.magnitude), + // Reverse the comparison for negative values (smaller negative magnitude = greater) + MySqlTimeSign::Negative => other.magnitude.cmp(&self.magnitude), + } + } +} + +impl Display for MySqlTime { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let TimeMagnitude { + hours, + minutes, + seconds, + microseconds, + } = self.magnitude; + + // Obeys the `+` flag. + Display::fmt(&self.sign(), f)?; + + write!(f, "{hours}:{minutes:02}:{seconds:02}")?; + + // Write microseconds if not zero or a nonzero precision was explicitly requested. + if f.precision().map_or(microseconds != 0, |it| it != 0) { + f.write_char('.')?; + + let mut remaining_precision = f.precision(); + let mut remainder = microseconds; + let mut power_of_10 = 10u32.pow(5); + + // Write digits from most-significant to least, up to the requested precision. + while remainder > 0 && remaining_precision != Some(0) { + let digit = remainder / power_of_10; + // 1 % 1 = 0 + remainder %= power_of_10; + power_of_10 /= 10; + + write!(f, "{digit}")?; + + if let Some(remaining_precision) = &mut remaining_precision { + *remaining_precision = remaining_precision.saturating_sub(1); + } + } + + // If any requested precision remains, pad with zeroes. + if let Some(precision) = remaining_precision.filter(|it| *it != 0) { + write!(f, "{:0precision$}", 0)?; + } + } + + Ok(()) + } +} + +impl Type for MySqlTime { + fn type_info() -> MySqlTypeInfo { + MySqlTypeInfo::binary(ColumnType::Time) + } +} + +impl<'r> Decode<'r, MySql> for MySqlTime { + fn decode(value: ::ValueRef<'r>) -> Result { + match value.format() { + MySqlValueFormat::Binary => { + let mut buf = value.as_bytes()?; + + // Row decoding should have left the length byte on the front. + if buf.is_empty() { + return Err("empty buffer".into()); + } + + let length = buf.get_u8(); + + // MySQL specifies that if all fields are 0 then the length is 0 and no further data is sent + // https://dev.mysql.com/doc/internals/en/binary-protocol-value.html + if length == 0 { + return Ok(Self::ZERO); + } + + if !matches!(buf.len(), 8 | 12) { + return Err(format!( + "expected 8 or 12 bytes for TIME value, got {}", + buf.len() + ) + .into()); + } + + let sign = MySqlTimeSign::from_byte(buf.get_u8())?; + // The wire protocol includes days but the text format doesn't. Isn't that crazy? + let days = buf.get_u32_le(); + let hours = buf.get_u8(); + let minutes = buf.get_u8(); + let seconds = buf.get_u8(); + + let microseconds = if !buf.is_empty() { buf.get_u32_le() } else { 0 }; + + let whole_hours = days + .checked_mul(24) + .and_then(|days_to_hours| days_to_hours.checked_add(hours as u32)) + .ok_or("overflow calculating whole hours from `days * 24 + hours`")?; + + Ok(Self::new( + sign, + whole_hours, + minutes, + seconds, + microseconds, + )?) + } + MySqlValueFormat::Text => parse(value.as_str()?), + } + } +} + +impl<'q> Encode<'q, MySql> for MySqlTime { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer<'q>) -> IsNull { + if self.is_zero() { + buf.put_u8(0); + return IsNull::No; + } + + buf.put_u8(self.encoded_len()); + buf.put_u8(self.sign.to_byte()); + + let TimeMagnitude { + hours: whole_hours, + minutes, + seconds, + microseconds, + } = self.magnitude; + + let days = whole_hours / 24; + let hours = (whole_hours % 24) as u8; + + buf.put_u32_le(days); + buf.put_u8(hours); + buf.put_u8(minutes); + buf.put_u8(seconds); + + if microseconds != 0 { + buf.put_u32_le(microseconds); + } + + IsNull::No + } + + fn size_hint(&self) -> usize { + self.encoded_len() as usize + 1 + } +} + +/// Convert [`MySqlTime`] from [`std::time::Duration`]. +/// +/// ### Note: Precision Truncation +/// [`Duration`] supports nanosecond precision, but MySQL `TIME` values only support microsecond +/// precision. +/// +/// For simplicity, higher precision values are truncated when converting. +/// If you prefer another rounding mode instead, you should apply that to the `Duration` first. +/// +/// See also: [MySQL Manual, section 13.2.6: Fractional Seconds in Time Values](https://dev.mysql.com/doc/refman/8.3/en/fractional-seconds.html) +/// +/// ### Errors: +/// Returns [`MySqlTimeError::FieldRange`] if the given duration is longer than `838:59:59.999999`. +/// +impl TryFrom for MySqlTime { + type Error = MySqlTimeError; + + fn try_from(value: Duration) -> Result { + let hours = value.as_secs() / 3600; + let rem_seconds = value.as_secs() % 3600; + let minutes = (rem_seconds / 60) as u8; + let seconds = (rem_seconds % 60) as u8; + + // Simply divides by 1000 + let microseconds = value.subsec_micros(); + + Self::new( + MySqlTimeSign::Positive, + hours.try_into().map_err(|_| MySqlTimeError::FieldRange { + field: "hours", + max: Self::HOURS_MAX, + value: hours, + })?, + minutes, + seconds, + microseconds, + ) + } +} + +impl MySqlTimeSign { + fn from_byte(b: u8) -> Result { + match b { + 0 => Ok(Self::Positive), + 1 => Ok(Self::Negative), + other => Err(format!("expected 0 or 1 for TIME sign byte, got {other}").into()), + } + } + + fn to_byte(&self) -> u8 { + match self { + // We can't use `#[repr(u8)]` because this is opposite of the ordering we want from `Ord` + Self::Negative => 1, + Self::Positive => 0, + } + } + + fn signum(&self) -> i32 { + match self { + Self::Negative => -1, + Self::Positive => 1, + } + } + + /// Returns `true` if positive, `false` if negative. + pub fn is_positive(&self) -> bool { + matches!(self, Self::Positive) + } + + /// Returns `true` if negative, `false` if positive. + pub fn is_negative(&self) -> bool { + matches!(self, Self::Negative) + } +} + +impl Display for MySqlTimeSign { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Positive if f.sign_plus() => f.write_char('+'), + Self::Negative => f.write_char('-'), + _ => Ok(()), + } + } +} + +impl Type for Duration { + fn type_info() -> MySqlTypeInfo { + MySqlTime::type_info() + } +} + +impl<'r> Decode<'r, MySql> for Duration { + fn decode(value: ::ValueRef<'r>) -> Result { + let time = MySqlTime::decode(value)?; + + time.to_duration().ok_or_else(|| { + format!("`std::time::Duration` can only decode positive TIME values; got {time}").into() + }) + } +} + +// Not exposing this as a `FromStr` impl currently because `MySqlTime` is not designed to be +// a general interchange type. +fn parse(text: &str) -> Result { + let mut segments = text.split(':'); + + let hours = segments + .next() + .ok_or("expected hours segment, got nothing")?; + + let minutes = segments + .next() + .ok_or("expected minutes segment, got nothing")?; + + let seconds = segments + .next() + .ok_or("expected seconds segment, got nothing")?; + + // Include the sign in parsing for convenience; + // the allowed range of whole hours is much smaller than `i32`'s positive range. + let hours: i32 = hours + .parse() + .map_err(|e| format!("error parsing hours from {text:?} (segment {hours:?}): {e}"))?; + + let sign = if hours.is_negative() { + MySqlTimeSign::Negative + } else { + MySqlTimeSign::Positive + }; + + let hours = hours.abs() as u32; + + let minutes: u8 = minutes + .parse() + .map_err(|e| format!("error parsing minutes from {text:?} (segment {minutes:?}): {e}"))?; + + let (seconds, microseconds): (u8, u32) = if let Some((seconds, microseconds)) = + seconds.split_once('.') + { + ( + seconds.parse().map_err(|e| { + format!("error parsing seconds from {text:?} (segment {seconds:?}): {e}") + })?, + parse_microseconds(microseconds).map_err(|e| { + format!("error parsing microseconds from {text:?} (segment {microseconds:?}): {e}") + })?, + ) + } else { + ( + seconds.parse().map_err(|e| { + format!("error parsing seconds from {text:?} (segment {seconds:?}): {e}") + })?, + 0, + ) + }; + + Ok(MySqlTime::new(sign, hours, minutes, seconds, microseconds)?) +} + +/// Parse microseconds from a fractional seconds string. +fn parse_microseconds(micros: &str) -> Result { + const EXPECTED_DIGITS: usize = 6; + + match micros.len() { + 0 => Err("empty string".into()), + len @ ..= EXPECTED_DIGITS => { + // Fewer than 6 digits, multiply to the correct magnitude + let micros: u32 = micros.parse()?; + Ok(micros * 10u32.pow((EXPECTED_DIGITS - len) as u32)) + } + // More digits than expected, truncate + _ => { + Ok(micros[..EXPECTED_DIGITS].parse()?) + } + } + +} + +#[cfg(test)] +mod tests { + use super::MySqlTime; + use crate::types::MySqlTimeSign; + + use super::parse_microseconds; + + #[test] + fn test_display() { + assert_eq!(MySqlTime::ZERO.to_string(), "0:00:00"); + + assert_eq!(format!("{:.0}", MySqlTime::ZERO), "0:00:00"); + + assert_eq!(format!("{:.3}", MySqlTime::ZERO), "0:00:00.000"); + + assert_eq!(format!("{:.6}", MySqlTime::ZERO), "0:00:00.000000"); + + assert_eq!(format!("{:.9}", MySqlTime::ZERO), "0:00:00.000000000"); + + assert_eq!(format!("{:.0}", MySqlTime::MAX), "838:59:59"); + + assert_eq!(format!("{:.3}", MySqlTime::MAX), "838:59:59.000"); + + assert_eq!(format!("{:.6}", MySqlTime::MAX), "838:59:59.000000"); + + assert_eq!(format!("{:.9}", MySqlTime::MAX), "838:59:59.000000000"); + + assert_eq!(format!("{:+.0}", MySqlTime::MAX), "+838:59:59"); + + assert_eq!(format!("{:+.3}", MySqlTime::MAX), "+838:59:59.000"); + + assert_eq!(format!("{:+.6}", MySqlTime::MAX), "+838:59:59.000000"); + + assert_eq!(format!("{:+.9}", MySqlTime::MAX), "+838:59:59.000000000"); + + assert_eq!(format!("{:.0}", MySqlTime::MIN), "-838:59:59"); + + assert_eq!(format!("{:.3}", MySqlTime::MIN), "-838:59:59.000"); + + assert_eq!(format!("{:.6}", MySqlTime::MIN), "-838:59:59.000000"); + + assert_eq!(format!("{:.9}", MySqlTime::MIN), "-838:59:59.000000000"); + + let positive = MySqlTime::new(MySqlTimeSign::Positive, 123, 45, 56, 890011).unwrap(); + + assert_eq!(positive.to_string(), "123:45:56.890011"); + assert_eq!(format!("{positive:.0}"), "123:45:56"); + assert_eq!(format!("{positive:.3}"), "123:45:56.890"); + assert_eq!(format!("{positive:.6}"), "123:45:56.890011"); + assert_eq!(format!("{positive:.9}"), "123:45:56.890011000"); + + assert_eq!(format!("{positive:+.0}"), "+123:45:56"); + assert_eq!(format!("{positive:+.3}"), "+123:45:56.890"); + assert_eq!(format!("{positive:+.6}"), "+123:45:56.890011"); + assert_eq!(format!("{positive:+.9}"), "+123:45:56.890011000"); + + let negative = MySqlTime::new(MySqlTimeSign::Negative, 123, 45, 56, 890011).unwrap(); + + assert_eq!(negative.to_string(), "-123:45:56.890011"); + assert_eq!(format!("{negative:.0}"), "-123:45:56"); + assert_eq!(format!("{negative:.3}"), "-123:45:56.890"); + assert_eq!(format!("{negative:.6}"), "-123:45:56.890011"); + assert_eq!(format!("{negative:.9}"), "-123:45:56.890011000"); + } + + #[test] + fn test_parse_microseconds() { + assert_eq!( + parse_microseconds("010").unwrap(), + 10_000 + ); + + assert_eq!( + parse_microseconds("0100000000").unwrap(), + 10_000 + ); + + assert_eq!( + parse_microseconds("890").unwrap(), + 890_000 + ); + + assert_eq!( + parse_microseconds("0890").unwrap(), + 89_000 + ); + + assert_eq!( + // Case in point about not exposing this: + // we always truncate excess precision because it's simpler than rounding + // and MySQL should never return a higher precision. + parse_microseconds("123456789").unwrap(), + 123456, + ); + } +} diff --git a/sqlx-mysql/src/types/time.rs b/sqlx-mysql/src/types/time.rs index 8a3c06eb96..c9cd8d6664 100644 --- a/sqlx-mysql/src/types/time.rs +++ b/sqlx-mysql/src/types/time.rs @@ -1,5 +1,6 @@ use byteorder::{ByteOrder, LittleEndian}; use bytes::Buf; +use sqlx_core::database::Database; use time::macros::format_description; use time::{Date, OffsetDateTime, PrimitiveDateTime, Time, UtcOffset}; @@ -8,7 +9,7 @@ use crate::encode::{Encode, IsNull}; use crate::error::{BoxDynError, UnexpectedNullError}; use crate::protocol::text::ColumnType; use crate::type_info::MySqlTypeInfo; -use crate::types::Type; +use crate::types::{MySqlTime, MySqlTimeSign, Type}; use crate::{MySql, MySqlValueFormat, MySqlValueRef}; impl Type for OffsetDateTime { @@ -52,7 +53,7 @@ impl Encode<'_, MySql> for Time { // Time is not negative buf.push(0); - // "date on 4 bytes little-endian format" (?) + // Number of days in the interval; always 0 for time-of-day values. // https://mariadb.com/kb/en/resultset-row/#teimstamp-binary-encoding buf.extend_from_slice(&[0_u8; 4]); @@ -76,29 +77,11 @@ impl<'r> Decode<'r, MySql> for Time { fn decode(value: MySqlValueRef<'r>) -> Result { match value.format() { MySqlValueFormat::Binary => { - let mut buf = value.as_bytes()?; - - // data length, expecting 8 or 12 (fractional seconds) - let len = buf.get_u8(); - - // MySQL specifies that if all of hours, minutes, seconds, microseconds - // are 0 then the length is 0 and no further data is send - // https://dev.mysql.com/doc/internals/en/binary-protocol-value.html - if len == 0 { - return Ok(Time::MIDNIGHT); - } - - // is negative : int<1> - let is_negative = buf.get_u8(); - assert_eq!(is_negative, 0, "Negative dates/times are not supported"); - - // "date on 4 bytes little-endian format" (?) - // https://mariadb.com/kb/en/resultset-row/#timestamp-binary-encoding - buf.advance(4); - - decode_time(len - 5, buf) + // Should never panic. + MySqlTime::decode(value)?.try_into() } + // Retaining this parsing for now as it allows us to cross-check our impl. MySqlValueFormat::Text => Time::parse( value.as_str()?, &format_description!("[hour]:[minute]:[second].[subsecond]"), @@ -108,6 +91,53 @@ impl<'r> Decode<'r, MySql> for Time { } } +impl TryFrom for Time { + type Error = BoxDynError; + + fn try_from(time: MySqlTime) -> Result { + if !time.is_valid_time_of_day() { + return Err(format!("MySqlTime value out of range for `time::Time`: {time}").into()); + } + + Ok(Time::from_hms_micro( + // `is_valid_time_of_day()` ensures this won't overflow + time.hours() as u8, + time.minutes(), + time.seconds(), + time.microseconds(), + )?) + } +} + +impl From for time::Duration { + fn from(time: MySqlTime) -> Self { + time::Duration::new(time.whole_seconds_signed(), time.subsec_nanos() as i32) + } +} + +impl TryFrom for MySqlTime { + type Error = BoxDynError; + + fn try_from(value: time::Duration) -> Result { + let sign = if value.is_negative() { MySqlTimeSign::Negative } else { MySqlTimeSign::Positive }; + + // Similar to `TryFrom`, use `std::time::Duration` as an intermediate. + Ok(MySqlTime::try_from(std::time::Duration::try_from(value.abs())?)?.with_sign(sign)) + } +} + +impl Type for time::Duration { + fn type_info() -> MySqlTypeInfo { + MySqlTime::type_info() + } +} + +impl<'r> Decode<'r, MySql> for time::Duration { + fn decode(value: ::ValueRef<'r>) -> Result { + Ok(MySqlTime::decode(value)?.into()) + } +} + impl Type for Date { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Date) @@ -132,7 +162,14 @@ impl<'r> Decode<'r, MySql> for Date { fn decode(value: MySqlValueRef<'r>) -> Result { match value.format() { MySqlValueFormat::Binary => { - Ok(decode_date(&value.as_bytes()?[1..])?.ok_or(UnexpectedNullError)?) + let buf = value.as_bytes()?; + + // Row decoding should leave the length byte on the front. + if buf.is_empty() { + return Err("empty buffer".into()); + } + + Ok(decode_date(&buf[1..])?.ok_or(UnexpectedNullError)?) } MySqlValueFormat::Text => { let s = value.as_str()?; @@ -183,12 +220,18 @@ impl<'r> Decode<'r, MySql> for PrimitiveDateTime { fn decode(value: MySqlValueRef<'r>) -> Result { match value.format() { MySqlValueFormat::Binary => { - let buf = value.as_bytes()?; - let len = buf[0]; - let date = decode_date(&buf[1..])?.ok_or(UnexpectedNullError)?; + let mut buf = value.as_bytes()?; + + if buf.is_empty() { + return Err("empty buffer".into()); + } + + let len = buf.get_u8(); + + let date = decode_date(buf)?.ok_or(UnexpectedNullError)?; let dt = if len > 4 { - date.with_time(decode_time(len - 4, &buf[5..])?) + date.with_time(decode_time(&buf[4..])?) } else { date.midnight() }; @@ -255,12 +298,12 @@ fn encode_time(time: &Time, include_micros: bool, buf: &mut Vec) { } } -fn decode_time(len: u8, mut buf: &[u8]) -> Result { +fn decode_time(mut buf: &[u8]) -> Result { let hour = buf.get_u8(); let minute = buf.get_u8(); let seconds = buf.get_u8(); - let micros = if len > 3 { + let micros = if !buf.is_empty() { // microseconds : int buf.get_uint_le(buf.len()) } else { diff --git a/sqlx-mysql/src/types/uint.rs b/sqlx-mysql/src/types/uint.rs index 4731f0e433..ef383797c5 100644 --- a/sqlx-mysql/src/types/uint.rs +++ b/sqlx-mysql/src/types/uint.rs @@ -119,6 +119,20 @@ fn uint_decode(value: MySqlValueRef<'_>) -> Result { MySqlValueFormat::Binary => { let buf = value.as_bytes()?; + + // Check conditions that could cause `read_uint()` to panic. + if buf.is_empty() { + return Err("empty buffer".into()); + } + + if buf.len() > 8 { + return Err(format!( + "expected no more than 8 bytes for unsigned integer value, got {}", + buf.len() + ) + .into()); + } + LittleEndian::read_uint(buf, buf.len()) } }) diff --git a/tests/mysql/types.rs b/tests/mysql/types.rs index fad95b36a4..e837a53f75 100644 --- a/tests/mysql/types.rs +++ b/tests/mysql/types.rs @@ -9,6 +9,9 @@ use sqlx::{Executor, Row}; use sqlx::types::Text; +use sqlx::mysql::types::MySqlTime; +use sqlx_mysql::types::MySqlTimeSign; + use sqlx_test::{new, test_type}; test_type!(bool(MySql, "false" == false, "true" == true)); @@ -70,34 +73,44 @@ test_type!(uuid_simple(MySql, == sqlx::types::Uuid::parse_str("00000000000000000000000000000000").unwrap().simple() )); +test_type!(mysql_time(MySql, + "TIME '00:00:00.000000'" == MySqlTime::ZERO, + "TIME '-00:00:00.000000'" == MySqlTime::ZERO, + "TIME '838:59:59.0'" == MySqlTime::MAX, + "TIME '-838:59:59.0'" == MySqlTime::MIN, + "TIME '123:45:56.890'" == MySqlTime::new(MySqlTimeSign::Positive, 123, 45, 56, 890_000).unwrap(), + "TIME '-123:45:56.890'" == MySqlTime::new(MySqlTimeSign::Negative, 123, 45, 56, 890_000).unwrap(), + "TIME '123:45:56.890011'" == MySqlTime::new(MySqlTimeSign::Positive, 123, 45, 56, 890_011).unwrap(), + "TIME '-123:45:56.890011'" == MySqlTime::new(MySqlTimeSign::Negative, 123, 45, 56, 890_011).unwrap(), +)); + #[cfg(feature = "chrono")] mod chrono { - use sqlx::types::chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; + use sqlx::types::chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc}; use super::*; test_type!(chrono_date(MySql, - "DATE '2001-01-05'" == NaiveDate::from_ymd(2001, 1, 5), - "DATE '2050-11-23'" == NaiveDate::from_ymd(2050, 11, 23) + "DATE '2001-01-05'" == NaiveDate::from_ymd_opt(2001, 1, 5).unwrap(), + "DATE '2050-11-23'" == NaiveDate::from_ymd_opt(2050, 11, 23).unwrap() )); test_type!(chrono_time_zero(MySql, - "TIME '00:00:00.000000'" == NaiveTime::from_hms_micro(0, 0, 0, 0) + "TIME '00:00:00.000000'" == NaiveTime::from_hms_micro_opt(0, 0, 0, 0).unwrap() )); test_type!(chrono_time(MySql, - "TIME '05:10:20.115100'" == NaiveTime::from_hms_micro(5, 10, 20, 115100) + "TIME '05:10:20.115100'" == NaiveTime::from_hms_micro_opt(5, 10, 20, 115100).unwrap() )); test_type!(chrono_date_time(MySql, - "TIMESTAMP '2019-01-02 05:10:20'" == NaiveDate::from_ymd(2019, 1, 2).and_hms(5, 10, 20) + "TIMESTAMP '2019-01-02 05:10:20'" == NaiveDate::from_ymd_opt(2019, 1, 2).unwrap().and_hms_opt(5, 10, 20).unwrap() )); test_type!(chrono_timestamp>(MySql, "TIMESTAMP '2019-01-02 05:10:20.115100'" - == DateTime::::from_utc( - NaiveDate::from_ymd(2019, 1, 2).and_hms_micro(5, 10, 20, 115100), - Utc, + == Utc.from_utc_datetime( + &NaiveDate::from_ymd_opt(2019, 1, 2).unwrap().and_hms_micro_opt(5, 10, 20, 115100).unwrap(), ) ));