Skip to content

Commit

Permalink
Merge pull request #2055 from mabel-dev/#2054
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer authored Oct 9, 2024
2 parents 2e220ef + 966ac9c commit 79ad30a
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 242 deletions.
2 changes: 1 addition & 1 deletion opteryx/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__build__ = 819
__build__ = 820

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
36 changes: 13 additions & 23 deletions opteryx/compiled/list_ops/cython_list_ops.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import cython
import numpy
cimport numpy as cnp
from cython import Py_ssize_t
from cython.parallel import prange
from numpy cimport int64_t, ndarray
from cpython.unicode cimport PyUnicode_AsUTF8String

Expand Down Expand Up @@ -266,41 +265,32 @@ cpdef cnp.ndarray cython_get_element_op(cnp.ndarray[object, ndim=1] array, int k
return result


@cython.boundscheck(False)
@cython.wraparound(False)
cpdef cnp.ndarray array_encode_utf8(cnp.ndarray inp):
"""
utf-8 encode all elements of a 1d ndarray of "object" dtype.
A new ndarray of bytes objects is returned.
This converts about 5 million short strings (twitter user names) per second,
and 3 million tweets per second. Raw python is many times slower
Parameters:
inp: list or ndarray
The input array to encode.
Returns:
numpy.ndarray
A new ndarray with utf-8 encoded bytes objects.
Parallel UTF-8 encode all elements of a 1D ndarray of "object" dtype.
"""
cdef Py_ssize_t i, n = inp.shape[0]
cdef object[:] inp_view = inp # Create a memory view for faster access
cdef Py_ssize_t n = inp.shape[0]
cdef cnp.ndarray out = numpy.empty(n, dtype=object)
cdef object[:] inp_view = inp
cdef object[:] out_view = out

# Iterate and encode
for i in range(n):
inp_view[i] = PyUnicode_AsUTF8String(inp_view[i])
out_view[i] = PyUnicode_AsUTF8String(inp_view[i])

return inp
return out


cpdef cnp.ndarray list_contains_any(cnp.ndarray array, cnp.ndarray items):
cpdef cnp.ndarray[cnp.uint8_t, ndim=1] list_contains_any(cnp.ndarray array, cnp.ndarray items):
"""
Cython optimized version that works with object arrays.
"""
cdef set items_set = set(items[0])
cdef Py_ssize_t size = array.size
cdef cnp.ndarray res = numpy.zeros(size, dtype=numpy.bool_)
cdef cnp.ndarray[cnp.uint8_t, ndim=1] res = numpy.zeros(size, dtype=numpy.uint8)
cdef Py_ssize_t i
cdef object test_set, el
cdef cnp.ndarray test_set

for i in range(size):
test_set = array[i]
Expand All @@ -309,4 +299,4 @@ cpdef cnp.ndarray list_contains_any(cnp.ndarray array, cnp.ndarray items):
if el in items_set:
res[i] = True
break
return res
return res
121 changes: 6 additions & 115 deletions opteryx/compiled/structures/hash_table.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ cpdef tuple distinct(table, HashSet seen_hashes=None, list columns=None):
data_array = column_data # Already a NumPy array

columns_data.append(data_array)
hashes = numpy.empty(num_rows, dtype=np.int64)
hashes = numpy.empty(num_rows, dtype=numpy.int64)

# Determine data type and compute hashes accordingly
if numpy.issubdtype(data_array.dtype, numpy.integer):
Expand All @@ -126,7 +126,7 @@ cpdef tuple distinct(table, HashSet seen_hashes=None, list columns=None):
compute_object_hashes(data_array, null_hash, hashes)
else:
# For other types (e.g., strings), treat as object
compute_object_hashes(data_array.astype(np.object_), null_hash, hashes)
compute_object_hashes(data_array.astype(numpy.object_), null_hash, hashes)

columns_hashes.append(hashes)

Expand Down Expand Up @@ -163,7 +163,7 @@ cdef void compute_int_hashes(cnp.ndarray[cnp.int64_t] data, int64_t null_hash, c
value = data[i]
# Assuming a specific value represents missing data
# Adjust this condition based on how missing integers are represented
if value == numpy.iinfo(np.int64).min:
if value == numpy.iinfo(numpy.int64).min:
hashes[i] = null_hash
else:
hashes[i] = value # Hash of int is the int itself in Python 3
Expand All @@ -188,19 +188,18 @@ cpdef tuple list_distinct(cnp.ndarray values, cnp.int32_t[::1] indices, HashSet
Py_ssize_t n = values.shape[0]
object v
int64_t hash_value
int32_t[::1] new_indices = np.empty(n, dtype=np.int32)
int32_t[::1] new_indices = numpy.empty(n, dtype=numpy.int32)

# Determine the dtype of the `values` array
cnp.dtype dtype = values.dtype

cnp.ndarray[::1] values_mv = values
cnp.ndarray new_values = np.empty(n, dtype=dtype)
cnp.ndarray new_values = numpy.empty(n, dtype=dtype)

if seen_hashes is None:
seen_hashes = HashSet()

for i in range(n):
v = values_mv[i]
v = values[i]
hash_value = <int64_t>hash(v)
if seen_hashes.insert(hash_value):
new_values[j] = v
Expand Down Expand Up @@ -277,111 +276,3 @@ cpdef HashTable hash_join_map(relation, list join_columns):
ht.insert(hash_value, non_null_indices[i])

return ht


"""
Below here is an incomplete attempt at rewriting the hash table builder to be faster.
The key points to make it faster are:
- specialized hashes for different column types
- more C native structures, relying on less Python
This is competitive but doesn't outright beat the above version and currently doesn't pass all of the tests
"""


import cython
import numpy as np
import pyarrow
from libc.stdint cimport int64_t
from libc.stdlib cimport malloc, free

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef HashTable _hash_join_map(relation, list join_columns):
"""
Build a hash table for join operations using column-based hashing.
Each column is hashed separately, and the results are combined efficiently.
Parameters:
relation: The pyarrow.Table to preprocess.
join_columns: A list of column names to join on.
Returns:
A HashTable where keys are combined hashes of the join column entries and
values are lists of row indices corresponding to each hash key.
"""
cdef HashTable ht = HashTable()
cdef int64_t num_rows = relation.num_rows
cdef int64_t num_columns = len(join_columns)

# Create an array to store column hashes
cdef int64_t* cell_hashes = <int64_t*>malloc(num_rows * num_columns * sizeof(int64_t))
if cell_hashes is NULL:
raise Exception("Unable to allocate memory")

# Allocate memory for the combined nulls array
cdef cnp.ndarray[uint8_t, ndim=1] combined_nulls = numpy.full(num_rows, 1, dtype=numpy.uint8)

# Process each column to update the combined null bitmap
cdef int64_t i, j, combined_hash
cdef object column, bitmap_buffer
cdef uint8_t bit, byte

for column_name in join_columns:
column = relation.column(column_name)

if column.null_count > 0:
combined_column = column.combine_chunks()
bitmap_buffer = combined_column.buffers()[0] # Get the null bitmap buffer

if bitmap_buffer is not None:
bitmap_array = numpy.frombuffer(bitmap_buffer, dtype=np.uint8)

for i in range(num_rows):
byte = bitmap_array[i // 8]
bit = (byte >> (i % 8)) & 1
combined_nulls[i] &= bit

# Determine row indices that have no nulls in any considered column
cdef cnp.ndarray non_null_indices = numpy.nonzero(combined_nulls)[0]

# Process each column by type
for j, column_name in enumerate(join_columns):
column = relation.column(column_name)

# Handle different PyArrow types
if pyarrow.types.is_string(column.type): # String column
for i in non_null_indices:
cell_hashes[j * num_rows + i] = hash(column[i].as_buffer().to_pybytes()) # Hash string
elif pyarrow.types.is_integer(column.type) or pyarrow.types.is_floating(column.type):
# Access the data buffer directly as a NumPy array
np_column = numpy.frombuffer(column.combine_chunks().buffers()[1], dtype=np.int64)
for i in non_null_indices:
cell_hashes[j * num_rows + i] = np_column[i] # Directly store as int64
elif pyarrow.types.is_boolean(column.type):
bitmap_buffer = column.buffers()[1] # Boolean values are stored in bitmap
bitmap_ptr = <uint8_t*>bitmap_buffer.address # Access the bitmap buffer
for i in non_null_indices:
byte_idx = i // 8
bit_idx = i % 8
bit_value = (bitmap_ptr[byte_idx] >> bit_idx) & 1
cell_hashes[j * num_rows + i] = bit_value # Convert to int64 (True -> 1, False -> 0)
elif pyarrow.types.is_date(column.type) or pyarrow.types.is_timestamp(column.type):
np_column = numpy.frombuffer(column.combine_chunks().buffers()[1], dtype=np.int64)
for i in non_null_indices:
cell_hashes[j * num_rows + i] = np_column[i] # Store as int64 timestamp

# Combine hash values (n * 31 + y pattern)
if num_columns == 1:
for i in non_null_indices:
ht.insert(cell_hashes[i], i)
else:
for i in non_null_indices:
combined_hash = 0
for j in range(num_columns):
combined_hash = combined_hash * 31 + cell_hashes[j * num_rows + i]
ht.insert(combined_hash, i) # Insert combined hash into the hash table

free(cell_hashes)
return ht
15 changes: 0 additions & 15 deletions opteryx/functions/other_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,21 +141,6 @@ def null_if(col1, col2):
return [None if a == b else a for a, b in zip(col1, col2)]


def case_when(conditions, values):
n_rows = len(conditions[0])
n_conditions = len(conditions)
res = []

for idx in range(n_rows):
for cond_idx in range(n_conditions):
if conditions[cond_idx][idx]:
res.append(values[cond_idx][idx])
break
else:
res.append(None)
return res


def cosine_similarity(arr, val):
"""
ad hoc cosine similarity function, slow.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,6 @@
LITERALS_TO_THE_RIGHT = {"Plus": "Minus", "Minus": "Plus"}


def _add_condition(existing_condition, new_condition):
if not existing_condition:
return new_condition
_and = Node(node_type=NodeType.AND)
_and.left = new_condition
_and.right = existing_condition
return _and


def remove_adjacent_wildcards(predicate):
"""
Remove adjacent wildcards from LIKE/ILIKE/NotLike/NotILike conditions.
Expand All @@ -65,63 +56,6 @@ def remove_adjacent_wildcards(predicate):
return predicate


def rewrite_to_starts_with(predicate):
"""
Rewrite LIKE/ILIKE conditions with a single trailing wildcard to STARTS_WITH function.
This optimization converts patterns like 'abc%' to a STARTS_WITH function, which can be
more efficiently processed by the underlying engine compared to a generic LIKE pattern.
"""
ignore_case = predicate.value == "ILike"
predicate.right.value = predicate.right.value[:-1]
predicate.node_type = NodeType.FUNCTION
predicate.value = "STARTS_WITH"
predicate.parameters = [
predicate.left,
predicate.right,
Node(node_type=NodeType.LITERAL, type=OrsoTypes.BOOLEAN, value=ignore_case),
]
return predicate


def rewrite_to_ends_with(predicate):
"""
Rewrite LIKE/ILIKE conditions with a single leading wildcard to ENDS_WITH function.
This optimization converts patterns like '%abc' to an ENDS_WITH function, which can be
more efficiently processed by the underlying engine compared to a generic LIKE pattern.
"""
ignore_case = predicate.value == "ILike"
predicate.right.value = predicate.right.value[1:]
predicate.node_type = NodeType.FUNCTION
predicate.value = "ENDS_WITH"
predicate.parameters = [
predicate.left,
predicate.right,
Node(node_type=NodeType.LITERAL, type=OrsoTypes.BOOLEAN, value=ignore_case),
]
return predicate


def rewrite_to_search(predicate):
"""
Rewrite LIKE/ILIKE conditions with leading and trailing wildcards to SEARCH function.
This optimization converts patterns like '%abc%' to a SEARCH function, which can be
more efficiently processed by the underlying engine compared to a generic LIKE pattern.
"""
ignore_case = predicate.value == "ILike"
predicate.right.value = predicate.right.value[1:-1]
predicate.node_type = NodeType.FUNCTION
predicate.value = "SEARCH"
predicate.parameters = [
predicate.left,
predicate.right,
Node(node_type=NodeType.LITERAL, type=OrsoTypes.BOOLEAN, value=ignore_case),
]
return predicate


def rewrite_in_to_eq(predicate):
"""
Rewrite IN conditions with a single value to equality conditions.
Expand Down Expand Up @@ -178,9 +112,6 @@ def reorder_interval_calc(predicate):
# Define dispatcher conditions and actions
dispatcher: Dict[str, Callable] = {
"remove_adjacent_wildcards": remove_adjacent_wildcards,
"rewrite_to_starts_with": rewrite_to_starts_with,
"rewrite_to_ends_with": rewrite_to_ends_with,
"rewrite_to_search": rewrite_to_search,
"rewrite_in_to_eq": rewrite_in_to_eq,
"reorder_interval_calc": reorder_interval_calc,
}
Expand All @@ -206,25 +137,6 @@ def _rewrite_predicate(predicate, statistics: QueryStatistics):
statistics.optimization_predicate_rewriter_remove_redundant_like += 1
predicate.value = LIKE_REWRITES[predicate.value]

if predicate.value in {"Like", "ILike"}:
if predicate.left.source_connector and predicate.left.source_connector.isdisjoint(
{"Sql", "Cql"}
):
if predicate.right.node_type == NodeType.LITERAL:
if predicate.right.value[-1] == "%" and predicate.right.value.count("%") == 1:
statistics.optimization_predicate_rewriter_like_to_starts_with += 1
return dispatcher["rewrite_to_starts_with"](predicate)
if predicate.right.value[0] == "%" and predicate.right.value.count("%") == 1:
statistics.optimization_predicate_rewriter_like_to_ends_with += 1
return dispatcher["rewrite_to_ends_with"](predicate)
if (
predicate.right.value[0] == "%"
and predicate.right.value[-1] == "%"
and predicate.right.value.count("%") == 2
):
statistics.optimization_predicate_rewriter_like_to_search += 1
return dispatcher["rewrite_to_search"](predicate)

if predicate.value == "AnyOpEq":
if predicate.right.node_type == NodeType.LITERAL:
statistics.optimization_predicate_rewriter_any_to_inlist += 1
Expand Down
6 changes: 6 additions & 0 deletions tests/sql_battery/test_shapes_and_errors_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2013,6 +2013,12 @@
# 2051
("SELECT CASE WHEN surfacePressure = 0 THEN -1 WHEN surfacePressure IS NULL THEN 0 ELSE -2 END FROM $planets", 9, 1, None),
("SELECT CASE WHEN surfacePressure = 0 THEN -1 ELSE -2 END FROM $planets", 9, 1, None),
# 2054
("SELECT DISTINCT sides FROM (SELECT * FROM $planets AS plans LEFT JOIN (SELECT ARRAY_AGG(id) as sids, planetId FROM $satellites GROUP BY planetId) AS sats ON plans.id = planetId) AS plansats CROSS JOIN UNNEST (sids) as sides", 177, 1, None),
("SELECT DISTINCT sides FROM (SELECT * FROM $planets AS plans LEFT JOIN (SELECT ARRAY_AGG(name) as sids, planetId FROM $satellites GROUP BY planetId) AS sats ON plans.id = planetId) AS plansats CROSS JOIN UNNEST (sids) as sides", 177, 1, None),
("SELECT DISTINCT sides FROM (SELECT * FROM $planets AS plans LEFT JOIN (SELECT ARRAY_AGG(gm) as sids, planetId FROM $satellites GROUP BY planetId) AS sats ON plans.id = planetId) AS plansats CROSS JOIN UNNEST (sids) as sides", 102, 1, None),
("SELECT DISTINCT sides FROM (SELECT * FROM $planets AS plans LEFT JOIN (SELECT ARRAY_AGG(birth_date) as sids, group FROM $astronauts GROUP BY group) AS sats ON plans.id = group) AS plansats CROSS JOIN UNNEST (sids) as sides", 125, 1, None),
("SELECT DISTINCT sides FROM (SELECT * FROM $planets AS plans LEFT JOIN (SELECT ARRAY_AGG(birth_place) as sids, group FROM $astronauts GROUP BY group) AS sats ON plans.id = group) AS plansats CROSS JOIN UNNEST (sids) as sides", 110, 1, None),
]
# fmt:on

Expand Down

0 comments on commit 79ad30a

Please sign in to comment.