diff --git a/Cargo.toml b/Cargo.toml index 87d251314..6f17444df 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,21 +1,20 @@ [package] -name = "sqloxide" -version = "0.1.36" -authors = ["Will Eaton "] +name = "compute" +version = "0.1.0" +authors = ["@joocer"] edition = "2018" [lib] -name = "sqloxide" +name = "compute" crate-type = ["cdylib"] [dependencies] +pyo3 = { version = "0.20", features = ["extension-module"] } +numpy = "0.20" +ndarray = "0.15.3" pythonize = "0.20" serde = "1.0.171" -[dependencies.pyo3] -version = "0.20.0" -features = ["extension-module"] - [dependencies.sqlparser] version = "0.44.0" features = ["serde", "visitor"] \ No newline at end of file diff --git a/opteryx/__version__.py b/opteryx/__version__.py index 5ef19cbda..9c25a1d91 100644 --- a/opteryx/__version__.py +++ b/opteryx/__version__.py @@ -1,4 +1,4 @@ -__build__ = 385 +__build__ = 391 # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/opteryx/compiled/bloom_filter/__init__.py b/opteryx/compiled/bloom_filter/__init__.py new file mode 100644 index 000000000..b4b03f3b1 --- /dev/null +++ b/opteryx/compiled/bloom_filter/__init__.py @@ -0,0 +1 @@ +from bloom_filter import create_bloom_filter diff --git a/opteryx/compiled/bloom_filter/bloom_filter.pyx b/opteryx/compiled/bloom_filter/bloom_filter.pyx new file mode 100644 index 000000000..1b548bc99 --- /dev/null +++ b/opteryx/compiled/bloom_filter/bloom_filter.pyx @@ -0,0 +1,52 @@ +# cython: language_level=3 + +""" +approximately +- 500 items, with two hashes, in 4092 bits, would have a 5% FP rate. +This implementation runs in about 1/3 the time of the one in Orso +""" + + +from libc.stdlib cimport malloc, free +from libc.string cimport memset + +cdef class BloomFilter: + cdef: + unsigned char* bit_array + long size + + def __cinit__(self, long size): + self.size = size + # Allocate memory for the bit array and initialize to 0 + self.bit_array = malloc(size // 8 + 1) + if not self.bit_array: + raise MemoryError("Failed to allocate memory for the bit array.") + memset(self.bit_array, 0, size // 8 + 1) + + def __dealloc__(self): + if self.bit_array: + free(self.bit_array) + + cpdef void add(self, long item): + """Add an item to the Bloom filter""" + h1 = item % self.size + # Apply the golden ratio to the item and use modulo to wrap within the size of the bit array + h2 = (item * 1.618033988749895) % self.size + # Set bits using bitwise OR + self.bit_array[h1 // 8] |= 1 << (h1 % 8) + self.bit_array[h2 // 8] |= 1 << (h2 % 8) + + cpdef int possibly_contains(self, long item): + """Check if the item might be in the set""" + h1 = item % self.size + h2 = (item * item + 1) % self.size + # Check bits using bitwise AND + return (self.bit_array[h1 // 8] & (1 << (h1 % 8))) and \ + (self.bit_array[h2 // 8] & (1 << (h2 % 8))) + +def create_bloom_filter(int size, items): + """Create and populate a Bloom filter""" + cdef BloomFilter bf = BloomFilter(size) + for item in items: + bf.add(item) + return bf diff --git a/opteryx/third_party/pyarrow_ops/ops.py b/opteryx/third_party/pyarrow_ops/ops.py index 6383f697d..51a48362d 100644 --- a/opteryx/third_party/pyarrow_ops/ops.py +++ b/opteryx/third_party/pyarrow_ops/ops.py @@ -120,11 +120,12 @@ def _inner_filter_operations(arr, operator, value): return compute.greater_equal(arr, value).to_numpy(False).astype(dtype=bool) if operator == "InList": # MODIFIED FOR OPTERYX - val = set(value[0]) - return numpy.array([a in val for a in arr], dtype=numpy.bool_) # [#325]? + values = set(value[0]) + return numpy.array([a in values for a in arr], dtype=numpy.bool_) # [#325]? if operator == "NotInList": # MODIFIED FOR OPTERYX - see comment above - return numpy.array([a not in value[0] for a in arr], dtype=numpy.bool_) # [#325]? + values = set(value[0]) + return numpy.array([a not in values for a in arr], dtype=numpy.bool_) # [#325]? if operator == "Like": # MODIFIED FOR OPTERYX # null input emits null output, which should be false/0 diff --git a/opteryx/third_party/sqloxide/__init__.py b/opteryx/third_party/sqloxide/__init__.py index 2845c8812..aa2c47b68 100644 --- a/opteryx/third_party/sqloxide/__init__.py +++ b/opteryx/third_party/sqloxide/__init__.py @@ -7,10 +7,8 @@ This module is not from sqloxide, it is written for Opteryx. """ -from .sqloxide import mutate_expressions -from .sqloxide import mutate_relations -from .sqloxide import parse_sql -from .sqloxide import restore_ast +from opteryx.compute import parse_sql +from opteryx.compute import restore_ast # Explicitly define the API of this module for external consumers -__all__ = ["parse_sql", "restore_ast", "mutate_expressions", "mutate_relations"] +__all__ = ["parse_sql", "restore_ast"] diff --git a/setup.py b/setup.py index b3078a34a..35534d1a3 100644 --- a/setup.py +++ b/setup.py @@ -25,9 +25,7 @@ def is_mac(): # pragma: no cover def rust_build(setup_kwargs: Dict[str, Any]) -> None: setup_kwargs.update( { - "rust_extensions": [ - RustExtension("opteryx.third_party.sqloxide.sqloxide", "Cargo.toml", debug=False) - ], + "rust_extensions": [RustExtension("opteryx.compute", "Cargo.toml", debug=False)], "zip_safe": False, } ) @@ -93,6 +91,11 @@ def rust_build(setup_kwargs: Dict[str, Any]) -> None: language="c++", extra_compile_args=COMPILE_FLAGS + ["-std=c++11"], ), + Extension( + name="bloom_filter", + sources=["opteryx/compiled/bloom_filter/bloom_filter.pyx"], + extra_compile_args=COMPILE_FLAGS, + ), Extension( name="varchar_array", sources=["opteryx/compiled/functions/varchar_array.pyx"], diff --git a/src/lib.rs b/src/lib.rs index f111f9215..8c9373f6c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,108 +1,22 @@ -use pythonize::pythonize; -use pyo3::exceptions::PyValueError; -use pyo3::prelude::*; - -use pyo3::wrap_pyfunction; -use pythonize::PythonizeError; - -use sqlparser::ast::Statement; -use sqlparser::dialect::*; -use sqlparser::parser::Parser; - -mod visitor; -use visitor::{extract_expressions, extract_relations, mutate_expressions, mutate_relations}; -fn string_to_dialect(dialect: &str) -> Box { - match dialect.to_lowercase().as_str() { - "ansi" => Box::new(AnsiDialect {}), - "bigquery" | "bq" => Box::new(BigQueryDialect {}), - "clickhouse" => Box::new(ClickHouseDialect {}), - "generic" => Box::new(GenericDialect {}), - "hive" => Box::new(HiveDialect {}), - "ms" | "mssql" => Box::new(MsSqlDialect {}), - "mysql" => Box::new(MySqlDialect {}), - "postgres" => Box::new(PostgreSqlDialect {}), - "redshift" => Box::new(RedshiftSqlDialect {}), - "snowflake" => Box::new(SnowflakeDialect {}), - "sqlite" => Box::new(SQLiteDialect {}), - _ => { - println!("The dialect you chose was not recognized, falling back to 'generic'"); - Box::new(GenericDialect {}) - } - } -} -/// Function to parse SQL statements from a string. Returns a list with -/// one item per query statement. -/// -/// Available `dialects`: -/// - generic -/// - ansi -/// - hive -/// - ms (mssql) -/// - mysql -/// - postgres -/// - snowflake -/// - sqlite -/// - clickhouse -/// - redshift -/// - bigquery (bq) -/// -#[pyfunction] -#[pyo3(text_signature = "(sql, dialect)")] -fn parse_sql(py: Python, sql: &str, dialect: &str) -> PyResult { - let chosen_dialect = string_to_dialect(dialect); - let parse_result = Parser::parse_sql(&*chosen_dialect, sql); - - let output = match parse_result { - Ok(statements) => { - pythonize(py, &statements).map_err(|e| { - let msg = e.to_string(); - PyValueError::new_err(format!("Python object serialization failed.\n\t{msg}")) - })? - } - Err(e) => { - let msg = e.to_string(); - return Err(PyValueError::new_err(format!( - "Query parsing failed.\n\t{msg}" - ))); - } - }; - - Ok(output) -} +use pyo3::prelude::*; -/// This utility function allows reconstituing a modified AST back into list of SQL queries. -#[pyfunction] -#[pyo3(text_signature = "(ast)")] -fn restore_ast(_py: Python, ast: &PyAny) -> PyResult> { - let parse_result: Result, PythonizeError> = pythonize::depythonize(ast); +use pyo3::wrap_pyfunction; - let output = match parse_result { - Ok(statements) => statements, - Err(e) => { - let msg = e.to_string(); - return Err(PyValueError::new_err(format!( - "Query serialization failed.\n\t{msg}" - ))); - } - }; +mod sqloxide; +use sqloxide::{restore_ast, parse_sql}; - Ok(output - .iter() - .map(std::string::ToString::to_string) - .collect::>()) -} +mod list_ops; +use list_ops::{anyop_eq_numeric, anyop_eq_string}; #[pymodule] -fn sqloxide(_py: Python, m: &PyModule) -> PyResult<()> { +fn compute(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(parse_sql, m)?)?; m.add_function(wrap_pyfunction!(restore_ast, m)?)?; - // TODO: maybe refactor into seperate module - m.add_function(wrap_pyfunction!(extract_relations, m)?)?; - m.add_function(wrap_pyfunction!(mutate_relations, m)?)?; - m.add_function(wrap_pyfunction!(extract_expressions, m)?)?; - m.add_function(wrap_pyfunction!(mutate_expressions, m)?)?; + + m.add_function(wrap_pyfunction!(anyop_eq_numeric, m)?)?; + m.add_function(wrap_pyfunction!(anyop_eq_string, m)?)?; Ok(()) } \ No newline at end of file diff --git a/src/list_ops.rs b/src/list_ops.rs new file mode 100644 index 000000000..2f5a8530c --- /dev/null +++ b/src/list_ops.rs @@ -0,0 +1,40 @@ +use numpy::{PyArray1, PyArray2, IntoPyArray}; +use pyo3::{Python, PyResult, prelude::*}; + + +#[pyfunction] +pub fn anyop_eq_numeric(py: Python<'_>, literal: i64, arr: &PyArray2) -> PyResult>> { + let array = unsafe { arr.as_array() }; + let result = array.map_axis(ndarray::Axis(1), |row| { + row.iter().any(|&item| item == literal) + }); + Ok(result.into_pyarray(py).to_owned()) +} + + +use pyo3::types::{PyAny, PyString}; + +#[pyfunction] +pub fn anyop_eq_string(_py: Python, value: &str, arr: &PyAny) -> PyResult> { + // Assume `arr` is a 2D array-like object (e.g., numpy array or list of lists) + let rows = arr.getattr("shape")?.extract::<(usize, )>()?.0; + let mut results = Vec::new(); + + for i in 0..rows { + let row = arr.get_item((i,))?; + let mut found = false; + + // Assuming `row` can be iterated over, reflecting a sequence of strings. + for item in row.iter()? { + let item_str = item?.downcast::()?.to_str()?; + if item_str == value { + found = true; + break; + } + } + + results.push(found); + } + + Ok(results) +} diff --git a/src/sqloxide.rs b/src/sqloxide.rs new file mode 100644 index 000000000..bf2a5d11c --- /dev/null +++ b/src/sqloxide.rs @@ -0,0 +1,93 @@ +use pythonize::pythonize; + +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; + +use pythonize::PythonizeError; + +use sqlparser::ast::Statement; +use sqlparser::dialect::*; +use sqlparser::parser::Parser; + +fn string_to_dialect(dialect: &str) -> Box { + match dialect.to_lowercase().as_str() { + "ansi" => Box::new(AnsiDialect {}), + "bigquery" | "bq" => Box::new(BigQueryDialect {}), + "clickhouse" => Box::new(ClickHouseDialect {}), + "generic" => Box::new(GenericDialect {}), + "hive" => Box::new(HiveDialect {}), + "ms" | "mssql" => Box::new(MsSqlDialect {}), + "mysql" => Box::new(MySqlDialect {}), + "postgres" => Box::new(PostgreSqlDialect {}), + "redshift" => Box::new(RedshiftSqlDialect {}), + "snowflake" => Box::new(SnowflakeDialect {}), + "sqlite" => Box::new(SQLiteDialect {}), + _ => { + println!("The dialect you chose was not recognized, falling back to 'generic'"); + Box::new(GenericDialect {}) + } + } +} + +/// Function to parse SQL statements from a string. Returns a list with +/// one item per query statement. +/// +/// Available `dialects`: +/// - generic +/// - ansi +/// - hive +/// - ms (mssql) +/// - mysql +/// - postgres +/// - snowflake +/// - sqlite +/// - clickhouse +/// - redshift +/// - bigquery (bq) +/// +#[pyfunction] +#[pyo3(text_signature = "(sql, dialect)")] +pub fn parse_sql(py: Python, sql: &str, dialect: &str) -> PyResult { + let chosen_dialect = string_to_dialect(dialect); + let parse_result = Parser::parse_sql(&*chosen_dialect, sql); + + let output = match parse_result { + Ok(statements) => { + pythonize(py, &statements).map_err(|e| { + let msg = e.to_string(); + PyValueError::new_err(format!("Python object serialization failed.\n\t{msg}")) + })? + } + Err(e) => { + let msg = e.to_string(); + return Err(PyValueError::new_err(format!( + "Query parsing failed.\n\t{msg}" + ))); + } + }; + + Ok(output) +} + +/// This utility function allows reconstituing a modified AST back into list of SQL queries. +#[pyfunction] +#[pyo3(text_signature = "(ast)")] +pub fn restore_ast(_py: Python, ast: &PyAny) -> PyResult> { + let parse_result: Result, PythonizeError> = pythonize::depythonize(ast); + + let output = match parse_result { + Ok(statements) => statements, + Err(e) => { + let msg = e.to_string(); + return Err(PyValueError::new_err(format!( + "Query serialization failed.\n\t{msg}" + ))); + } + }; + + Ok(output + .iter() + .map(std::string::ToString::to_string) + .collect::>()) +} + diff --git a/src/visitor.rs b/src/visitor.rs deleted file mode 100644 index 7e6e0bcf1..000000000 --- a/src/visitor.rs +++ /dev/null @@ -1,147 +0,0 @@ -use core::ops::ControlFlow; - -use pyo3::exceptions::PyValueError; -use pyo3::prelude::*; - -use serde::Serialize; - -use sqlparser::ast::{ - Statement, {visit_expressions, visit_expressions_mut, visit_relations, visit_relations_mut}, -}; - -// Refactored function for handling depythonization -fn depythonize_query(parsed_query: &PyAny) -> Result, PyErr> { - match pythonize::depythonize(parsed_query) { - Ok(statements) => Ok(statements), - Err(e) => { - let msg = e.to_string(); - Err(PyValueError::new_err(format!( - "Query serialization failed.\n\t{msg}" - ))) - } - } -} - -fn pythonize_query_output(py: Python, output: Vec) -> PyResult> -where - T: Sized + Serialize, -{ - match pythonize::pythonize(py, &output) { - Ok(p) => Ok(p), - Err(e) => { - let msg = e.to_string(); - Err(PyValueError::new_err(format!( - "Python object serialization failed.\n\t{msg}" - ))) - } - } -} - -#[pyfunction] -#[pyo3(text_signature = "(parsed_query)")] -pub fn extract_relations(py: Python, parsed_query: &PyAny) -> PyResult { - let statements = depythonize_query(parsed_query)?; - - let mut relations = Vec::new(); - for statement in statements { - visit_relations(&statement, |relation| { - relations.push(relation.clone()); - ControlFlow::<()>::Continue(()) - }); - } - - pythonize_query_output(py, relations) -} - -#[pyfunction] -#[pyo3(text_signature = "(parsed_query, func)")] -pub fn mutate_relations(_py: Python, parsed_query: &PyAny, func: &PyAny) -> PyResult> { - let mut statements = depythonize_query(parsed_query)?; - - for statement in &mut statements { - visit_relations_mut(statement, |table| { - for section in &mut table.0 { - let val = match func.call1((section.value.clone(),)) { - Ok(val) => val, - Err(e) => { - let msg = e.to_string(); - return ControlFlow::Break(PyValueError::new_err(format!( - "Python object serialization failed.\n\t{msg}" - ))); - } - }; - - section.value = val.to_string(); - } - ControlFlow::Continue(()) - }); - } - - Ok(statements - .iter() - .map(std::string::ToString::to_string) - .collect::>()) -} - -#[pyfunction] -#[pyo3(text_signature = "(parsed_query, func)")] -pub fn mutate_expressions(py: Python, parsed_query: &PyAny, func: &PyAny) -> PyResult> { - let mut statements = depythonize_query(parsed_query)?; - - for statement in &mut statements { - visit_expressions_mut(statement, |expr| { - let converted_expr = match pythonize::pythonize(py, expr) { - Ok(val) => val, - Err(e) => { - let msg = e.to_string(); - return ControlFlow::Break(PyValueError::new_err(format!( - "Python object deserialization failed.\n\t{msg}" - ))); - } - }; - - let func_result = match func.call1((converted_expr,)) { - Ok(val) => val, - Err(e) => { - let msg = e.to_string(); - return ControlFlow::Break(PyValueError::new_err(format!( - "Calling python function failed.\n\t{msg}" - ))); - } - }; - - *expr = match pythonize::depythonize(func_result) { - Ok(val) => val, - Err(e) => { - let msg = e.to_string(); - return ControlFlow::Break(PyValueError::new_err(format!( - "Python object reserialization failed.\n\t{msg}" - ))); - } - }; - - ControlFlow::Continue(()) - }); - } - - Ok(statements - .iter() - .map(std::string::ToString::to_string) - .collect::>()) -} - -#[pyfunction] -#[pyo3(text_signature = "(parsed_query)")] -pub fn extract_expressions(py: Python, parsed_query: &PyAny) -> PyResult { - let statements = depythonize_query(parsed_query)?; - - let mut expressions = Vec::new(); - for statement in statements { - visit_expressions(&statement, |expr| { - expressions.push(expr.clone()); - ControlFlow::<()>::Continue(()) - }); - } - - pythonize_query_output(py, expressions) -} \ No newline at end of file diff --git a/tests/sql_battery/test_shapes_and_errors_battery.py b/tests/sql_battery/test_shapes_and_errors_battery.py index 6b2161f75..c06f7d704 100644 --- a/tests/sql_battery/test_shapes_and_errors_battery.py +++ b/tests/sql_battery/test_shapes_and_errors_battery.py @@ -106,6 +106,11 @@ ("SELECT * FROM $planets WHERE orbitalPeriod BETWEEN 100 AND 1000", 3, 20, None), ("SELECT * FROM $planets WHERE LENGTH(name) = 5", 3, 20, None), ("SELECT * FROM $planets WHERE LENGTH(name) <> 5", 6, 20, None), + ("SELECT * FROM $planets WHERE LENGTH(name) == 5", 3, 20, None), + ("SELECT * FROM $planets WHERE LENGTH(name) != 5", 6, 20, None), + ("SELECT * FROM $planets WHERE LENGTH(name) = 5", 3, 20, None), + ("SELECT * FROM $planets WHERE NOT LENGTH(name) = 5", 6, 20, None), + ("SELECT * FROM $planets WHERE NOT LENGTH(name) == 5", 6, 20, None), ("SELECT * FROM $planets LIMIT 5", 5, 20, None), ("SELECT * FROM $planets WHERE numberOfMoons = 0", 2, 20, None), ("SELECT id FROM $planets WHERE density > 4000 ORDER BY id ASC", 3, 1, None),