Skip to content

Commit

Permalink
feat: infer types of transforms (#3845)
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen authored Nov 25, 2023
1 parent b09949f commit 1759a86
Show file tree
Hide file tree
Showing 17 changed files with 764 additions and 312 deletions.
4 changes: 2 additions & 2 deletions prqlc/prql-compiler/src/ir/decl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub enum DeclKind {

TableDecl(TableDecl),

InstanceOf(Ident),
InstanceOf(Ident, Option<Ty>),

/// A single column. Contains id of target which is either:
/// - an input relation that is source of this column or
Expand Down Expand Up @@ -190,7 +190,7 @@ impl std::fmt::Display for DeclKind {
ty.as_ref().map(write_ty).unwrap_or_default()
)
}
Self::InstanceOf(arg0) => write!(f, "InstanceOf: {arg0}"),
Self::InstanceOf(arg0, _) => write!(f, "InstanceOf: {arg0}"),
Self::Column(arg0) => write!(f, "Column (target {arg0})"),
Self::Infer(arg0) => write!(f, "Infer (default: {arg0})"),
Self::Expr(arg0) => write!(f, "Expr: {}", write_pl(*arg0.clone())),
Expand Down
4 changes: 2 additions & 2 deletions prqlc/prql-compiler/src/ir/pl/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ pub struct Expr {
pub enum ExprKind {
Ident(Ident),
All {
within: Ident,
except: Vec<Expr>,
within: Box<Expr>,
except: Box<Expr>,
},
Literal(Literal),

Expand Down
4 changes: 2 additions & 2 deletions prqlc/prql-compiler/src/ir/pl/fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ pub fn fold_expr_kind<T: ?Sized + PlFold>(fold: &mut T, expr_kind: ExprKind) ->
Ok(match expr_kind {
Ident(ident) => Ident(ident),
All { within, except } => All {
within,
except: fold.fold_exprs(except)?,
within: Box::new(fold.fold_expr(*within)?),
except: Box::new(fold.fold_expr(*except)?),
},
Tuple(items) => Tuple(fold.fold_exprs(items)?),
Array(items) => Array(fold.fold_exprs(items)?),
Expand Down
4 changes: 2 additions & 2 deletions prqlc/prql-compiler/src/semantic/ast_expand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ fn restrict_expr_kind(value: pl::ExprKind) -> ExprKind {
pl::ExprKind::Internal(v) => ExprKind::Internal(v),

// TODO: these are not correct, they are producing invalid PRQL
pl::ExprKind::All { within, .. } => ExprKind::Ident(within),
pl::ExprKind::All { within, .. } => restrict_expr(*within).kind,
pl::ExprKind::TransformCall(tc) => ExprKind::Ident(Ident::from_name(format!(
"({} ...)",
tc.kind.as_ref().as_ref()
Expand Down Expand Up @@ -459,7 +459,7 @@ fn restrict_decl(name: String, value: decl::Decl) -> Option<Stmt> {
ty: table_decl.ty,
}),

decl::DeclKind::InstanceOf(ident) => {
decl::DeclKind::InstanceOf(ident, _) => {
new_internal_stmt(name, format!("instance_of.{ident}"))
}
decl::DeclKind::Column(id) => new_internal_stmt(name, format!("column.{id}")),
Expand Down
132 changes: 74 additions & 58 deletions prqlc/prql-compiler/src/semantic/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -643,33 +643,9 @@ impl Lowerer {
let mut r = Vec::new();

match exprs.kind {
pl::ExprKind::All { except, .. } => {
pl::ExprKind::All { .. } => {
// special case: ExprKind::All
let mut selected = Vec::<CId>::new();
for target_id in exprs.target_ids {
match &self.node_mapping[&target_id] {
LoweredTarget::Compute(cid) => {
selected.push(*cid);
}
LoweredTarget::Input(input) => {
let mut cols = input.iter().collect_vec();
cols.sort_by_key(|c| c.1 .1);
selected.extend(cols.into_iter().map(|(_, (cid, _))| cid));
}
}
}

let except: HashSet<CId> = except
.into_iter()
.filter(|e| e.target_id.is_some())
.map(|e| {
let id = e.target_id.unwrap();
self.lookup_cid(id, Some(&e.kind.into_ident().unwrap().name))
})
.try_collect()?;
selected.retain(|c| !except.contains(c));

r.extend(selected);
r.extend(self.find_selected_all(exprs)?);
}
pl::ExprKind::Tuple(fields) => {
// tuple unpacking
Expand All @@ -685,6 +661,61 @@ impl Lowerer {
Ok(r)
}

fn find_selected_all(&mut self, expr: pl::Expr) -> Result<Vec<CId>> {
let pl::ExprKind::All { except, .. } = expr.kind else {
unreachable!()
};

let mut selected = Vec::<CId>::new();
for target_id in &expr.target_ids {
match self.node_mapping.get(target_id) {
Some(LoweredTarget::Compute(cid)) => selected.push(*cid),
Some(LoweredTarget::Input(input_columns)) => {
let mut cols = input_columns.iter().collect_vec();
cols.sort_by_key(|c| c.1 .1);
selected.extend(cols.into_iter().map(|(_, (cid, _))| cid));
}
_ => {}
}
}
let except: HashSet<_> = self.find_except_ids(*except)?;
selected.retain(|t| !except.contains(t));
Ok(selected)
}

fn find_except_ids(&mut self, except: pl::Expr) -> Result<HashSet<CId>> {
let pl::ExprKind::Tuple(fields) = except.kind else {
return Ok(HashSet::new());
};

let mut res = HashSet::new();
for e in fields {
if e.target_id.is_none() {
continue;
}

let id = e.target_id.unwrap();
match e.kind {
pl::ExprKind::Ident(ident) => {
res.insert(
self.lookup_cid(id, Some(&ident.name))
.with_span(except.span)?,
);
}
pl::ExprKind::All { .. } => res.extend(self.find_selected_all(e)?),
_ => {
return Err(Error::new(Reason::Expected {
who: None,
expected: "an identifier".to_string(),
found: write_pl(e),
})
.into());
}
}
}
Ok(res)
}

fn declare_as_column(
&mut self,
mut expr_ast: pl::Expr,
Expand Down Expand Up @@ -741,6 +772,8 @@ impl Lowerer {
}

fn lower_expr(&mut self, expr: pl::Expr) -> Result<rq::Expr> {
let span = expr.span;

if expr.needs_window {
let span = expr.span;
let cid = self.declare_as_column(expr, false)?;
Expand All @@ -754,7 +787,7 @@ impl Lowerer {
log::debug!("lowering ident {ident} (target {:?})", expr.target_id);

if let Some(id) = expr.target_id {
let cid = self.lookup_cid(id, Some(&ident.name))?;
let cid = self.lookup_cid(id, Some(&ident.name)).with_span(span)?;

rq::ExprKind::ColumnRef(cid)
} else {
Expand All @@ -763,37 +796,15 @@ impl Lowerer {
rq::ExprKind::SString(vec![InterpolateItem::String(ident.name)])
}
}
pl::ExprKind::All { except, .. } => {
let mut targets = Vec::new();

for target_id in &expr.target_ids {
match self.node_mapping.get(target_id) {
Some(LoweredTarget::Compute(cid)) => targets.push(*cid),
Some(LoweredTarget::Input(input_columns)) => {
targets.extend(input_columns.values().map(|(c, _)| c))
}
_ => {}
}
}

// this is terrible code
let except: HashSet<_> = except
.iter()
.map(|e| {
let ident = e.kind.as_ident().unwrap();
self.lookup_cid(e.target_id.unwrap(), Some(&ident.name))
.unwrap()
})
.collect();
pl::ExprKind::All { .. } => {
let selected = self.find_selected_all(expr)?;

targets.retain(|t| !except.contains(t));

if targets.len() == 1 {
rq::ExprKind::ColumnRef(targets[0])
if selected.len() == 1 {
rq::ExprKind::ColumnRef(selected[0])
} else {
return Err(
Error::new_simple("This wildcard usage is not yet supported.")
.with_span(expr.span)
.with_span(span)
.into(),
);
}
Expand Down Expand Up @@ -853,10 +864,7 @@ impl Lowerer {
}
};

Ok(rq::Expr {
kind,
span: expr.span,
})
Ok(rq::Expr { kind, span })
}

fn lower_interpolations(
Expand All @@ -881,6 +889,14 @@ impl Lowerer {
let cid = match self.node_mapping.get(&id) {
Some(LoweredTarget::Compute(cid)) => *cid,
Some(LoweredTarget::Input(input_columns)) => {
if name.map_or(false, |x| x == "_self") {
return Err(
Error::new_simple("table instance cannot be referenced directly")
.push_hint("did you forget to specify the column name?")
.into(),
);
}

let name = match name {
Some(v) => RelationColumn::Single(Some(v.clone())),
None => return Err(Error::new_simple(
Expand Down
50 changes: 43 additions & 7 deletions prqlc/prql-compiler/src/semantic/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::collections::{HashMap, HashSet};

use anyhow::Result;
use prqlc_ast::stmt::QueryDef;
use prqlc_ast::{Span, TupleField, Ty};
use prqlc_ast::{Literal, Span, TupleField, Ty, TyKind};

use crate::ir::pl::{Annotation, Expr, Ident, Lineage, LineageColumn};
use crate::Error;
Expand Down Expand Up @@ -196,14 +196,18 @@ impl Module {
res
}

pub(super) fn insert_frame(&mut self, frame: &Lineage, namespace: &str) {
pub(super) fn insert_frame(&mut self, lineage: &Lineage, namespace: &str) {
let namespace = self.names.entry(namespace.to_string()).or_default();
let namespace = namespace.kind.as_module_mut().unwrap();

for (col_index, column) in frame.columns.iter().enumerate() {
let lin_ty = *ty_of_lineage(lineage).kind.into_array().unwrap();

for (col_index, column) in lineage.columns.iter().enumerate() {
// determine input name
let input_name = match column {
LineageColumn::All { input_id, .. } => frame.find_input(*input_id).map(|i| &i.name),
LineageColumn::All { input_id, .. } => {
lineage.find_input(*input_id).map(|i| &i.name)
}
LineageColumn::Single { name, .. } => name.as_ref().and_then(|n| n.path.first()),
};

Expand All @@ -215,12 +219,22 @@ impl Module {
None => {
namespace.redirects.push(Ident::from_name(input_name));

let input = frame.find_input_by_name(input_name).unwrap();
let input = lineage.find_input_by_name(input_name).unwrap();
let mut sub_ns = Module::default();

let self_ty = lin_ty.clone().kind.into_tuple().unwrap();
let self_ty = self_ty
.into_iter()
.flat_map(|x| x.into_single())
.find(|(name, _)| name.as_ref() == Some(input_name))
.and_then(|(_, ty)| ty)
.or(Some(Ty::new(TyKind::Tuple(vec![TupleField::Wildcard(
None,
)]))));

let self_decl = Decl {
declared_at: Some(input.id),
kind: DeclKind::InstanceOf(input.table.clone()),
kind: DeclKind::InstanceOf(input.table.clone(), self_ty),
..Default::default()
};
sub_ns.names.insert(NS_SELF.to_string(), self_decl);
Expand All @@ -242,7 +256,7 @@ impl Module {
// insert column decl
match column {
LineageColumn::All { input_id, .. } => {
let input = frame.find_input(*input_id).unwrap();
let input = lineage.find_input(*input_id).unwrap();

let kind = DeclKind::Infer(Box::new(DeclKind::Column(input.id)));
let declared_at = Some(input.id);
Expand Down Expand Up @@ -270,6 +284,12 @@ impl Module {
_ => {}
}
}

// insert namespace._self with correct type
namespace.names.insert(
NS_SELF.to_string(),
Decl::from(DeclKind::InstanceOf(Ident::from_name(""), Some(lin_ty))),
);
}

pub(super) fn insert_frame_col(&mut self, namespace: &str, name: String, id: usize) {
Expand Down Expand Up @@ -466,6 +486,22 @@ impl RootModule {
}
}

pub fn ty_of_lineage(lineage: &Lineage) -> Ty {
Ty::relation(
lineage
.columns
.iter()
.map(|col| match col {
LineageColumn::All { .. } => TupleField::Wildcard(None),
LineageColumn::Single { name, .. } => TupleField::Single(
name.as_ref().map(|i| i.name.clone()),
Some(Ty::new(Literal::Null)),
),
})
.collect(),
)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
2 changes: 1 addition & 1 deletion prqlc/prql-compiler/src/semantic/reporting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ impl<'a> PlFold for Labeler<'a> {
DeclKind::Expr(_) => Color::Blue,
DeclKind::Ty(_) => Color::Green,
DeclKind::Column { .. } => Color::Yellow,
DeclKind::InstanceOf(_) => Color::Yellow,
DeclKind::InstanceOf(_, _) => Color::Yellow,
DeclKind::TableDecl { .. } => Color::Red,
DeclKind::Module(module) => {
self.label_module(module);
Expand Down
Loading

0 comments on commit 1759a86

Please sign in to comment.