Skip to content

Commit

Permalink
add compiler parameter for database module
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen committed Jan 24, 2024
1 parent 299965f commit 1018393
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 40 deletions.
22 changes: 22 additions & 0 deletions prqlc/prqlc-ast/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,29 @@ impl std::fmt::Display for Errors {
}

pub trait WithErrorInfo: Sized {
fn span(&self) -> Option<&Span>;

fn push_hint<S: Into<String>>(self, hint: S) -> Self;

fn with_hints<S: Into<String>, I: IntoIterator<Item = S>>(self, hints: I) -> Self;

fn with_span(self, span: Option<Span>) -> Self;
fn with_code(self, code: &'static str) -> Self;

fn with_span_if_not_exists(self, span: Option<Span>) -> Self {
if self.span().is_none() {
self.with_span(span)
} else {
self
}
}
}

impl WithErrorInfo for Error {
fn span(&self) -> Option<&Span> {
self.span.as_ref()
}

fn with_hints<S: Into<String>, I: IntoIterator<Item = S>>(mut self, hints: I) -> Self {
self.hints = hints.into_iter().map(|x| x.into()).collect();
self
Expand All @@ -144,6 +158,10 @@ impl WithErrorInfo for Error {

#[cfg(feature = "anyhow")]
impl WithErrorInfo for anyhow::Error {
fn span(&self) -> Option<&Span> {
self.downcast_ref::<Error>().and_then(|e| e.span.as_ref())
}

fn push_hint<S: Into<String>>(self, hint: S) -> Self {
self.downcast_ref::<Error>()
.map(|e| e.clone().push_hint(hint).into())
Expand Down Expand Up @@ -173,6 +191,10 @@ impl WithErrorInfo for anyhow::Error {
}

impl<T, E: WithErrorInfo> WithErrorInfo for Result<T, E> {
fn span(&self) -> Option<&Span> {
self.as_ref().err().and_then(|x| x.span())
}

fn with_hints<S: Into<String>, I: IntoIterator<Item = S>>(self, hints: I) -> Self {
self.map_err(|e| e.with_hints(hints))
}
Expand Down
14 changes: 11 additions & 3 deletions prqlc/prqlc-ast/src/expr/ident.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,18 @@ impl Ident {
.all(|(prefix_component, self_component)| prefix_component == self_component)
}

pub fn starts_with_path<S: AsRef<str>>(&self, prefix: &[S]) -> bool {
if prefix.len() > self.path.len() + 1 {
return false;
}
prefix
.iter()
.zip(self.iter())
.all(|(prefix_component, self_component)| prefix_component.as_ref() == self_component)
}

pub fn starts_with_part(&self, prefix: &str) -> bool {
self.iter()
.next()
.map_or(false, |self_component| self_component == prefix)
self.starts_with_path(&[prefix])
}
}

Expand Down
9 changes: 5 additions & 4 deletions prqlc/prqlc/src/cli/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use clap::{CommandFactory, Parser, Subcommand, ValueHint};
use clio::has_extension;
use clio::Output;
use itertools::Itertools;
use prqlc::semantic::NS_DEFAULT_DB;
use prqlc_ast::stmt::StmtKind;
use std::collections::HashMap;
use std::env;
Expand Down Expand Up @@ -376,7 +377,7 @@ impl Command {
semantic::load_std_lib(sources);

let ast = prql_to_pl_tree(sources)?;
let ir = pl_to_rq_tree(ast, &main_path)?;
let ir = pl_to_rq_tree(ast, &main_path, &[NS_DEFAULT_DB.to_string()])?;

match format {
Format::Json => serde_json::to_string_pretty(&ir)?.into_bytes(),
Expand All @@ -397,7 +398,7 @@ impl Command {
.with_format(*format);

prql_to_pl_tree(sources)
.and_then(|pl| pl_to_rq_tree(pl, &main_path))
.and_then(|pl| pl_to_rq_tree(pl, &main_path, &[NS_DEFAULT_DB.to_string()]))
.and_then(|rq| rq_to_sql(rq, &opts))
.map_err(|e| e.composed(sources))?
.as_bytes()
Expand All @@ -408,15 +409,15 @@ impl Command {
semantic::load_std_lib(sources);

let ast = prql_to_pl_tree(sources)?;
let rq = pl_to_rq_tree(ast, &main_path)?;
let rq = pl_to_rq_tree(ast, &main_path, &[NS_DEFAULT_DB.to_string()])?;
let srq = prqlc::sql::internal::preprocess(rq)?;
format!("{srq:#?}").as_bytes().to_vec()
}
Command::SQLAnchor { format, .. } => {
semantic::load_std_lib(sources);

let ast = prql_to_pl_tree(sources)?;
let rq = pl_to_rq_tree(ast, &main_path)?;
let rq = pl_to_rq_tree(ast, &main_path, &[NS_DEFAULT_DB.to_string()])?;
let srq = prqlc::sql::internal::anchor(rq)?;

let json = serde_json::to_string_pretty(&srq)?;
Expand Down
10 changes: 7 additions & 3 deletions prqlc/prqlc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ pub fn compile(prql: &str, options: &Options) -> Result<String, ErrorMessages> {
semantic::load_std_lib(&mut sources);

parser::parse(&sources)
.and_then(|ast| semantic::resolve_and_lower(ast, &[]))
.and_then(|ast| semantic::resolve_and_lower(ast, &[], None))
.and_then(|rq| sql::compile(rq, options))
.map_err(error_message::downcast)
.map_err(|e| e.composed(&prql.into()))
Expand Down Expand Up @@ -269,6 +269,7 @@ impl Options {
pub struct ReadmeDoctests;

/// Parse PRQL into a PL AST
// TODO: rename this to `prql_to_pl_simple`
pub fn prql_to_pl(prql: &str) -> Result<Vec<prqlc_ast::stmt::Stmt>, ErrorMessages> {
let sources = SourceTree::from(prql);

Expand All @@ -288,17 +289,20 @@ pub fn prql_to_pl_tree(
}

/// Perform semantic analysis and convert PL to RQ.
// TODO: rename this to `pl_to_rq_simple`
pub fn pl_to_rq(pl: Vec<prqlc_ast::stmt::Stmt>) -> Result<ir::rq::RelationalQuery, ErrorMessages> {
let source_tree = SourceTree::single(PathBuf::new(), pl);
semantic::resolve_and_lower(source_tree, &[]).map_err(error_message::downcast)
semantic::resolve_and_lower(source_tree, &[], None).map_err(error_message::downcast)
}

/// Perform semantic analysis and convert PL to RQ.
pub fn pl_to_rq_tree(
pl: SourceTree<Vec<prqlc_ast::stmt::Stmt>>,
main_path: &[String],
database_module_path: &[String],
) -> Result<ir::rq::RelationalQuery, ErrorMessages> {
semantic::resolve_and_lower(pl, main_path).map_err(error_message::downcast)
semantic::resolve_and_lower(pl, main_path, Some(database_module_path))
.map_err(error_message::downcast)
}

/// Generate SQL from RQ.
Expand Down
69 changes: 48 additions & 21 deletions prqlc/prqlc/src/semantic/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,22 @@ use crate::COMPILER_VERSION;
use crate::{Error, Reason, Span, WithErrorInfo};
use prqlc_ast::expr::generic::{InterpolateItem, Range, SwitchCase};

use super::NS_DEFAULT_DB;

/// Convert AST into IR and make sure that:
/// Convert a resolved expression at path `main_path` relative to `root_mod`
/// into RQ and make sure that:
/// - transforms are not nested,
/// - transforms have correct partition, window and sort set,
/// - make sure there are no unresolved expressions.
///
/// All table references must reside within module at `database_module_path`.
/// They are compiled to table identifiers, using their path relative to the database module.
/// For example, with `database_module_path=my_database`:
/// - `my_database.my_table` will compile to `"my_table"`,
/// - `my_database.my_schema.my_table` will compile to `"my_schema.my_table"`,
/// - `my_table` will error out saying that this table does not reside in current database.
pub fn lower_to_ir(
root_mod: RootModule,
main_path: &[String],
database_module_path: &[String],
) -> Result<(RelationalQuery, RootModule)> {
// find main
log::debug!("lookup for main pipeline in {main_path:?}");
Expand All @@ -50,12 +57,16 @@ pub fn lower_to_ir(
let tables = toposort_tables(tables, &main_ident);

// lower tables
let mut l = Lowerer::new(root_mod);
let mut l = Lowerer::new(root_mod, database_module_path);
let mut main_relation = None;
for (fq_ident, table) in tables {
for (fq_ident, (table, declared_at)) in tables {
let is_main = fq_ident == main_ident;

l.lower_table_decl(table, fq_ident)?;
let span = declared_at
.and_then(|id| l.root_mod.span_map.get(&id))
.cloned();

l.lower_table_decl(table, fq_ident).with_span_if_not_exists(span)?;

if is_main {
let main_table = l.table_buffer.pop().unwrap();
Expand All @@ -74,13 +85,24 @@ pub fn lower_to_ir(
fn extern_ref_to_relation(
mut columns: Vec<TupleField>,
fq_ident: &Ident,
) -> (rq::Relation, Option<String>) {
let extern_name = if fq_ident.starts_with_part(NS_DEFAULT_DB) {
let (_, remainder) = fq_ident.clone().pop_front();
remainder.unwrap()
database_module_path: &[String],
) -> Result<(rq::Relation, Option<String>), Error> {
let extern_name = if fq_ident.starts_with_path(database_module_path) {
let relative_to_database: Vec<&String> =
fq_ident.iter().skip(database_module_path.len()).collect();
if relative_to_database.is_empty() {
None
} else {
Some(Ident::from_path(relative_to_database))
}
} else {
// tables that are not from default_db: use full name
fq_ident.clone()
None
};

let Some(extern_name) = extern_name else {
let database_module = Ident::from_path(database_module_path.to_vec());
return Err(Error::new_simple("this table is not in the current database")
.push_hint(format!("If this is a table in the current database, move its declaration into module {database_module}")));
};

// put wildcards last
Expand All @@ -90,7 +112,7 @@ fn extern_ref_to_relation(
kind: rq::RelationKind::ExternRef(extern_name),
columns: tuple_fields_to_relation_columns(columns),
};
(relation, None)
Ok((relation, None))
}

fn tuple_fields_to_relation_columns(columns: Vec<TupleField>) -> Vec<RelationColumn> {
Expand Down Expand Up @@ -118,6 +140,7 @@ struct Lowerer {
tid: IdGenerator<TId>,

root_mod: RootModule,
database_module_path: Vec<String>,

/// describes what has certain id has been lowered to
node_mapping: HashMap<usize, LoweredTarget>,
Expand Down Expand Up @@ -146,9 +169,10 @@ enum LoweredTarget {
}

impl Lowerer {
fn new(root_mod: RootModule) -> Self {
fn new(root_mod: RootModule, database_module_path: &[String]) -> Self {
Lowerer {
root_mod,
database_module_path: database_module_path.to_vec(),

cid: IdGenerator::new(),
tid: IdGenerator::new(),
Expand All @@ -173,7 +197,9 @@ impl Lowerer {
// a CTE
(self.lower_relation(*expr)?, Some(fq_ident.name.clone()))
}
TableExpr::LocalTable => extern_ref_to_relation(columns, &fq_ident),
TableExpr::LocalTable => {
extern_ref_to_relation(columns, &fq_ident, &self.database_module_path)?
}
TableExpr::Param(_) => unreachable!(),
TableExpr::None => return Ok(()),
};
Expand Down Expand Up @@ -1008,12 +1034,12 @@ fn validate_take_range(range: &Range<rq::Expr>, span: Option<Span>) -> Result<()
struct TableExtractor {
path: Vec<String>,

tables: Vec<(Ident, decl::TableDecl)>,
tables: Vec<(Ident, (decl::TableDecl, Option<usize>))>,
}

impl TableExtractor {
/// Finds table declarations in a module, recursively.
fn extract(root_module: &Module) -> Vec<(Ident, decl::TableDecl)> {
fn extract(root_module: &Module) -> Vec<(Ident, (decl::TableDecl, Option<usize>))> {
let mut te = TableExtractor::default();
te.extract_from_module(root_module);
te.tables
Expand All @@ -1030,7 +1056,8 @@ impl TableExtractor {
}
DeclKind::TableDecl(table) => {
let fq_ident = Ident::from_path(self.path.clone());
self.tables.push((fq_ident, table.clone()));
self.tables
.push((fq_ident, (table.clone(), entry.declared_at)));
}
_ => {}
}
Expand All @@ -1043,14 +1070,14 @@ impl TableExtractor {
/// are not needed for the main pipeline. To do this, it needs to collect references
/// between pipelines.
fn toposort_tables(
tables: Vec<(Ident, decl::TableDecl)>,
tables: Vec<(Ident, (decl::TableDecl, Option<usize>))>,
main_table: &Ident,
) -> Vec<(Ident, decl::TableDecl)> {
) -> Vec<(Ident, (decl::TableDecl, Option<usize>))> {
let tables: HashMap<_, _, RandomState> = HashMap::from_iter(tables);

let mut dependencies: Vec<(Ident, Vec<Ident>)> = Vec::new();
for (ident, table) in &tables {
let deps = if let TableExpr::RelationVar(e) = &table.expr {
let deps = if let TableExpr::RelationVar(e) = &table.0.expr {
TableDepsCollector::collect(*e.clone())
} else {
vec![]
Expand Down
7 changes: 5 additions & 2 deletions prqlc/prqlc/src/semantic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@ use crate::{Error, Reason, SourceTree};
pub fn resolve_and_lower(
file_tree: SourceTree<Vec<prqlc_ast::stmt::Stmt>>,
main_path: &[String],
database_module_path: Option<&[String]>,
) -> Result<RelationalQuery> {
let root_mod = resolve(file_tree, Default::default())?;

let (query, _) = lowering::lower_to_ir(root_mod, main_path)?;
let default_db = [NS_DEFAULT_DB.to_string()];
let database_module_path = database_module_path.unwrap_or(&default_db);
let (query, _) = lowering::lower_to_ir(root_mod, main_path, database_module_path)?;
Ok(query)
}

Expand Down Expand Up @@ -265,7 +268,7 @@ pub mod test {

pub fn parse_resolve_and_lower(query: &str) -> Result<RelationalQuery> {
let source_tree = query.into();
resolve_and_lower(parse(&source_tree)?, &[])
resolve_and_lower(parse(&source_tree)?, &[], None)
}

pub fn parse_and_resolve(query: &str) -> Result<RootModule> {
Expand Down
10 changes: 9 additions & 1 deletion prqlc/prqlc/tests/integration/bad_error_messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,14 @@ fn nested_groups() {
)
)
"###).unwrap_err(), @r###"
Error: internal compiler error; tracked at https://github.com/PRQL/prql/issues/3870
Error:
╭─[:2:5]
2 │ ╭─▶ from inv=invoices
┆ ┆
12 │ ├─▶ )
│ │
│ ╰─────────── internal compiler error; tracked at https://github.com/PRQL/prql/issues/3870
────╯
"###);
}
14 changes: 8 additions & 6 deletions prqlc/prqlc/tests/integration/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2273,7 +2273,7 @@ fn test_from_json() {
prqlc::semantic::load_std_lib(&mut source_tree);

let sql_from_prql = Ok(prqlc::prql_to_pl_tree(&source_tree).unwrap())
.and_then(|ast| prqlc::semantic::resolve_and_lower(ast, &[]))
.and_then(|ast| prqlc::semantic::resolve_and_lower(ast, &[], None))
.and_then(|rq| sql::compile(rq, &Options::default()))
.unwrap();

Expand Down Expand Up @@ -4935,13 +4935,15 @@ fn test_group_exclude() {
fn test_table_declarations() {
assert_display_snapshot!(compile(
r###"
module my_schema {
let my_table <[{ id = int, a = text }]>
}
module default_db {
module my_schema {
let my_table <[{ id = int, a = text }]>
}
let another_table <[{ id = int, b = text }]>
let another_table <[{ id = int, b = text }]>
}
my_schema.my_table | join another_table (==id) | take 10
from my_schema.my_table | join another_table (==id) | take 10
"###,
)
.unwrap(), @r###"
Expand Down

0 comments on commit 1018393

Please sign in to comment.