diff --git a/sqlx-sqlite/src/connection/explain.rs b/sqlx-sqlite/src/connection/explain.rs index 7929ffb211..34da35bcdd 100644 --- a/sqlx-sqlite/src/connection/explain.rs +++ b/sqlx-sqlite/src/connection/explain.rs @@ -2,10 +2,11 @@ use crate::connection::intmap::IntMap; use crate::connection::{execute, ConnectionState}; use crate::error::Error; use crate::from_row::FromRow; +use crate::logger::{BranchParent, BranchResult, DebugDiff}; use crate::type_info::DataType; use crate::SqliteTypeInfo; use sqlx_core::HashMap; -use std::collections::HashSet; +use std::fmt::Debug; use std::str::from_utf8; // affinity @@ -132,7 +133,7 @@ const OP_HALT_IF_NULL: &str = "HaltIfNull"; const MAX_LOOP_COUNT: u8 = 2; const MAX_TOTAL_INSTRUCTION_COUNT: u32 = 100_000; -#[derive(Debug, Clone, Eq, PartialEq, Hash)] +#[derive(Clone, Eq, PartialEq, Hash)] enum ColumnType { Single { datatype: DataType, @@ -171,6 +172,32 @@ impl ColumnType { } } +impl core::fmt::Debug for ColumnType { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::Single { datatype, nullable } => { + let nullable_str = match nullable { + Some(true) => "NULL", + Some(false) => "NOT NULL", + None => "NULL?", + }; + write!(f, "{:?} {}", datatype, nullable_str) + } + Self::Record(columns) => { + f.write_str("Record(")?; + let mut column_iter = columns.iter(); + if let Some(item) = column_iter.next() { + write!(f, "{:?}", item)?; + while let Some(item) = column_iter.next() { + write!(f, ", {:?}", item)?; + } + } + f.write_str(")") + } + } + } +} + #[derive(Debug, Clone, Eq, PartialEq, Hash)] enum RegDataType { Single(ColumnType), @@ -326,7 +353,7 @@ fn opcode_to_type(op: &str) -> DataType { OP_REAL => DataType::Float, OP_BLOB => DataType::Blob, OP_AND | OP_OR => DataType::Bool, - OP_ROWID | OP_COUNT | OP_INT64 | OP_INTEGER => DataType::Integer, + OP_NEWROWID | OP_ROWID | OP_COUNT | OP_INT64 | OP_INTEGER => DataType::Integer, OP_STRING8 => DataType::Text, OP_COLUMN | _ => DataType::Null, } @@ -376,18 +403,78 @@ fn root_block_columns( return Ok(row_info); } -#[derive(Debug, Clone, PartialEq)] +struct Sequence(i64); + +impl Sequence { + pub fn new() -> Self { + Self(0) + } + pub fn next(&mut self) -> i64 { + let curr = self.0; + self.0 += 1; + curr + } +} + +#[derive(Debug)] struct QueryState { // The number of times each instruction has been visited pub visited: Vec, - // A log of the order of execution of each instruction - pub history: Vec, + // A unique identifier of the query branch + pub branch_id: i64, + // How many instructions have been executed on this branch (NOT the same as program_i, which is the currently executing instruction of the program) + pub instruction_counter: i64, + // Parent branch this branch was forked from (if any) + pub branch_parent: Option, // State of the virtual machine pub mem: MemoryState, // Results published by the execution pub result: Option, Option)>>, } +impl From<&QueryState> for MemoryState { + fn from(val: &QueryState) -> Self { + val.mem.clone() + } +} + +impl From for MemoryState { + fn from(val: QueryState) -> Self { + val.mem + } +} + +impl From<&QueryState> for BranchParent { + fn from(val: &QueryState) -> Self { + Self { + id: val.branch_id, + idx: val.instruction_counter, + } + } +} + +impl QueryState { + fn get_reference(&self) -> BranchParent { + BranchParent { + id: self.branch_id, + idx: self.instruction_counter, + } + } + fn new_branch(&self, branch_seq: &mut Sequence) -> Self { + Self { + visited: self.visited.clone(), + branch_id: branch_seq.next(), + instruction_counter: 0, + branch_parent: Some(BranchParent { + id: self.branch_id, + idx: self.instruction_counter - 1, //instruction counter is incremented at the start of processing an instruction, so need to subtract 1 to get the 'current' instruction + }), + mem: self.mem.clone(), + result: self.result.clone(), + } + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct MemoryState { // Next instruction to execute @@ -400,22 +487,65 @@ struct MemoryState { pub t: IntMap, } +impl DebugDiff for MemoryState { + fn diff(&self, prev: &Self) -> String { + let r_diff = self.r.diff(&prev.r); + let p_diff = self.p.diff(&prev.p); + let t_diff = self.t.diff(&prev.t); + + let mut differences = String::new(); + for (i, v) in r_diff { + if !differences.is_empty() { + differences.push('\n'); + } + differences.push_str(&format!("r[{}]={:?}", i, v)) + } + for (i, v) in p_diff { + if !differences.is_empty() { + differences.push('\n'); + } + differences.push_str(&format!("p[{}]={:?}", i, v)) + } + for (i, v) in t_diff { + if !differences.is_empty() { + differences.push('\n'); + } + differences.push_str(&format!("t[{}]={:?}", i, v)) + } + differences + } +} + struct BranchList { states: Vec, - visited_branch_state: HashSet, + visited_branch_state: HashMap, } impl BranchList { pub fn new(state: QueryState) -> Self { Self { states: vec![state], - visited_branch_state: HashSet::new(), + visited_branch_state: HashMap::new(), } } - pub fn push(&mut self, state: QueryState) { - if !self.visited_branch_state.contains(&state.mem) { - self.visited_branch_state.insert(state.mem.clone()); - self.states.push(state); + pub fn push( + &mut self, + mut state: QueryState, + logger: &mut crate::logger::QueryPlanLogger<'_, R, MemoryState, P>, + ) { + logger.add_branch(&state, &state.branch_parent.unwrap()); + match self.visited_branch_state.entry(state.mem) { + std::collections::hash_map::Entry::Vacant(entry) => { + //this state is not identical to another state, so it will need to be processed + state.mem = entry.key().clone(); //replace state.mem since .entry() moved it + entry.insert(state.get_reference()); + self.states.push(state); + } + std::collections::hash_map::Entry::Occupied(entry) => { + //already saw a state identical to this one, so no point in processing it + state.mem = entry.key().clone(); //replace state.mem since .entry() moved it + logger.add_result(state, BranchResult::Dedup(entry.get().clone())); + } } } pub fn pop(&mut self) -> Option { @@ -436,12 +566,13 @@ pub(super) fn explain( .collect::, Error>>()?; let program_size = program.len(); - let mut logger = - crate::logger::QueryPlanLogger::new(query, &program, conn.log_settings.clone()); - + let mut logger = crate::logger::QueryPlanLogger::new(query, &program); + let mut branch_seq = Sequence::new(); let mut states = BranchList::new(QueryState { visited: vec![0; program_size], - history: Vec::new(), + branch_id: branch_seq.next(), + branch_parent: None, + instruction_counter: 0, result: None, mem: MemoryState { program_i: 0, @@ -457,22 +588,20 @@ pub(super) fn explain( while let Some(mut state) = states.pop() { while state.mem.program_i < program_size { let (_, ref opcode, p1, p2, p3, ref p4) = program[state.mem.program_i]; - state.history.push(state.mem.program_i); + + logger.add_operation(state.mem.program_i, &state); + state.instruction_counter += 1; //limit the number of 'instructions' that can be evaluated if gas > 0 { gas -= 1; } else { + logger.add_result(state, BranchResult::GasLimit); break; } if state.visited[state.mem.program_i] > MAX_LOOP_COUNT { - if logger.log_enabled() { - let program_history: Vec<&(i64, String, i64, i64, i64, Vec)> = - state.history.iter().map(|i| &program[*i]).collect(); - logger.add_result((program_history, None)); - } - + logger.add_result(state, BranchResult::LoopLimit); //avoid (infinite) loops by breaking if we ever hit the same instruction twice break; } @@ -513,23 +642,63 @@ pub(super) fn explain( OP_DECR_JUMP_ZERO | OP_ELSE_EQ | OP_EQ | OP_FILTER | OP_FOUND | OP_GE | OP_GT | OP_IDX_GE | OP_IDX_GT | OP_IDX_LE | OP_IDX_LT | OP_IF_NO_HOPE | OP_IF_NOT | OP_IF_NOT_OPEN | OP_IF_NOT_ZERO | OP_IF_NULL_ROW | OP_IF_SMALLER - | OP_INCR_VACUUM | OP_IS_NULL | OP_IS_NULL_OR_TYPE | OP_LE | OP_LT | OP_NE - | OP_NEXT | OP_NO_CONFLICT | OP_NOT_EXISTS | OP_ONCE | OP_PREV | OP_PROGRAM + | OP_INCR_VACUUM | OP_IS_NULL_OR_TYPE | OP_LE | OP_LT | OP_NE | OP_NEXT + | OP_NO_CONFLICT | OP_NOT_EXISTS | OP_ONCE | OP_PREV | OP_PROGRAM | OP_ROW_SET_READ | OP_ROW_SET_TEST | OP_SEEK_GE | OP_SEEK_GT | OP_SEEK_LE | OP_SEEK_LT | OP_SEEK_ROW_ID | OP_SEEK_SCAN | OP_SEQUENCE_TEST | OP_SORTER_NEXT | OP_V_FILTER | OP_V_NEXT => { // goto or next instruction (depending on actual values) - let mut branch_state = state.clone(); + let mut branch_state = state.new_branch(&mut branch_seq); branch_state.mem.program_i = p2 as usize; - states.push(branch_state); + states.push(branch_state, &mut logger); state.mem.program_i += 1; continue; } + OP_IS_NULL => { + // goto if p1 is null + + //branch if maybe null + let might_branch = match state.mem.r.get(&p1) { + Some(r_p1) => !matches!(r_p1.map_to_nullable(), Some(false)), + _ => false, + }; + + //nobranch if maybe not null + let might_not_branch = match state.mem.r.get(&p1) { + Some(r_p1) => !matches!(r_p1.map_to_datatype(), DataType::Null), + _ => false, + }; + + if might_branch { + let mut branch_state = state.new_branch(&mut branch_seq); + branch_state.mem.program_i = p2 as usize; + branch_state + .mem + .r + .insert(p1, RegDataType::Single(ColumnType::default())); + + states.push(branch_state, &mut logger); + } + + if might_not_branch { + state.mem.program_i += 1; + if let Some(RegDataType::Single(ColumnType::Single { nullable, .. })) = + state.mem.r.get_mut(&p1) + { + *nullable = Some(false); + } + continue; + } else { + logger.add_result(state, BranchResult::Branched); + break; + } + } + OP_NOT_NULL => { - // goto or next instruction (depending on actual values) + // goto if p1 is not null let might_branch = match state.mem.r.get(&p1) { Some(r_p1) => !matches!(r_p1.map_to_datatype(), DataType::Null), @@ -542,7 +711,7 @@ pub(super) fn explain( }; if might_branch { - let mut branch_state = state.clone(); + let mut branch_state = state.new_branch(&mut branch_seq); branch_state.mem.program_i = p2 as usize; if let Some(RegDataType::Single(ColumnType::Single { nullable, .. })) = branch_state.mem.r.get_mut(&p1) @@ -550,7 +719,7 @@ pub(super) fn explain( *nullable = Some(false); } - states.push(branch_state); + states.push(branch_state, &mut logger); } if might_not_branch { @@ -561,6 +730,7 @@ pub(super) fn explain( .insert(p1, RegDataType::Single(ColumnType::default())); continue; } else { + logger.add_result(state, BranchResult::Branched); break; } } @@ -571,9 +741,9 @@ pub(super) fn explain( //don't bother checking actual types, just don't branch to instruction 0 if p2 != 0 { - let mut branch_state = state.clone(); + let mut branch_state = state.new_branch(&mut branch_seq); branch_state.mem.program_i = p2 as usize; - states.push(branch_state); + states.push(branch_state, &mut logger); } state.mem.program_i += 1; @@ -594,13 +764,13 @@ pub(super) fn explain( }; if might_branch { - let mut branch_state = state.clone(); + let mut branch_state = state.new_branch(&mut branch_seq); branch_state.mem.program_i = p2 as usize; if p3 == 0 { branch_state.mem.r.insert(p1, RegDataType::Int(1)); } - states.push(branch_state); + states.push(branch_state, &mut logger); } if might_not_branch { @@ -610,6 +780,7 @@ pub(super) fn explain( } continue; } else { + logger.add_result(state, BranchResult::Branched); break; } } @@ -631,12 +802,12 @@ pub(super) fn explain( let loop_detected = state.visited[state.mem.program_i] > 1; if might_branch || loop_detected { - let mut branch_state = state.clone(); + let mut branch_state = state.new_branch(&mut branch_seq); branch_state.mem.program_i = p2 as usize; if let Some(RegDataType::Int(r_p1)) = branch_state.mem.r.get_mut(&p1) { *r_p1 -= 1; } - states.push(branch_state); + states.push(branch_state, &mut logger); } if might_not_branch { @@ -656,6 +827,7 @@ pub(super) fn explain( } continue; } else { + logger.add_result(state, BranchResult::Branched); break; } } @@ -672,7 +844,7 @@ pub(super) fn explain( if matches!(cursor.is_empty(&state.mem.t), None | Some(true)) { //only take this branch if the cursor is empty - let mut branch_state = state.clone(); + let mut branch_state = state.new_branch(&mut branch_seq); branch_state.mem.program_i = p2 as usize; if let Some(cur) = branch_state.mem.p.get(&p1) { @@ -680,7 +852,7 @@ pub(super) fn explain( tab.is_empty = Some(true); } } - states.push(branch_state); + states.push(branch_state, &mut logger); } if matches!(cursor.is_empty(&state.mem.t), None | Some(false)) { @@ -688,16 +860,12 @@ pub(super) fn explain( state.mem.program_i += 1; continue; } else { + logger.add_result(state, BranchResult::Branched); break; } } - if logger.log_enabled() { - let program_history: Vec<&(i64, String, i64, i64, i64, Vec)> = - state.history.iter().map(|i| &program[*i]).collect(); - logger.add_result((program_history, None)); - } - + logger.add_result(state, BranchResult::Branched); break; } @@ -726,34 +894,15 @@ pub(super) fn explain( state.mem.r.remove(&p1); continue; } else { - if logger.log_enabled() { - let program_history: Vec<&( - i64, - String, - i64, - i64, - i64, - Vec, - )> = state.history.iter().map(|i| &program[*i]).collect(); - logger.add_result((program_history, None)); - } - + logger.add_result(state, BranchResult::Error); break; } } else { - if logger.log_enabled() { - let program_history: Vec<&(i64, String, i64, i64, i64, Vec)> = - state.history.iter().map(|i| &program[*i]).collect(); - logger.add_result((program_history, None)); - } + logger.add_result(state, BranchResult::Error); break; } } else { - if logger.log_enabled() { - let program_history: Vec<&(i64, String, i64, i64, i64, Vec)> = - state.history.iter().map(|i| &program[*i]).collect(); - logger.add_result((program_history, None)); - } + logger.add_result(state, BranchResult::Error); break; } } @@ -765,12 +914,11 @@ pub(super) fn explain( state.mem.program_i = (*return_i + 1) as usize; state.mem.r.remove(&p1); continue; + } else if p3 == 1 { + state.mem.program_i += 1; + continue; } else { - if logger.log_enabled() { - let program_history: Vec<&(i64, String, i64, i64, i64, Vec)> = - state.history.iter().map(|i| &program[*i]).collect(); - logger.add_result((program_history, None)); - } + logger.add_result(state, BranchResult::Error); break; } } @@ -796,11 +944,7 @@ pub(super) fn explain( continue; } } else { - if logger.log_enabled() { - let program_history: Vec<&(i64, String, i64, i64, i64, Vec)> = - state.history.iter().map(|i| &program[*i]).collect(); - logger.add_result((program_history, None)); - } + logger.add_result(state, BranchResult::Error); break; } } @@ -808,17 +952,17 @@ pub(super) fn explain( OP_JUMP => { // goto one of , , or based on the result of a prior compare - let mut branch_state = state.clone(); + let mut branch_state = state.new_branch(&mut branch_seq); branch_state.mem.program_i = p1 as usize; - states.push(branch_state); + states.push(branch_state, &mut logger); - let mut branch_state = state.clone(); + let mut branch_state = state.new_branch(&mut branch_seq); branch_state.mem.program_i = p2 as usize; - states.push(branch_state); + states.push(branch_state, &mut logger); - let mut branch_state = state.clone(); + let mut branch_state = state.new_branch(&mut branch_seq); branch_state.mem.program_i = p3 as usize; - states.push(branch_state); + states.push(branch_state, &mut logger); } OP_COLUMN => { @@ -889,18 +1033,35 @@ pub(super) fn explain( } OP_INSERT | OP_IDX_INSERT | OP_SORTER_INSERT => { - if let Some(RegDataType::Single(ColumnType::Record(record))) = - state.mem.r.get(&p2) - { - if let Some(TableDataType { cols, is_empty }) = state - .mem - .p - .get(&p1) - .and_then(|cur| cur.table_mut(&mut state.mem.t)) - { - // Insert the record into wherever pointer p1 is - *cols = record.clone(); - *is_empty = Some(false); + if let Some(RegDataType::Single(columntype)) = state.mem.r.get(&p2) { + match columntype { + ColumnType::Record(record) => { + if let Some(TableDataType { cols, is_empty }) = state + .mem + .p + .get(&p1) + .and_then(|cur| cur.table_mut(&mut state.mem.t)) + { + // Insert the record into wherever pointer p1 is + *cols = record.clone(); + *is_empty = Some(false); + } + } + ColumnType::Single { + datatype: DataType::Null, + nullable: _, + } => { + if let Some(TableDataType { is_empty, .. }) = state + .mem + .p + .get(&p1) + .and_then(|cur| cur.table_mut(&mut state.mem.t)) + { + // Insert a null record into wherever pointer p1 is + *is_empty = Some(false); + } + } + _ => {} } } //Noop if the register p2 isn't a record, or if pointer p1 does not exist @@ -1035,7 +1196,7 @@ pub(super) fn explain( ); } - _ => logger.add_unknown_operation(&program[state.mem.program_i]), + _ => logger.add_unknown_operation(state.mem.program_i), } } @@ -1284,44 +1445,39 @@ pub(super) fn explain( OP_RESULT_ROW => { // output = r[p1 .. p1 + p2] - - state.result = Some( - (p1..p1 + p2) - .map(|i| { - let coltype = state.mem.r.get(&i); - - let sqltype = - coltype.map(|d| d.map_to_datatype()).map(SqliteTypeInfo); - let nullable = - coltype.map(|d| d.map_to_nullable()).unwrap_or_default(); - - (sqltype, nullable) - }) - .collect(), + let result: Vec<_> = (p1..p1 + p2) + .map(|i| { + state + .mem + .r + .get(&i) + .map(RegDataType::map_to_columntype) + .unwrap_or_default() + }) + .collect(); + + let mut branch_state = state.new_branch(&mut branch_seq); + branch_state.mem.program_i += 1; + states.push(branch_state, &mut logger); + + logger.add_result( + state, + BranchResult::Result(IntMap::from_dense_record(&result)), ); - if logger.log_enabled() { - let program_history: Vec<&(i64, String, i64, i64, i64, Vec)> = - state.history.iter().map(|i| &program[*i]).collect(); - logger.add_result((program_history, Some(state.result.clone()))); - } - - result_states.push(state.clone()); + result_states.push(result); + break; } OP_HALT => { - if logger.log_enabled() { - let program_history: Vec<&(i64, String, i64, i64, i64, Vec)> = - state.history.iter().map(|i| &program[*i]).collect(); - logger.add_result((program_history, None)); - } + logger.add_result(state, BranchResult::Halt); break; } _ => { // ignore unsupported operations // if we fail to find an r later, we just give up - logger.add_unknown_operation(&program[state.mem.program_i]); + logger.add_unknown_operation(state.mem.program_i); } } @@ -1332,31 +1488,32 @@ pub(super) fn explain( let mut output: Vec> = Vec::new(); let mut nullable: Vec> = Vec::new(); - while let Some(state) = result_states.pop() { + while let Some(result) = result_states.pop() { // find the datatype info from each ResultRow execution - if let Some(result) = state.result { - let mut idx = 0; - for (this_type, this_nullable) in result { - if output.len() == idx { - output.push(this_type); - } else if output[idx].is_none() - || matches!(output[idx], Some(SqliteTypeInfo(DataType::Null))) - { - output[idx] = this_type; - } + let mut idx = 0; + for this_col in result { + let this_type = this_col.map_to_datatype(); + let this_nullable = this_col.map_to_nullable(); + if output.len() == idx { + output.push(Some(SqliteTypeInfo(this_type))); + } else if output[idx].is_none() + || matches!(output[idx], Some(SqliteTypeInfo(DataType::Null))) + && !matches!(this_type, DataType::Null) + { + output[idx] = Some(SqliteTypeInfo(this_type)); + } - if nullable.len() == idx { - nullable.push(this_nullable); - } else if let Some(ref mut null) = nullable[idx] { - //if any ResultRow's column is nullable, the final result is nullable - if let Some(this_null) = this_nullable { - *null |= this_null; - } - } else { - nullable[idx] = this_nullable; + if nullable.len() == idx { + nullable.push(this_nullable); + } else if let Some(ref mut null) = nullable[idx] { + //if any ResultRow's column is nullable, the final result is nullable + if let Some(this_null) = this_nullable { + *null |= this_null; } - idx += 1; + } else { + nullable[idx] = this_nullable; } + idx += 1; } } diff --git a/sqlx-sqlite/src/connection/intmap.rs b/sqlx-sqlite/src/connection/intmap.rs index 3bf4f886d2..0c0e5ce52a 100644 --- a/sqlx-sqlite/src/connection/intmap.rs +++ b/sqlx-sqlite/src/connection/intmap.rs @@ -1,10 +1,16 @@ +use std::{fmt::Debug, hash::Hash}; + /// Simplistic map implementation built on a Vec of Options (index = key) -#[derive(Debug, Clone, Eq, Default)] -pub(crate) struct IntMap( - Vec>, -); +#[derive(Debug, Clone, Eq)] +pub(crate) struct IntMap(Vec>); + +impl Default for IntMap { + fn default() -> Self { + IntMap(Vec::new()) + } +} -impl IntMap { +impl IntMap { pub(crate) fn new() -> Self { Self(Vec::new()) } @@ -17,10 +23,6 @@ impl IntMap { idx } - pub(crate) fn from_dense_record(record: &Vec) -> Self { - Self(record.iter().cloned().map(Some).collect()) - } - pub(crate) fn values_mut(&mut self) -> impl Iterator { self.0.iter_mut().filter_map(Option::as_mut) } @@ -67,9 +69,67 @@ impl IntMap { None => None, } } + + pub(crate) fn iter(&self) -> impl Iterator> { + self.0.iter().map(Option::as_ref) + } + + pub(crate) fn iter_entries(&self) -> impl Iterator { + self.0 + .iter() + .enumerate() + .filter_map(|(i, v)| v.as_ref().map(|v: &V| (i as i64, v))) + } + + pub(crate) fn last_index(&self) -> Option { + self.0.iter().rposition(|v| v.is_some()).map(|i| i as i64) + } +} + +impl IntMap { + pub(crate) fn get_mut_or_default<'a>(&'a mut self, idx: &i64) -> &'a mut V { + let idx: usize = self.expand(*idx); + + let item: &mut Option = &mut self.0[idx]; + if item.is_none() { + *item = Some(V::default()); + } + + return self.0[idx].as_mut().unwrap(); + } +} + +impl IntMap { + pub(crate) fn from_dense_record(record: &Vec) -> Self { + Self(record.iter().cloned().map(Some).collect()) + } +} + +impl IntMap { + /// get the additions to this intmap compared to the prev intmap + pub(crate) fn diff<'a, 'b, 'c>( + &'a self, + prev: &'b Self, + ) -> impl Iterator)> + where + 'a: 'c, + 'b: 'c, + { + let self_pad = if prev.0.len() > self.0.len() { + prev.0.len() - self.0.len() + } else { + 0 + }; + self.iter() + .chain(std::iter::repeat(None).take(self_pad)) + .zip(prev.iter().chain(std::iter::repeat(None))) + .enumerate() + .filter(|(_i, (n, p))| n != p) + .map(|(i, (n, _p))| (i, n)) + } } -impl std::hash::Hash for IntMap { +impl Hash for IntMap { fn hash(&self, state: &mut H) { for value in self.values() { value.hash(state); @@ -77,7 +137,7 @@ impl std::hash::H } } -impl PartialEq for IntMap { +impl PartialEq for IntMap { fn eq(&self, other: &Self) -> bool { if !self .0 @@ -98,9 +158,7 @@ impl PartialEq fo } } -impl FromIterator<(i64, V)> - for IntMap -{ +impl FromIterator<(i64, V)> for IntMap { fn from_iter(iter: I) -> Self where I: IntoIterator, diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index fadc9e4175..b6afc0c52b 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -30,7 +30,7 @@ pub(crate) mod execute; mod executor; mod explain; mod handle; -mod intmap; +pub(crate) mod intmap; mod worker; diff --git a/sqlx-sqlite/src/logger.rs b/sqlx-sqlite/src/logger.rs index d605ea13a8..7e76c946a6 100644 --- a/sqlx-sqlite/src/logger.rs +++ b/sqlx-sqlite/src/logger.rs @@ -1,87 +1,432 @@ -use sqlx_core::{connection::LogSettings, logger}; +use crate::connection::intmap::IntMap; use std::collections::HashSet; use std::fmt::Debug; use std::hash::Hash; pub(crate) use sqlx_core::logger::*; -pub struct QueryPlanLogger<'q, O: Debug + Hash + Eq, R: Debug, P: Debug> { +#[derive(Debug)] +pub(crate) enum BranchResult { + Result(R), + Dedup(BranchParent), + Halt, + Error, + GasLimit, + LoopLimit, + Branched, +} + +#[derive(Debug, Clone, Copy, PartialEq, Hash, Eq, Ord, PartialOrd)] +pub(crate) struct BranchParent { + pub id: i64, + pub idx: i64, +} + +#[derive(Debug)] +pub(crate) struct InstructionHistory { + pub program_i: usize, + pub state: S, +} + +pub(crate) trait DebugDiff { + fn diff(&self, prev: &Self) -> String; +} + +pub struct QueryPlanLogger<'q, R: Debug + 'static, S: Debug + DebugDiff + 'static, P: Debug> { sql: &'q str, - unknown_operations: HashSet, - results: Vec, + unknown_operations: HashSet, + branch_origins: IntMap, + branch_results: IntMap>, + branch_operations: IntMap>>, program: &'q [P], - settings: LogSettings, } -impl<'q, O: Debug + Hash + Eq, R: Debug, P: Debug> QueryPlanLogger<'q, O, R, P> { - pub fn new(sql: &'q str, program: &'q [P], settings: LogSettings) -> Self { +/// convert a string into dot format +fn dot_escape_string(value: impl AsRef) -> String { + value + .as_ref() + .replace("\\", "\\\\") + .replace("\"", "'") + .replace("\n", "\\n") + .to_string() +} + +impl core::fmt::Display for QueryPlanLogger<'_, R, S, P> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + //writes query plan history in dot format + f.write_str("digraph {\n")?; + + f.write_str("subgraph operations {\n")?; + f.write_str("style=\"rounded\";\nnode [shape=\"point\"];\n")?; + + let all_states: std::collections::HashMap> = self + .branch_operations + .iter_entries() + .flat_map( + |(branch_id, instructions): (i64, &IntMap>)| { + instructions.iter_entries().map( + move |(idx, ih): (i64, &InstructionHistory)| { + (BranchParent { id: branch_id, idx }, ih) + }, + ) + }, + ) + .collect(); + + let mut instruction_uses: IntMap> = Default::default(); + for (k, state) in all_states.iter() { + let entry = instruction_uses.get_mut_or_default(&(state.program_i as i64)); + entry.push(k.clone()); + } + + let mut branch_children: std::collections::HashMap> = + Default::default(); + + let mut branched_with_state: std::collections::HashSet = Default::default(); + + for (branch_id, branch_parent) in self.branch_origins.iter_entries() { + let entry = branch_children.entry(*branch_parent).or_default(); + entry.push(BranchParent { + id: branch_id, + idx: 0, + }); + } + + for (idx, instruction) in self.program.iter().enumerate() { + let escaped_instruction = dot_escape_string(format!("{:?}", instruction)); + write!( + f, + "subgraph cluster_{} {{ label=\"{}\"", + idx, escaped_instruction + )?; + + if self.unknown_operations.contains(&idx) { + f.write_str(" style=dashed")?; + } + + f.write_str(";\n")?; + + let mut state_list: std::collections::BTreeMap< + String, + Vec<(BranchParent, Option)>, + > = Default::default(); + + write!(f, "i{}[style=invis];", idx)?; + + if let Some(this_instruction_uses) = instruction_uses.get(&(idx as i64)) { + for curr_ref in this_instruction_uses.iter() { + if let Some(curr_state) = all_states.get(curr_ref) { + let next_ref = BranchParent { + id: curr_ref.id, + idx: curr_ref.idx + 1, + }; + + if let Some(next_state) = all_states.get(&next_ref) { + let state_diff = next_state.state.diff(&curr_state.state); + + state_list + .entry(state_diff) + .or_default() + .push((curr_ref.clone(), Some(next_ref))); + } else { + state_list + .entry(Default::default()) + .or_default() + .push((curr_ref.clone(), None)); + }; + + if let Some(children) = branch_children.get(curr_ref) { + for next_ref in children { + if let Some(next_state) = all_states.get(&next_ref) { + let state_diff = next_state.state.diff(&curr_state.state); + + if !state_diff.is_empty() { + branched_with_state.insert(next_ref.clone()); + } + + state_list + .entry(state_diff) + .or_default() + .push((curr_ref.clone(), Some(next_ref.clone()))); + } + } + }; + } + } + + for curr_ref in this_instruction_uses { + if branch_children.contains_key(curr_ref) { + write!(f, "\"b{}p{}\";", curr_ref.id, curr_ref.idx)?; + } + } + } else { + write!(f, "i{}->i{}[style=invis];", idx - 1, idx)?; + } + + for (state_num, (state_diff, ref_list)) in state_list.iter().enumerate() { + if !state_diff.is_empty() { + let escaped_state = dot_escape_string(state_diff); + write!( + f, + "subgraph \"cluster_i{}s{}\" {{\nlabel=\"{}\"\n", + idx, state_num, escaped_state + )?; + } + + for (curr_ref, next_ref) in ref_list { + if let Some(next_ref) = next_ref { + let next_program_i = all_states + .get(&next_ref) + .map(|s| s.program_i.to_string()) + .unwrap_or_default(); + + if branched_with_state.contains(next_ref) { + write!( + f, + "\"b{}p{}_b{}p{}\"[tooltip=\"next:{}\"];", + curr_ref.id, + curr_ref.idx, + next_ref.id, + next_ref.idx, + next_program_i + )?; + continue; + } else { + write!( + f, + "\"b{}p{}\"[tooltip=\"next:{}\"];", + curr_ref.id, curr_ref.idx, next_program_i + )?; + } + } else { + write!(f, "\"b{}p{}\";", curr_ref.id, curr_ref.idx)?; + } + } + + if !state_diff.is_empty() { + f.write_str("}\n")?; + } + } + + f.write_str("}\n")?; + } + + f.write_str("};\n")?; //subgraph operations + + let max_branch_id: i64 = [ + self.branch_operations.last_index().unwrap_or(0), + self.branch_results.last_index().unwrap_or(0), + self.branch_results.last_index().unwrap_or(0), + ] + .into_iter() + .max() + .unwrap_or(0); + + f.write_str("subgraph branches {\n")?; + for branch_id in 0..=max_branch_id { + write!(f, "subgraph b{}{{", branch_id)?; + + let branch_num = branch_id as usize; + let color_names = [ + "blue", + "red", + "cyan", + "yellow", + "green", + "magenta", + "orange", + "purple", + "orangered", + "sienna", + "olivedrab", + "pink", + ]; + let color_name_root = color_names[branch_num % color_names.len()]; + let color_name_suffix = match (branch_num / color_names.len()) % 4 { + 0 => "1", + 1 => "4", + 2 => "3", + 3 => "2", + _ => "", + }; //colors are easily confused after color_names.len() * 2, and outright reused after color_names.len() * 4 + write!( + f, + "edge [colorscheme=x11 color={}{}];", + color_name_root, color_name_suffix + )?; + + let mut instruction_list: Vec<(BranchParent, &InstructionHistory)> = Vec::new(); + if let Some(parent) = self.branch_origins.get(&branch_id) { + if let Some(parent_state) = all_states.get(parent) { + instruction_list.push((parent.clone(), parent_state)); + } + } + if let Some(instructions) = self.branch_operations.get(&branch_id) { + for instruction in instructions.iter_entries() { + instruction_list.push(( + BranchParent { + id: branch_id, + idx: instruction.0, + }, + instruction.1, + )) + } + } + + let mut instructions_iter = instruction_list.into_iter(); + + if let Some((cur_ref, _)) = instructions_iter.next() { + let mut prev_ref = cur_ref; + + while let Some((cur_ref, _)) = instructions_iter.next() { + if branched_with_state.contains(&cur_ref) { + write!( + f, + "\"b{}p{}\" -> \"b{}p{}_b{}p{}\" -> \"b{}p{}\"\n", + prev_ref.id, + prev_ref.idx, + prev_ref.id, + prev_ref.idx, + cur_ref.id, + cur_ref.idx, + cur_ref.id, + cur_ref.idx + )?; + } else { + write!( + f, + "\"b{}p{}\" -> \"b{}p{}\";", + prev_ref.id, prev_ref.idx, cur_ref.id, cur_ref.idx + )?; + } + prev_ref = cur_ref; + } + + //draw edge to the result of this branch + if let Some(result) = self.branch_results.get(&branch_id) { + if let BranchResult::Dedup(dedup_ref) = result { + write!( + f, + "\"b{}p{}\"->\"b{}p{}\" [style=dotted]", + prev_ref.id, prev_ref.idx, dedup_ref.id, dedup_ref.idx + )?; + } else { + let escaped_result = dot_escape_string(format!("{:?}", result)); + write!( + f, + "\"b{}p{}\" ->\"{}\"; \"{}\" [shape=box];", + prev_ref.id, prev_ref.idx, escaped_result, escaped_result + )?; + } + } else { + write!( + f, + "\"b{}p{}\" ->\"NoResult\"; \"NoResult\" [shape=box];", + prev_ref.id, prev_ref.idx + )?; + } + } + f.write_str("};\n")?; + } + f.write_str("};\n")?; //branches + + f.write_str("}\n")?; + Ok(()) + } +} + +impl<'q, R: Debug, S: Debug + DebugDiff, P: Debug> QueryPlanLogger<'q, R, S, P> { + pub fn new(sql: &'q str, program: &'q [P]) -> Self { Self { sql, unknown_operations: HashSet::new(), - results: Vec::new(), + branch_origins: IntMap::new(), + branch_results: IntMap::new(), + branch_operations: IntMap::new(), program, - settings, } } pub fn log_enabled(&self) -> bool { - if let Some((tracing_level, log_level)) = - logger::private_level_filter_to_levels(self.settings.statements_level) - { - log::log_enabled!(log_level) - || sqlx_core::private_tracing_dynamic_enabled!(tracing_level) - } else { - false + log::log_enabled!(target: "sqlx::explain", log::Level::Trace) + || private_tracing_dynamic_enabled!(target: "sqlx::explain", tracing::Level::TRACE) + } + + pub fn add_branch(&mut self, state: I, parent: &BranchParent) + where + BranchParent: From, + { + if !self.log_enabled() { + return; + } + let branch: BranchParent = BranchParent::from(state); + self.branch_origins.insert(branch.id, parent.clone()); + } + + pub fn add_operation(&mut self, program_i: usize, state: I) + where + BranchParent: From, + S: From, + { + if !self.log_enabled() { + return; } + let branch: BranchParent = BranchParent::from(state); + let state: S = S::from(state); + self.branch_operations + .get_mut_or_default(&branch.id) + .insert(branch.idx, InstructionHistory { program_i, state }); } - pub fn add_result(&mut self, result: R) { - self.results.push(result); + pub fn add_result(&mut self, state: I, result: BranchResult) + where + BranchParent: for<'a> From<&'a I>, + S: From, + { + if !self.log_enabled() { + return; + } + let branch: BranchParent = BranchParent::from(&state); + self.branch_results.insert(branch.id, result); } - pub fn add_unknown_operation(&mut self, operation: O) { + pub fn add_unknown_operation(&mut self, operation: usize) { + if !self.log_enabled() { + return; + } self.unknown_operations.insert(operation); } pub fn finish(&self) { - let lvl = self.settings.statements_level; - - if let Some((tracing_level, log_level)) = logger::private_level_filter_to_levels(lvl) { - let log_is_enabled = log::log_enabled!(target: "sqlx::explain", log_level) - || private_tracing_dynamic_enabled!(target: "sqlx::explain", tracing_level); - if log_is_enabled { - let mut summary = parse_query_summary(&self.sql); - - let sql = if summary != self.sql { - summary.push_str(" …"); - format!( - "\n\n{}\n", - sqlformat::format( - &self.sql, - &sqlformat::QueryParams::None, - sqlformat::FormatOptions::default() - ) - ) - } else { - String::new() - }; - - let message = format!( - "{}; program:{:?}, unknown_operations:{:?}, results: {:?}{}", - summary, self.program, self.unknown_operations, self.results, sql - ); - - sqlx_core::private_tracing_dynamic_event!( - target: "sqlx::explain", - tracing_level, - message, - ); - } + if !self.log_enabled() { + return; } + + let mut summary = parse_query_summary(&self.sql); + + let sql = if summary != self.sql { + summary.push_str(" …"); + format!( + "\n\n{}\n", + sqlformat::format( + &self.sql, + &sqlformat::QueryParams::None, + sqlformat::FormatOptions::default() + ) + ) + } else { + String::new() + }; + + sqlx_core::private_tracing_dynamic_event!( + target: "sqlx::explain", + tracing::Level::TRACE, + "{}; program:\n{}\n\n{:?}", summary, self, sql + ); } } -impl<'q, O: Debug + Hash + Eq, R: Debug, P: Debug> Drop for QueryPlanLogger<'q, O, R, P> { +impl<'q, R: Debug, S: Debug + DebugDiff, P: Debug> Drop for QueryPlanLogger<'q, R, S, P> { fn drop(&mut self) { self.finish(); } diff --git a/tests/sqlite/describe.rs b/tests/sqlite/describe.rs index 18d43fa815..5458eaaa82 100644 --- a/tests/sqlite/describe.rs +++ b/tests/sqlite/describe.rs @@ -278,6 +278,26 @@ async fn it_describes_update_with_returning() -> anyhow::Result<()> { assert_eq!(d.column(0).type_info().name(), "INTEGER"); assert_eq!(d.nullable(0), Some(false)); + let d = conn + .describe("UPDATE accounts SET is_active=true WHERE id=?1 RETURNING *") + .await?; + + assert_eq!(d.columns().len(), 3); + assert_eq!(d.column(0).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(0), Some(false)); + assert_eq!(d.column(1).type_info().name(), "TEXT"); + assert_eq!(d.nullable(1), Some(false)); + assert_eq!(d.column(2).type_info().name(), "BOOLEAN"); + //assert_eq!(d.nullable(2), Some(false)); //query analysis is allowed to notice that it is always set to true by the update + + let d = conn + .describe("UPDATE accounts SET is_active=true WHERE id=?1 RETURNING id") + .await?; + + assert_eq!(d.columns().len(), 1); + assert_eq!(d.column(0).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(0), Some(false)); + Ok(()) } @@ -592,6 +612,42 @@ async fn it_describes_union() -> anyhow::Result<()> { Ok(()) } +#[sqlx_macros::test] +async fn it_describes_having_group_by() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let d = conn + .describe( + r#" + WITH tweet_reply_unq as ( --tweets with a single response + SELECT tweet_id id + FROM tweet_reply + GROUP BY tweet_id + HAVING COUNT(1) = 1 + ) + SELECT + ( + SELECT COUNT(*) + FROM ( + SELECT NULL + FROM tweet + JOIN tweet_reply_unq + USING (id) + WHERE tweet.owner_id = accounts.id + ) + ) single_reply_count + FROM accounts + WHERE id = ?1 + "#, + ) + .await?; + + assert_eq!(d.column(0).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(0), Some(false)); + + Ok(()) +} + //documents failures originally found through property testing #[sqlx_macros::test] async fn it_describes_strange_queries() -> anyhow::Result<()> {