Skip to content

Commit

Permalink
feat(optimizer): ensure column name unique when create MV (risingwave…
Browse files Browse the repository at this point in the history
…labs#1363)

* add result on gen_create_mv_plan

* add result on gen_create_mv_plan

* refactor row id

* refactor row id

* Materialize derive schema

* refactor gen_row_id_column_name

* project schema(name) derive

* fix ut

* add alias_check if is row id

* add column name check if is row id

* check alias not prefixed with _row_id

* clippy fix

* ut
  • Loading branch information
st1page authored Mar 29, 2022
1 parent 4866021 commit 9d731c2
Show file tree
Hide file tree
Showing 20 changed files with 217 additions and 140 deletions.
9 changes: 9 additions & 0 deletions rust/frontend/src/binder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use risingwave_sqlparser::ast::{Expr, Select, SelectItem};
use super::bind_context::{Clause, ColumnBinding};
use super::UNNAMED_COLUMN;
use crate::binder::{Binder, Relation};
use crate::catalog::{is_row_id_column_name, ROWID_PREFIX};
use crate::expr::{Expr as _, ExprImpl, InputRef};

#[derive(Debug)]
Expand Down Expand Up @@ -118,6 +119,14 @@ impl Binder {
aliases.push(alias);
}
SelectItem::ExprWithAlias { expr, alias } => {
if is_row_id_column_name(&alias.value) {
return Err(ErrorCode::InternalError(format!(
"column name prefixed with {:?} are reserved word.",
ROWID_PREFIX
))
.into());
}

let expr = self.bind_expr(expr)?;
select_list.push(expr);
aliases.push(Some(alias.value));
Expand Down
10 changes: 10 additions & 0 deletions rust/frontend/src/catalog/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ pub(crate) type SchemaId = u32;
pub(crate) type TableId = risingwave_common::catalog::TableId;
pub(crate) type ColumnId = risingwave_common::catalog::ColumnId;

pub const ROWID_PREFIX: &str = "_row_id";

pub fn gen_row_id_column_name(idx: usize) -> String {
ROWID_PREFIX.to_string() + "#" + &idx.to_string()
}

pub fn is_row_id_column_name(name: &str) -> bool {
name.starts_with(ROWID_PREFIX)
}

#[derive(Error, Debug)]
pub enum CatalogError {
#[error("{0} not found: {1}")]
Expand Down
8 changes: 4 additions & 4 deletions rust/frontend/src/catalog/table_catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ mod tests {
use risingwave_pb::plan::{ColumnCatalog as ProstColumnCatalog, ColumnDesc as ProstColumnDesc};

use crate::catalog::column_catalog::ColumnCatalog;
use crate::catalog::gen_row_id_column_name;
use crate::catalog::table_catalog::TableCatalog;
use crate::handler::create_table::ROWID_NAME;

#[test]
fn test_into_table_catalog() {
Expand All @@ -173,7 +173,7 @@ mod tests {
ProstColumnCatalog {
column_desc: Some(ProstColumnDesc {
column_id: 0,
name: ROWID_NAME.to_string(),
name: gen_row_id_column_name(0),
field_descs: vec![],
column_type: Some(DataType::Int32.to_protobuf()),
type_name: String::new(),
Expand Down Expand Up @@ -220,7 +220,7 @@ mod tests {
column_desc: ColumnDesc {
data_type: DataType::Int32,
column_id: ColumnId::new(0),
name: ROWID_NAME.to_string(),
name: gen_row_id_column_name(0),
field_descs: vec![],
type_name: String::new()
},
Expand Down Expand Up @@ -258,7 +258,7 @@ mod tests {
column_desc: ColumnDesc {
data_type: DataType::Int32,
column_id: ColumnId::new(0),
name: ROWID_NAME.to_string(),
name: gen_row_id_column_name(0),
field_descs: vec![],
type_name: String::new()
},
Expand Down
2 changes: 1 addition & 1 deletion rust/frontend/src/handler/create_mv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub fn gen_create_mv_plan(

let mut plan_root = Planner::new(context).plan_query(bound)?;
plan_root.set_required_dist(Distribution::any().clone());
let materialize = plan_root.gen_create_mv_plan(table_name);
let materialize = plan_root.gen_create_mv_plan(table_name)?;
let table = materialize.table().to_prost(schema_id, database_id);
let plan: PlanRef = materialize.into();

Expand Down
8 changes: 5 additions & 3 deletions rust/frontend/src/handler/create_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use risingwave_source::ProtobufParser;
use risingwave_sqlparser::ast::{CreateSourceStatement, ProtobufSchema, SourceSchema};

use crate::binder::expr::bind_data_type;
use crate::handler::create_table::ROWID_NAME;
use crate::catalog::gen_row_id_column_name;
use crate::session::OptimizerContext;

fn extract_protobuf_table_schema(schema: &ProtobufSchema) -> Result<Vec<ColumnCatalog>> {
Expand Down Expand Up @@ -55,7 +55,7 @@ pub(super) async fn handle_create_source(
let mut column_catalogs = vec![ColumnCatalog {
column_desc: Some(ColumnDesc {
column_id: 0,
name: ROWID_NAME.to_string(),
name: gen_row_id_column_name(0),
column_type: Some(DataType::Int32.to_protobuf()),
field_descs: vec![],
type_name: "".to_string(),
Expand Down Expand Up @@ -142,6 +142,7 @@ pub mod tests {

use super::*;
use crate::catalog::column_catalog::ColumnCatalog;
use crate::catalog::gen_row_id_column_name;
use crate::test_utils::LocalFrontend;

/// Returns the file.
Expand Down Expand Up @@ -220,8 +221,9 @@ pub mod tests {
let city_type = DataType::Struct {
fields: vec![DataType::Varchar, DataType::Varchar].into(),
};
let row_id_col_name = gen_row_id_column_name(0);
let expected_columns = maplit::hashmap! {
ROWID_NAME => DataType::Int32,
row_id_col_name.as_str() => DataType::Int32,
"id" => DataType::Int32,
"country.zipcode" => DataType::Varchar,
"zipcode" => DataType::Int64,
Expand Down
22 changes: 15 additions & 7 deletions rust/frontend/src/handler/create_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use fixedbitset::FixedBitSet;
use itertools::Itertools;
use pgwire::pg_response::{PgResponse, StatementType};
use risingwave_common::catalog::{ColumnDesc, ColumnId, TableId};
use risingwave_common::error::Result;
use risingwave_common::error::{ErrorCode, Result};
use risingwave_common::types::DataType;
use risingwave_pb::catalog::source::Info;
use risingwave_pb::catalog::{Source as ProstSource, Table as ProstTable, TableSourceInfo};
Expand All @@ -25,13 +25,12 @@ use risingwave_sqlparser::ast::{ColumnDef, ObjectName};

use crate::binder::expr::bind_data_type;
use crate::binder::Binder;
use crate::catalog::{gen_row_id_column_name, is_row_id_column_name, ROWID_PREFIX};
use crate::optimizer::plan_node::StreamSource;
use crate::optimizer::property::{Distribution, Order};
use crate::optimizer::{PlanRef, PlanRoot};
use crate::session::{OptimizerContext, OptimizerContextRef, SessionImpl};

pub const ROWID_NAME: &str = "_row_id";

pub fn gen_create_table_plan(
session: &SessionImpl,
context: OptimizerContextRef,
Expand All @@ -51,12 +50,20 @@ pub fn gen_create_table_plan(
column_descs.push(ColumnDesc {
data_type: DataType::Int64,
column_id: ColumnId::new(0),
name: ROWID_NAME.to_string(),
name: gen_row_id_column_name(0),
field_descs: vec![],
type_name: "".to_string(),
});
// Then user columns.
for (i, column) in columns.into_iter().enumerate() {
if is_row_id_column_name(&column.name.value) {
return Err(ErrorCode::InternalError(format!(
"column name prefixed with {:?} are reserved word.",
ROWID_PREFIX
))
.into());
}

column_descs.push(ColumnDesc {
data_type: bind_data_type(&column.data_type)?,
column_id: ColumnId::new((i + 1) as i32),
Expand Down Expand Up @@ -106,7 +113,7 @@ pub fn gen_create_table_plan(
Order::any().clone(),
required_cols,
)
.gen_create_mv_plan(table_name)
.gen_create_mv_plan(table_name)?
};
let table = materialize.table().to_prost(schema_id, database_id);

Expand Down Expand Up @@ -154,7 +161,7 @@ mod tests {
use risingwave_common::catalog::{DEFAULT_DATABASE_NAME, DEFAULT_SCHEMA_NAME};
use risingwave_common::types::DataType;

use super::*;
use crate::catalog::gen_row_id_column_name;
use crate::test_utils::LocalFrontend;

#[tokio::test]
Expand Down Expand Up @@ -188,8 +195,9 @@ mod tests {
.map(|col| (col.name(), col.data_type().clone()))
.collect::<HashMap<&str, DataType>>();

let row_id_col_name = gen_row_id_column_name(0);
let expected_columns = maplit::hashmap! {
ROWID_NAME => DataType::Int64,
row_id_col_name.as_str() => DataType::Int64,
"v1" => DataType::Int16,
"v2" => DataType::Int32,
"v3" => DataType::Int64,
Expand Down
3 changes: 2 additions & 1 deletion rust/frontend/src/optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use fixedbitset::FixedBitSet;
use itertools::Itertools as _;
use property::{Distribution, Order};
use risingwave_common::catalog::Schema;
use risingwave_common::error::Result;

use self::heuristic::{ApplyOrder, HeuristicOptimizer};
use self::plan_node::{LogicalProject, StreamMaterialize};
Expand Down Expand Up @@ -170,7 +171,7 @@ impl PlanRoot {
///
/// The `MaterializeExecutor` won't be generated at this stage, and will be attached in
/// `gen_create_mv_plan`.
pub fn gen_create_mv_plan(&mut self, mv_name: String) -> StreamMaterialize {
pub fn gen_create_mv_plan(&mut self, mv_name: String) -> Result<StreamMaterialize> {
let stream_plan = match self.plan.convention() {
Convention::Logical => {
let plan = self.gen_optimized_logical_plan();
Expand Down
14 changes: 11 additions & 3 deletions rust/frontend/src/optimizer/plan_node/logical_project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub struct LogicalProject {
impl LogicalProject {
pub fn new(input: PlanRef, exprs: Vec<ExprImpl>, expr_alias: Vec<Option<String>>) -> Self {
let ctx = input.ctx();
let schema = Self::derive_schema(&exprs, &expr_alias);
let schema = Self::derive_schema(&exprs, &expr_alias, input.schema());
let pk_indices = Self::derive_pk(input.schema(), input.pk_indices(), &exprs);
for expr in &exprs {
assert_input_ref!(expr, input.schema().fields().len());
Expand Down Expand Up @@ -126,13 +126,21 @@ impl LogicalProject {
LogicalProject::new(input, exprs, alias).into()
}

fn derive_schema(exprs: &[ExprImpl], expr_alias: &[Option<String>]) -> Schema {
fn derive_schema(
exprs: &[ExprImpl],
expr_alias: &[Option<String>],
input_schema: &Schema,
) -> Schema {
let o2i = Self::o2i_col_mapping_inner(input_schema.len(), exprs);
let fields = exprs
.iter()
.zip_eq(expr_alias.iter())
.enumerate()
.map(|(id, (expr, alias))| {
let name = alias.clone().unwrap_or(format!("expr#{}", id));
let name = alias.clone().unwrap_or(match o2i.try_map(id) {
Some(input_idx) => input_schema.fields()[input_idx].name.clone(),
None => format!("expr#{}", id),
});
Field {
name,
data_type: expr.return_type(),
Expand Down
59 changes: 48 additions & 11 deletions rust/frontend/src/optimizer/plan_node/stream_materialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashSet;
use std::fmt;

use fixedbitset::FixedBitSet;
use itertools::Itertools;
use risingwave_common::catalog::{ColumnDesc, OrderedColumnDesc, Schema, TableId};
use risingwave_common::catalog::{ColumnDesc, Field, OrderedColumnDesc, Schema, TableId};
use risingwave_common::error::ErrorCode::InternalError;
use risingwave_common::error::Result;
use risingwave_common::util::sort_util::OrderType;
use risingwave_pb::expr::InputRefExpr;
use risingwave_pb::plan::ColumnOrder;
Expand All @@ -25,7 +28,7 @@ use risingwave_pb::stream_plan::stream_node::Node as ProstStreamNode;
use super::{PlanRef, PlanTreeNodeUnary, ToStreamProst};
use crate::catalog::column_catalog::ColumnCatalog;
use crate::catalog::table_catalog::TableCatalog;
use crate::catalog::ColumnId;
use crate::catalog::{gen_row_id_column_name, is_row_id_column_name, ColumnId};
use crate::optimizer::plan_node::{PlanBase, PlanNode};
use crate::optimizer::property::{Order, WithSchema};

Expand All @@ -39,22 +42,56 @@ pub struct StreamMaterialize {
}

impl StreamMaterialize {
fn derive_plan_base(input: &PlanRef) -> PlanBase {
fn derive_plan_base(input: &PlanRef) -> Result<PlanBase> {
let ctx = input.ctx();
let schema = input.schema();

let schema = Self::derive_schema(input.schema())?;
let pk_indices = input.pk_indices();

PlanBase::new_stream(
Ok(PlanBase::new_stream(
ctx,
schema.clone(),
schema,
pk_indices.to_vec(),
input.distribution().clone(),
input.append_only(),
)
))
}

fn derive_schema(schema: &Schema) -> Result<Schema> {
let mut col_names = HashSet::new();
for field in schema.fields() {
if is_row_id_column_name(&field.name) {
continue;
}
if !col_names.insert(field.name.clone()) {
return Err(InternalError(format!(
"column {} specified more than once",
field.name
))
.into());
}
}
let mut row_id_count = 0;
let fields = schema
.fields()
.iter()
.map(|field| match is_row_id_column_name(&field.name) {
true => {
let field = Field {
data_type: field.data_type.clone(),
name: gen_row_id_column_name(row_id_count),
};
row_id_count += 1;
field
}
false => field.clone(),
})
.collect();
Ok(Schema { fields })
}
#[must_use]
pub fn new(input: PlanRef, table: TableCatalog) -> Self {
let base = Self::derive_plan_base(&input);
let base = Self::derive_plan_base(&input).unwrap();
Self { base, input, table }
}

Expand All @@ -64,8 +101,8 @@ impl StreamMaterialize {
mv_name: String,
user_order_by: Order,
user_cols: FixedBitSet,
) -> Self {
let base = Self::derive_plan_base(&input);
) -> Result<Self> {
let base = Self::derive_plan_base(&input)?;
let schema = &base.schema;
let pk_indices = &base.pk_indices;
// Materialize executor won't change the append-only behavior of the stream, so it depends
Expand Down Expand Up @@ -115,7 +152,7 @@ impl StreamMaterialize {
pk_desc,
};

Self { base, input, table }
Ok(Self { base, input, table })
}

/// Get a reference to the stream materialize's table.
Expand Down
Loading

0 comments on commit 9d731c2

Please sign in to comment.