Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#1543 #1545

Merged
merged 8 commits into from
Mar 28, 2024
Merged

#1543 #1545

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
[package]
name = "sqloxide"
version = "0.1.36"
authors = ["Will Eaton <me@wseaton.com>"]
name = "compute"
version = "0.1.0"
authors = ["@joocer"]
edition = "2018"

[lib]
name = "sqloxide"
name = "compute"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.20", features = ["extension-module"] }
numpy = "0.20"
ndarray = "0.15.3"
pythonize = "0.20"
serde = "1.0.171"

[dependencies.pyo3]
version = "0.20.0"
features = ["extension-module"]

[dependencies.sqlparser]
version = "0.44.0"
features = ["serde", "visitor"]
2 changes: 1 addition & 1 deletion opteryx/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__build__ = 385
__build__ = 391

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
1 change: 1 addition & 0 deletions opteryx/compiled/bloom_filter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from bloom_filter import create_bloom_filter
52 changes: 52 additions & 0 deletions opteryx/compiled/bloom_filter/bloom_filter.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# cython: language_level=3

"""
approximately
- 500 items, with two hashes, in 4092 bits, would have a 5% FP rate.
This implementation runs in about 1/3 the time of the one in Orso
"""


from libc.stdlib cimport malloc, free
from libc.string cimport memset

cdef class BloomFilter:
cdef:
unsigned char* bit_array
long size

def __cinit__(self, long size):
self.size = size
# Allocate memory for the bit array and initialize to 0
self.bit_array = <unsigned char*>malloc(size // 8 + 1)
if not self.bit_array:
raise MemoryError("Failed to allocate memory for the bit array.")
memset(self.bit_array, 0, size // 8 + 1)

def __dealloc__(self):
if self.bit_array:
free(self.bit_array)

cpdef void add(self, long item):
"""Add an item to the Bloom filter"""
h1 = item % self.size
# Apply the golden ratio to the item and use modulo to wrap within the size of the bit array
h2 = <long>(item * 1.618033988749895) % self.size
# Set bits using bitwise OR
self.bit_array[h1 // 8] |= 1 << (h1 % 8)
self.bit_array[h2 // 8] |= 1 << (h2 % 8)

cpdef int possibly_contains(self, long item):
"""Check if the item might be in the set"""
h1 = item % self.size
h2 = (item * item + 1) % self.size
# Check bits using bitwise AND
return (self.bit_array[h1 // 8] & (1 << (h1 % 8))) and \
(self.bit_array[h2 // 8] & (1 << (h2 % 8)))

def create_bloom_filter(int size, items):
"""Create and populate a Bloom filter"""
cdef BloomFilter bf = BloomFilter(size)
for item in items:
bf.add(item)
return bf
7 changes: 4 additions & 3 deletions opteryx/third_party/pyarrow_ops/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,12 @@ def _inner_filter_operations(arr, operator, value):
return compute.greater_equal(arr, value).to_numpy(False).astype(dtype=bool)
if operator == "InList":
# MODIFIED FOR OPTERYX
val = set(value[0])
return numpy.array([a in val for a in arr], dtype=numpy.bool_) # [#325]?
values = set(value[0])
return numpy.array([a in values for a in arr], dtype=numpy.bool_) # [#325]?
if operator == "NotInList":
# MODIFIED FOR OPTERYX - see comment above
return numpy.array([a not in value[0] for a in arr], dtype=numpy.bool_) # [#325]?
values = set(value[0])
return numpy.array([a not in values for a in arr], dtype=numpy.bool_) # [#325]?
if operator == "Like":
# MODIFIED FOR OPTERYX
# null input emits null output, which should be false/0
Expand Down
8 changes: 3 additions & 5 deletions opteryx/third_party/sqloxide/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
This module is not from sqloxide, it is written for Opteryx.
"""

from .sqloxide import mutate_expressions
from .sqloxide import mutate_relations
from .sqloxide import parse_sql
from .sqloxide import restore_ast
from opteryx.compute import parse_sql
from opteryx.compute import restore_ast

# Explicitly define the API of this module for external consumers
__all__ = ["parse_sql", "restore_ast", "mutate_expressions", "mutate_relations"]
__all__ = ["parse_sql", "restore_ast"]
9 changes: 6 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ def is_mac(): # pragma: no cover
def rust_build(setup_kwargs: Dict[str, Any]) -> None:
setup_kwargs.update(
{
"rust_extensions": [
RustExtension("opteryx.third_party.sqloxide.sqloxide", "Cargo.toml", debug=False)
],
"rust_extensions": [RustExtension("opteryx.compute", "Cargo.toml", debug=False)],
"zip_safe": False,
}
)
Expand Down Expand Up @@ -93,6 +91,11 @@ def rust_build(setup_kwargs: Dict[str, Any]) -> None:
language="c++",
extra_compile_args=COMPILE_FLAGS + ["-std=c++11"],
),
Extension(
name="bloom_filter",
sources=["opteryx/compiled/bloom_filter/bloom_filter.pyx"],
extra_compile_args=COMPILE_FLAGS,
),
Extension(
name="varchar_array",
sources=["opteryx/compiled/functions/varchar_array.pyx"],
Expand Down
106 changes: 10 additions & 96 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,108 +1,22 @@
use pythonize::pythonize;

use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;

use pyo3::wrap_pyfunction;
use pythonize::PythonizeError;

use sqlparser::ast::Statement;
use sqlparser::dialect::*;
use sqlparser::parser::Parser;

mod visitor;
use visitor::{extract_expressions, extract_relations, mutate_expressions, mutate_relations};

fn string_to_dialect(dialect: &str) -> Box<dyn Dialect> {
match dialect.to_lowercase().as_str() {
"ansi" => Box::new(AnsiDialect {}),
"bigquery" | "bq" => Box::new(BigQueryDialect {}),
"clickhouse" => Box::new(ClickHouseDialect {}),
"generic" => Box::new(GenericDialect {}),
"hive" => Box::new(HiveDialect {}),
"ms" | "mssql" => Box::new(MsSqlDialect {}),
"mysql" => Box::new(MySqlDialect {}),
"postgres" => Box::new(PostgreSqlDialect {}),
"redshift" => Box::new(RedshiftSqlDialect {}),
"snowflake" => Box::new(SnowflakeDialect {}),
"sqlite" => Box::new(SQLiteDialect {}),
_ => {
println!("The dialect you chose was not recognized, falling back to 'generic'");
Box::new(GenericDialect {})
}
}
}

/// Function to parse SQL statements from a string. Returns a list with
/// one item per query statement.
///
/// Available `dialects`:
/// - generic
/// - ansi
/// - hive
/// - ms (mssql)
/// - mysql
/// - postgres
/// - snowflake
/// - sqlite
/// - clickhouse
/// - redshift
/// - bigquery (bq)
///
#[pyfunction]
#[pyo3(text_signature = "(sql, dialect)")]
fn parse_sql(py: Python, sql: &str, dialect: &str) -> PyResult<PyObject> {
let chosen_dialect = string_to_dialect(dialect);
let parse_result = Parser::parse_sql(&*chosen_dialect, sql);

let output = match parse_result {
Ok(statements) => {
pythonize(py, &statements).map_err(|e| {
let msg = e.to_string();
PyValueError::new_err(format!("Python object serialization failed.\n\t{msg}"))
})?
}
Err(e) => {
let msg = e.to_string();
return Err(PyValueError::new_err(format!(
"Query parsing failed.\n\t{msg}"
)));
}
};

Ok(output)
}
use pyo3::prelude::*;

/// This utility function allows reconstituing a modified AST back into list of SQL queries.
#[pyfunction]
#[pyo3(text_signature = "(ast)")]
fn restore_ast(_py: Python, ast: &PyAny) -> PyResult<Vec<String>> {
let parse_result: Result<Vec<Statement>, PythonizeError> = pythonize::depythonize(ast);
use pyo3::wrap_pyfunction;

let output = match parse_result {
Ok(statements) => statements,
Err(e) => {
let msg = e.to_string();
return Err(PyValueError::new_err(format!(
"Query serialization failed.\n\t{msg}"
)));
}
};
mod sqloxide;
use sqloxide::{restore_ast, parse_sql};

Ok(output
.iter()
.map(std::string::ToString::to_string)
.collect::<Vec<String>>())
}
mod list_ops;
use list_ops::{anyop_eq_numeric, anyop_eq_string};

#[pymodule]
fn sqloxide(_py: Python, m: &PyModule) -> PyResult<()> {
fn compute(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(parse_sql, m)?)?;
m.add_function(wrap_pyfunction!(restore_ast, m)?)?;
// TODO: maybe refactor into seperate module
m.add_function(wrap_pyfunction!(extract_relations, m)?)?;
m.add_function(wrap_pyfunction!(mutate_relations, m)?)?;
m.add_function(wrap_pyfunction!(extract_expressions, m)?)?;
m.add_function(wrap_pyfunction!(mutate_expressions, m)?)?;

m.add_function(wrap_pyfunction!(anyop_eq_numeric, m)?)?;
m.add_function(wrap_pyfunction!(anyop_eq_string, m)?)?;
Ok(())
}
40 changes: 40 additions & 0 deletions src/list_ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use numpy::{PyArray1, PyArray2, IntoPyArray};
use pyo3::{Python, PyResult, prelude::*};


#[pyfunction]
pub fn anyop_eq_numeric(py: Python<'_>, literal: i64, arr: &PyArray2<i64>) -> PyResult<Py<PyArray1<bool>>> {
let array = unsafe { arr.as_array() };
let result = array.map_axis(ndarray::Axis(1), |row| {
row.iter().any(|&item| item == literal)
});
Ok(result.into_pyarray(py).to_owned())
}


use pyo3::types::{PyAny, PyString};

#[pyfunction]
pub fn anyop_eq_string(_py: Python, value: &str, arr: &PyAny) -> PyResult<Vec<bool>> {
// Assume `arr` is a 2D array-like object (e.g., numpy array or list of lists)
let rows = arr.getattr("shape")?.extract::<(usize, )>()?.0;
let mut results = Vec::new();

for i in 0..rows {
let row = arr.get_item((i,))?;
let mut found = false;

// Assuming `row` can be iterated over, reflecting a sequence of strings.
for item in row.iter()? {
let item_str = item?.downcast::<PyString>()?.to_str()?;
if item_str == value {
found = true;
break;
}
}

results.push(found);
}

Ok(results)
}
93 changes: 93 additions & 0 deletions src/sqloxide.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use pythonize::pythonize;

use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;

use pythonize::PythonizeError;

use sqlparser::ast::Statement;
use sqlparser::dialect::*;
use sqlparser::parser::Parser;

fn string_to_dialect(dialect: &str) -> Box<dyn Dialect> {
match dialect.to_lowercase().as_str() {
"ansi" => Box::new(AnsiDialect {}),
"bigquery" | "bq" => Box::new(BigQueryDialect {}),
"clickhouse" => Box::new(ClickHouseDialect {}),
"generic" => Box::new(GenericDialect {}),
"hive" => Box::new(HiveDialect {}),
"ms" | "mssql" => Box::new(MsSqlDialect {}),
"mysql" => Box::new(MySqlDialect {}),
"postgres" => Box::new(PostgreSqlDialect {}),
"redshift" => Box::new(RedshiftSqlDialect {}),
"snowflake" => Box::new(SnowflakeDialect {}),
"sqlite" => Box::new(SQLiteDialect {}),
_ => {
println!("The dialect you chose was not recognized, falling back to 'generic'");
Box::new(GenericDialect {})
}
}
}

/// Function to parse SQL statements from a string. Returns a list with
/// one item per query statement.
///
/// Available `dialects`:
/// - generic
/// - ansi
/// - hive
/// - ms (mssql)
/// - mysql
/// - postgres
/// - snowflake
/// - sqlite
/// - clickhouse
/// - redshift
/// - bigquery (bq)
///
#[pyfunction]
#[pyo3(text_signature = "(sql, dialect)")]
pub fn parse_sql(py: Python, sql: &str, dialect: &str) -> PyResult<PyObject> {
let chosen_dialect = string_to_dialect(dialect);
let parse_result = Parser::parse_sql(&*chosen_dialect, sql);

let output = match parse_result {
Ok(statements) => {
pythonize(py, &statements).map_err(|e| {
let msg = e.to_string();
PyValueError::new_err(format!("Python object serialization failed.\n\t{msg}"))
})?
}
Err(e) => {
let msg = e.to_string();
return Err(PyValueError::new_err(format!(
"Query parsing failed.\n\t{msg}"
)));
}
};

Ok(output)
}

/// This utility function allows reconstituing a modified AST back into list of SQL queries.
#[pyfunction]
#[pyo3(text_signature = "(ast)")]
pub fn restore_ast(_py: Python, ast: &PyAny) -> PyResult<Vec<String>> {
let parse_result: Result<Vec<Statement>, PythonizeError> = pythonize::depythonize(ast);

let output = match parse_result {
Ok(statements) => statements,
Err(e) => {
let msg = e.to_string();
return Err(PyValueError::new_err(format!(
"Query serialization failed.\n\t{msg}"
)));
}
};

Ok(output
.iter()
.map(std::string::ToString::to_string)
.collect::<Vec<String>>())
}

Loading
Loading