Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support duplicate column names in Substrait consumer #11048

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 137 additions & 33 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ use url::Url;
use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{
aggregate_function, expr::find_df_window_func, Aggregate, BinaryExpr, Case,
EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, ScalarUDF,
Values,
aggregate_function, builder::subquery_alias, expr::find_df_window_func, Aggregate,
BinaryExpr, Case, EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator,
Projection, ScalarUDF, Values,
};

use datafusion::logical_expr::{
Expand Down Expand Up @@ -233,10 +233,10 @@ pub async fn from_substrait_plan(
match plan.relations[0].rel_type.as_ref() {
Some(rt) => match rt {
plan_rel::RelType::Rel(rel) => {
Ok(from_substrait_rel(ctx, rel, &function_extension).await?)
Ok(from_substrait_rel(ctx, rel, &function_extension, &mut Vec::new()).await?)
},
plan_rel::RelType::Root(root) => {
let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension).await?;
let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension, &mut Vec::new()).await?;
if root.names.is_empty() {
// Backwards compatibility for plans missing names
return Ok(plan);
Expand Down Expand Up @@ -413,12 +413,13 @@ pub async fn from_substrait_rel(
ctx: &SessionContext,
rel: &Rel,
extensions: &HashMap<u32, &String>,
seen_table_names: &mut Vec<TableReference>,
) -> Result<LogicalPlan> {
match &rel.rel_type {
Some(RelType::Project(p)) => {
if let Some(input) = p.input.as_ref() {
let mut input = LogicalPlanBuilder::from(
from_substrait_rel(ctx, input, extensions).await?,
from_substrait_rel(ctx, input, extensions, seen_table_names).await?,
);
let mut exprs: Vec<Expr> = vec![];
for e in &p.expressions {
Expand Down Expand Up @@ -446,7 +447,7 @@ pub async fn from_substrait_rel(
Some(RelType::Filter(filter)) => {
if let Some(input) = filter.input.as_ref() {
let input = LogicalPlanBuilder::from(
from_substrait_rel(ctx, input, extensions).await?,
from_substrait_rel(ctx, input, extensions, seen_table_names).await?,
);
if let Some(condition) = filter.condition.as_ref() {
let expr =
Expand All @@ -463,7 +464,7 @@ pub async fn from_substrait_rel(
Some(RelType::Fetch(fetch)) => {
if let Some(input) = fetch.input.as_ref() {
let input = LogicalPlanBuilder::from(
from_substrait_rel(ctx, input, extensions).await?,
from_substrait_rel(ctx, input, extensions, seen_table_names).await?,
);
let offset = fetch.offset as usize;
// Since protobuf can't directly distinguish `None` vs `0` `None` is encoded as `MAX`
Expand All @@ -480,7 +481,7 @@ pub async fn from_substrait_rel(
Some(RelType::Sort(sort)) => {
if let Some(input) = sort.input.as_ref() {
let input = LogicalPlanBuilder::from(
from_substrait_rel(ctx, input, extensions).await?,
from_substrait_rel(ctx, input, extensions, seen_table_names).await?,
);
let sorts =
from_substrait_sorts(ctx, &sort.sorts, input.schema(), extensions)
Expand All @@ -493,7 +494,7 @@ pub async fn from_substrait_rel(
Some(RelType::Aggregate(agg)) => {
if let Some(input) = agg.input.as_ref() {
let input = LogicalPlanBuilder::from(
from_substrait_rel(ctx, input, extensions).await?,
from_substrait_rel(ctx, input, extensions, seen_table_names).await?,
);
let mut group_expr = vec![];
let mut aggr_expr = vec![];
Expand Down Expand Up @@ -589,10 +590,22 @@ pub async fn from_substrait_rel(
}

let left: LogicalPlanBuilder = LogicalPlanBuilder::from(
from_substrait_rel(ctx, join.left.as_ref().unwrap(), extensions).await?,
from_substrait_rel(
ctx,
join.left.as_ref().unwrap(),
extensions,
seen_table_names,
)
.await?,
);
let right = LogicalPlanBuilder::from(
from_substrait_rel(ctx, join.right.as_ref().unwrap(), extensions).await?,
from_substrait_rel(
ctx,
join.right.as_ref().unwrap(),
extensions,
seen_table_names,
)
.await?,
);
let join_type = from_substrait_jointype(join.r#type)?;
// The join condition expression needs full input schema and not the output schema from join since we lose columns from
Expand Down Expand Up @@ -628,11 +641,21 @@ pub async fn from_substrait_rel(
}
Some(RelType::Cross(cross)) => {
let left: LogicalPlanBuilder = LogicalPlanBuilder::from(
from_substrait_rel(ctx, cross.left.as_ref().unwrap(), extensions).await?,
from_substrait_rel(
ctx,
cross.left.as_ref().unwrap(),
extensions,
seen_table_names,
)
.await?,
);
let right =
from_substrait_rel(ctx, cross.right.as_ref().unwrap(), extensions)
.await?;
let right = from_substrait_rel(
ctx,
cross.right.as_ref().unwrap(),
extensions,
seen_table_names,
)
.await?;
left.cross_join(right)?.build()
}
Some(RelType::Read(read)) => match &read.as_ref().read_type {
Expand All @@ -656,7 +679,8 @@ pub async fn from_substrait_rel(
};
let t = ctx.table(table_reference).await?;
let t = t.into_optimized_plan()?;
extract_projection(t, &read.projection)
let t = extract_projection(t, &read.projection)?;
qualify_table_uniquely(t, seen_table_names)
}
Some(ReadType::VirtualTable(vt)) => {
let base_schema = read.base_schema.as_ref().ok_or_else(|| {
Expand All @@ -666,10 +690,13 @@ pub async fn from_substrait_rel(
let schema = from_substrait_named_struct(base_schema)?;

if vt.values.is_empty() {
return Ok(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema,
}));
return qualify_table_uniquely(
LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema,
}),
seen_table_names,
);
}

let values = vt
Expand Down Expand Up @@ -699,8 +726,8 @@ pub async fn from_substrait_rel(
Ok(lits)
})
.collect::<Result<_>>()?;

Ok(LogicalPlan::Values(Values { schema, values }))
let t = LogicalPlan::Values(Values { schema, values });
qualify_table_uniquely(t, seen_table_names)
}
Some(ReadType::LocalFiles(lf)) => {
fn extract_filename(name: &str) -> Option<String> {
Expand Down Expand Up @@ -735,7 +762,8 @@ pub async fn from_substrait_rel(
let table_reference = TableReference::Bare { table: name.into() };
let t = ctx.table(table_reference).await?;
let t = t.into_optimized_plan()?;
extract_projection(t, &read.projection)
let t = extract_projection(t, &read.projection)?;
qualify_table_uniquely(t, seen_table_names)
}
_ => not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type),
},
Expand All @@ -744,11 +772,24 @@ pub async fn from_substrait_rel(
set_rel::SetOp::UnionAll => {
if !set.inputs.is_empty() {
let mut union_builder = Ok(LogicalPlanBuilder::from(
from_substrait_rel(ctx, &set.inputs[0], extensions).await?,
from_substrait_rel(
ctx,
&set.inputs[0],
extensions,
seen_table_names,
)
.await?,
));
for input in &set.inputs[1..] {
union_builder = union_builder?
.union(from_substrait_rel(ctx, input, extensions).await?);
union_builder = union_builder?.union(
from_substrait_rel(
ctx,
input,
extensions,
seen_table_names,
)
.await?,
);
}
union_builder?.build()
} else {
Expand Down Expand Up @@ -782,7 +823,8 @@ pub async fn from_substrait_rel(
"ExtensionSingleRel doesn't contains input rel. Try use ExtensionLeafRel instead"
);
};
let input_plan = from_substrait_rel(ctx, input_rel, extensions).await?;
let input_plan =
from_substrait_rel(ctx, input_rel, extensions, seen_table_names).await?;
let plan =
plan.with_exprs_and_inputs(plan.expressions(), vec![input_plan])?;
Ok(LogicalPlan::Extension(Extension { node: plan }))
Expand All @@ -797,7 +839,8 @@ pub async fn from_substrait_rel(
.deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?;
let mut inputs = Vec::with_capacity(extension.inputs.len());
for input in &extension.inputs {
let input_plan = from_substrait_rel(ctx, input, extensions).await?;
let input_plan =
from_substrait_rel(ctx, input, extensions, seen_table_names).await?;
inputs.push(input_plan);
}
let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?;
Expand All @@ -807,7 +850,9 @@ pub async fn from_substrait_rel(
let Some(input) = exchange.input.as_ref() else {
return substrait_err!("Unexpected empty input in ExchangeRel");
};
let input = Arc::new(from_substrait_rel(ctx, input, extensions).await?);
let input = Arc::new(
from_substrait_rel(ctx, input, extensions, seen_table_names).await?,
);

let Some(exchange_kind) = &exchange.exchange_kind else {
return substrait_err!("Unexpected empty input in ExchangeRel");
Expand Down Expand Up @@ -846,6 +891,61 @@ pub async fn from_substrait_rel(
}
}

// Substrait doesn't keep table aliases, nor column aliases within the plan - columns are just
// referred to by their index in the schema. This means that column names within the plan may not
// be unique, unless we ensure each input table is qualified with a unique table reference.
fn qualify_table_uniquely(
table: LogicalPlan,
seen_table_names: &mut Vec<TableReference>,
) -> Result<LogicalPlan> {
let original_table_ref = match table.schema().columns().first() {
Some(col) => col.relation.to_owned(),
None => return Ok(table),
};

let original_table_name = original_table_ref
.to_owned()
.map(|r| r.table().to_owned())
.unwrap_or("table".to_string());

let mut new_ref = original_table_ref
.to_owned()
.unwrap_or(TableReference::bare("table"));
let mut i = 1;
while seen_table_names.contains(&new_ref) {
match new_ref {
TableReference::Bare { table: _ } => {
new_ref = TableReference::bare(format!("{}_{}", original_table_name, i));
}
TableReference::Partial { schema, table: _ } => {
new_ref = TableReference::partial(
schema.to_owned(),
format!("{}_{}", original_table_name, i),
);
}
TableReference::Full {
catalog,
schema,
table: _,
} => {
new_ref = TableReference::full(
catalog.to_owned(),
schema.to_owned(),
format!("{}_{}", original_table_name, i),
);
}
}
i += 1;
}

seen_table_names.push(new_ref.to_owned());

if original_table_ref.as_ref() == Some(&new_ref) {
return Ok(table);
}
subquery_alias(table, new_ref)
}

fn from_substrait_jointype(join_type: i32) -> Result<JoinType> {
if let Ok(substrait_join_type) = join_rel::JoinType::try_from(join_type) {
match substrait_join_type {
Expand Down Expand Up @@ -1267,9 +1367,13 @@ pub async fn from_substrait_rex(
let needle_expr = &in_predicate.needles[0];
let haystack_expr = &in_predicate.haystack;
if let Some(haystack_expr) = haystack_expr {
let haystack_expr =
from_substrait_rel(ctx, haystack_expr, extensions)
.await?;
let haystack_expr = from_substrait_rel(
ctx,
haystack_expr,
extensions,
&mut Vec::new(),
)
.await?;
let outer_refs = haystack_expr.all_out_ref_exprs();
Ok(Arc::new(Expr::InSubquery(InSubquery {
expr: Box::new(
Expand Down
Loading
Loading