Skip to content

Commit

Permalink
feat: add postgres response for trasaction related statements (Grepti…
Browse files Browse the repository at this point in the history
…meTeam#4562)

* feat: add postgres fixtures WIP

* feat: implement more postgres fixtures

* feat: add compatibility for transaction/set transaction/show transaction

* fix: improve regex for set transaction
  • Loading branch information
sunng87 authored and CookiePieWw committed Sep 17, 2024
1 parent f15db85 commit cce44ec
Show file tree
Hide file tree
Showing 7 changed files with 277 additions and 15 deletions.
3 changes: 2 additions & 1 deletion src/common/function/src/system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ mod version;
use std::sync::Arc;

use build::BuildFunction;
use database::DatabaseFunction;
use database::{CurrentSchemaFunction, DatabaseFunction};
use pg_catalog::PGCatalogFunction;
use procedure_state::ProcedureStateFunction;
use timezone::TimezoneFunction;
Expand All @@ -37,6 +37,7 @@ impl SystemFunction {
registry.register(Arc::new(BuildFunction));
registry.register(Arc::new(VersionFunction));
registry.register(Arc::new(DatabaseFunction));
registry.register(Arc::new(CurrentSchemaFunction));
registry.register(Arc::new(TimezoneFunction));
registry.register_async(Arc::new(ProcedureStateFunction));
PGCatalogFunction::register(registry);
Expand Down
34 changes: 32 additions & 2 deletions src/common/function/src/system/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,35 @@ use crate::function::{Function, FunctionContext};
#[derive(Clone, Debug, Default)]
pub struct DatabaseFunction;

const NAME: &str = "database";
#[derive(Clone, Debug, Default)]
pub struct CurrentSchemaFunction;

const DATABASE_FUNCTION_NAME: &str = "database";
const CURRENT_SCHEMA_FUNCTION_NAME: &str = "current_schema";

impl Function for DatabaseFunction {
fn name(&self) -> &str {
NAME
DATABASE_FUNCTION_NAME
}

fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
Ok(ConcreteDataType::string_datatype())
}

fn signature(&self) -> Signature {
Signature::uniform(0, vec![], Volatility::Immutable)
}

fn eval(&self, func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
let db = func_ctx.query_ctx.current_schema();

Ok(Arc::new(StringVector::from_slice(&[&db])) as _)
}
}

impl Function for CurrentSchemaFunction {
fn name(&self) -> &str {
CURRENT_SCHEMA_FUNCTION_NAME
}

fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
Expand All @@ -54,6 +78,12 @@ impl fmt::Display for DatabaseFunction {
}
}

impl fmt::Display for CurrentSchemaFunction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "CURRENT_SCHEMA")
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;
Expand Down
3 changes: 3 additions & 0 deletions src/common/function/src/system/pg_catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

mod pg_get_userbyid;
mod table_is_visible;
mod version;

use std::sync::Arc;

use pg_get_userbyid::PGGetUserByIdFunction;
use table_is_visible::PGTableIsVisibleFunction;
use version::PGVersionFunction;

use crate::function_registry::FunctionRegistry;

Expand All @@ -35,5 +37,6 @@ impl PGCatalogFunction {
pub fn register(registry: &FunctionRegistry) {
registry.register(Arc::new(PGTableIsVisibleFunction));
registry.register(Arc::new(PGGetUserByIdFunction));
registry.register(Arc::new(PGVersionFunction));
}
}
54 changes: 54 additions & 0 deletions src/common/function/src/system/pg_catalog/version.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;
use std::{env, fmt};

use common_query::error::Result;
use common_query::prelude::{Signature, Volatility};
use datatypes::data_type::ConcreteDataType;
use datatypes::vectors::{StringVector, VectorRef};

use crate::function::{Function, FunctionContext};

#[derive(Clone, Debug, Default)]
pub(crate) struct PGVersionFunction;

impl fmt::Display for PGVersionFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, crate::pg_catalog_func_fullname!("VERSION"))
}
}

impl Function for PGVersionFunction {
fn name(&self) -> &str {
crate::pg_catalog_func_fullname!("version")
}

fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
Ok(ConcreteDataType::string_datatype())
}

fn signature(&self) -> Signature {
Signature::exact(vec![], Volatility::Immutable)
}

fn eval(&self, _func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
let result = StringVector::from(vec![format!(
"PostgreSQL 16.3 GreptimeDB {}",
env!("CARGO_PKG_VERSION")
)]);
Ok(Arc::new(result))
}
}
7 changes: 4 additions & 3 deletions src/servers/src/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

mod auth_handler;
mod fixtures;
mod handler;
mod server;
mod types;
Expand Down Expand Up @@ -41,13 +42,13 @@ use self::handler::DefaultQueryParser;
use crate::query_handler::sql::ServerSqlQueryHandlerRef;

pub(crate) struct GreptimeDBStartupParameters {
version: &'static str,
version: String,
}

impl GreptimeDBStartupParameters {
fn new() -> GreptimeDBStartupParameters {
GreptimeDBStartupParameters {
version: env!("CARGO_PKG_VERSION"),
version: format!("16.3-greptime-{}", env!("CARGO_PKG_VERSION")),
}
}
}
Expand All @@ -58,7 +59,7 @@ impl ServerParameterProvider for GreptimeDBStartupParameters {
C: ClientInfo,
{
Some(HashMap::from([
("server_version".to_owned(), self.version.to_owned()),
("server_version".to_owned(), self.version.clone()),
("server_encoding".to_owned(), "UTF8".to_owned()),
("client_encoding".to_owned(), "UTF8".to_owned()),
("DateStyle".to_owned(), "ISO YMD".to_owned()),
Expand Down
167 changes: 167 additions & 0 deletions src/servers/src/postgres/fixtures.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashMap;
use std::sync::Arc;

use futures::stream;
use once_cell::sync::Lazy;
use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
use pgwire::api::Type;
use pgwire::error::PgWireResult;
use pgwire::messages::data::DataRow;
use regex::Regex;
use session::context::QueryContextRef;

fn build_string_data_rows(
schema: Arc<Vec<FieldInfo>>,
rows: Vec<Vec<String>>,
) -> Vec<PgWireResult<DataRow>> {
rows.iter()
.map(|row| {
let mut encoder = DataRowEncoder::new(schema.clone());
for value in row {
encoder.encode_field(&Some(value))?;
}
encoder.finish()
})
.collect()
}

static VAR_VALUES: Lazy<HashMap<&str, &str>> = Lazy::new(|| {
HashMap::from([
("default_transaction_isolation", "read committed"),
("transaction isolation level", "read committed"),
("standard_conforming_strings", "on"),
("client_encoding", "UTF8"),
])
});

static SHOW_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new("(?i)^SHOW (.*?);?$").unwrap());
static SET_TRANSACTION_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new("(?i)^SET TRANSACTION (.*?);?$").unwrap());
static TRANSACTION_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new("(?i)^(BEGIN|ROLLBACK|COMMIT);?").unwrap());

/// Process unsupported SQL and return fixed result as a compatibility solution
pub(crate) fn process<'a>(
query: &str,
_query_ctx: QueryContextRef,
) -> Option<PgWireResult<Vec<Response<'a>>>> {
// Transaction directives:
if let Some(tx) = TRANSACTION_PATTERN.captures(query) {
let tx_tag = &tx[1];
Some(Ok(vec![Response::Execution(Tag::new(
&tx_tag.to_uppercase(),
))]))
} else if let Some(show_var) = SHOW_PATTERN.captures(query) {
let show_var = show_var[1].to_lowercase();
if let Some(value) = VAR_VALUES.get(&show_var.as_ref()) {
let f1 = FieldInfo::new(
show_var.clone(),
None,
None,
Type::VARCHAR,
FieldFormat::Text,
);
let schema = Arc::new(vec![f1]);
let data = stream::iter(build_string_data_rows(
schema.clone(),
vec![vec![value.to_string()]],
));

Some(Ok(vec![Response::Query(QueryResponse::new(schema, data))]))
} else {
None
}
} else if SET_TRANSACTION_PATTERN.is_match(query) {
Some(Ok(vec![Response::Execution(Tag::new("SET"))]))
} else {
None
}
}

#[cfg(test)]
mod test {
use session::context::{QueryContext, QueryContextRef};

use super::*;

fn assert_tag(q: &str, t: &str, query_context: QueryContextRef) {
if let Response::Execution(tag) = process(q, query_context.clone())
.unwrap_or_else(|| panic!("fail to match {}", q))
.expect("unexpected error")
.remove(0)
{
assert_eq!(Tag::new(t), tag);
} else {
panic!("Invalid response");
}
}

fn get_data<'a>(q: &str, query_context: QueryContextRef) -> QueryResponse<'a> {
if let Response::Query(resp) = process(q, query_context.clone())
.unwrap_or_else(|| panic!("fail to match {}", q))
.expect("unexpected error")
.remove(0)
{
resp
} else {
panic!("Invalid response");
}
}

#[test]
fn test_process() {
let query_context = QueryContext::arc();

assert_tag("BEGIN", "BEGIN", query_context.clone());
assert_tag("BEGIN;", "BEGIN", query_context.clone());
assert_tag("begin;", "BEGIN", query_context.clone());
assert_tag("ROLLBACK", "ROLLBACK", query_context.clone());
assert_tag("ROLLBACK;", "ROLLBACK", query_context.clone());
assert_tag("rollback;", "ROLLBACK", query_context.clone());
assert_tag("COMMIT", "COMMIT", query_context.clone());
assert_tag("COMMIT;", "COMMIT", query_context.clone());
assert_tag("commit;", "COMMIT", query_context.clone());
assert_tag(
"SET TRANSACTION ISOLATION LEVEL READ COMMITTED",
"SET",
query_context.clone(),
);
assert_tag(
"SET TRANSACTION ISOLATION LEVEL READ COMMITTED;",
"SET",
query_context.clone(),
);
assert_tag(
"SET transaction isolation level READ COMMITTED;",
"SET",
query_context.clone(),
);

let resp = get_data("SHOW transaction isolation level", query_context.clone());
assert_eq!(1, resp.row_schema().len());
let resp = get_data("show client_encoding;", query_context.clone());
assert_eq!(1, resp.row_schema().len());
let resp = get_data("show standard_conforming_strings;", query_context.clone());
assert_eq!(1, resp.row_schema().len());
let resp = get_data("show default_transaction_isolation", query_context.clone());
assert_eq!(1, resp.row_schema().len());

assert!(process("SELECT 1", query_context.clone()).is_none());
assert!(process("SHOW TABLES ", query_context.clone()).is_none());
assert!(process("SET TIME_ZONE=utc ", query_context.clone()).is_none());
}
}
24 changes: 15 additions & 9 deletions src/servers/src/postgres/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use sql::dialect::PostgreSqlDialect;
use sql::parser::{ParseOptions, ParserContext};

use super::types::*;
use super::PostgresServerHandler;
use super::{fixtures, PostgresServerHandler};
use crate::error::Result;
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
use crate::SqlPlan;
Expand All @@ -58,20 +58,26 @@ impl SimpleQueryHandler for PostgresServerHandler {
let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
.with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()])
.start_timer();
let outputs = self.query_handler.do_query(query, query_ctx.clone()).await;

let mut results = Vec::with_capacity(outputs.len());
if let Some(resps) = fixtures::process(query, query_ctx.clone()) {
resps
} else {
let outputs = self.query_handler.do_query(query, query_ctx.clone()).await;

for output in outputs {
let resp = output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?;
results.push(resp);
}
let mut results = Vec::with_capacity(outputs.len());

Ok(results)
for output in outputs {
let resp =
output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?;
results.push(resp);
}

Ok(results)
}
}
}

fn output_to_query_response<'a>(
pub(crate) fn output_to_query_response<'a>(
query_ctx: QueryContextRef,
output: Result<Output>,
field_format: &Format,
Expand Down

0 comments on commit cce44ec

Please sign in to comment.