Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TPCH - Working Branch #1166

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions tests/tpch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def dataset_path(local, scale):
}
local_paths = {
1: "./tpch-data/scale-1/",
10: "./tpch-data/scale10/",
100: "./tpch-data/scale100/",
10: "./tpch-data/scale-10/",
100: "./tpch-data/scale-100/",
}

if local:
Expand Down Expand Up @@ -167,7 +167,7 @@ def cluster_spec(scale):
if scale == 10:
return {
"worker_vm_types": ["m6i.large"],
"n_workers": 16,
"n_workers": 8,
**everywhere,
}
elif scale == 100:
Expand All @@ -178,7 +178,7 @@ def cluster_spec(scale):
}
elif scale == 1000:
return {
"worker_vm_types": ["m6i.large"],
"worker_vm_types": ["m6i.xlarge"],
"n_workers": 32,
**everywhere,
}
Expand All @@ -203,8 +203,9 @@ def cluster(
make_chart,
):
if local:
with LocalCluster() as cluster:
yield cluster
with dask.config.set({"distributed.scheduler.worker-saturation": 4}):
with LocalCluster() as cluster:
yield cluster
else:
kwargs = dict(
name=f"tpch-{module}-{scale}-{name}",
Expand Down Expand Up @@ -317,15 +318,15 @@ def fs(local):
def machine_spec(scale):
if scale == 10:
return {
"vm_type": "m6i.8xlarge",
"vm_type": "m6i.4xlarge",
}
elif scale == 100:
return {
"vm_type": "m6i.8xlarge",
}
elif scale == 1000:
return {
"vm_type": "m6i.16xlarge",
"vm_type": "m6i.32xlarge",
}
elif scale == 10000:
return {
Expand Down Expand Up @@ -399,7 +400,8 @@ def make_chart(request, name, tmp_path_factory, local, scale):

with lock:
generate(
outfile=os.path.join("charts", f"{local}-{scale}-query-{name}.json"),
outfile=os.path.join("charts", f"{local}-{scale}-{name}.json"),
name=name,
scale=scale,
local=local,
)
19 changes: 13 additions & 6 deletions tests/tpch/generate-data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import multiprocessing
import pathlib
import tempfile
import warnings
Expand All @@ -19,7 +20,7 @@ def generate(
scale: int = 10,
partition_size: str = "128 MiB",
path: str = "./tpch-data",
relaxed_schema: bool = False,
relaxed_schema: bool = True,
):
if str(path).startswith("s3"):
path += "/" if not path.endswith("/") else ""
Expand Down Expand Up @@ -56,7 +57,14 @@ def generate(
jobs = client.map(_tpch_data_gen, range(0, scale), **kwargs)
client.gather(jobs)
else:
_tpch_data_gen(step=None, **kwargs)
with dask.distributed.LocalCluster(
threads_per_worker=1,
memory_limit=dask.distributed.system.MEMORY_LIMIT,
n_workers=multiprocessing.cpu_count() // 2,
) as cluster:
with cluster.get_client() as client:
jobs = client.map(_tpch_data_gen, range(0, scale), **kwargs)
client.gather(jobs)


def retry(f):
Expand Down Expand Up @@ -117,7 +125,6 @@ def _tpch_data_gen(
f"""
SET memory_limit='{psutil.virtual_memory().available // 2**30 }G';
SET preserve_insertion_order=false;
SET threads TO 1;
mrocklin marked this conversation as resolved.
Show resolved Hide resolved
SET enable_progress_bar=false;
"""
)
Expand Down Expand Up @@ -166,8 +173,8 @@ def _tpch_data_gen(
(format parquet, per_thread_output true, filename_pattern "{table}_{{uuid}}", overwrite_or_ignore)
"""
)
print(f"Finished exporting table {table}!")
print("Finished exporting all data!")
print(f"Finished exporting table {table}")
print("Finished exporting all data")


def rows_approx_mb(con, table_name, partition_size: str):
Expand Down Expand Up @@ -246,7 +253,7 @@ def get_bucket_region(path: str):
)
@click.option(
"--relaxed-schema",
default=False,
default=True,
flag_value=True,
help="Set flag to convert official TPC-H types decimal -> float and date -> timestamp_s",
)
Expand Down
6 changes: 4 additions & 2 deletions tests/tpch/generate_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pandas as pd


def generate(outfile="chart.json", name=None, scale=None):
def generate(outfile="chart.json", name=None, scale=None, local=None):
df = pd.read_sql_table(table_name="test_run", con="sqlite:///benchmark.db")

df = df[
Expand Down Expand Up @@ -47,7 +47,9 @@ def recent(df):
),
tooltip=["library", "duration"],
)
.properties(title=f"TPC-H -- scale:{df.scale.iloc[0]} name:{df.name.iloc[0]}")
.properties(
title=f"TPC-H: {local} scale {df.scale.iloc[0]} -- {df.name.iloc[0]}"
)
.configure_title(
fontSize=20,
)
Expand Down
46 changes: 23 additions & 23 deletions tests/tpch/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def test_query_1(client, dataset_path, fs):
VAR1 = datetime(1998, 9, 2)
lineitem_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs)
lineitem_ds = dd.read_parquet(dataset_path + "lineitem")

lineitem_filtered = lineitem_ds[lineitem_ds.l_shipdate <= VAR1]
lineitem_filtered["sum_qty"] = lineitem_filtered.l_quantity
Expand Down Expand Up @@ -45,11 +45,11 @@ def test_query_2(client, dataset_path, fs):
var2 = "BRASS"
var3 = "EUROPE"

region_ds = dd.read_parquet(dataset_path + "region", filesystem=fs)
nation_filtered = dd.read_parquet(dataset_path + "nation", filesystem=fs)
supplier_filtered = dd.read_parquet(dataset_path + "supplier", filesystem=fs)
part_filtered = dd.read_parquet(dataset_path + "part", filesystem=fs)
partsupp_filtered = dd.read_parquet(dataset_path + "partsupp", filesystem=fs)
region_ds = dd.read_parquet(dataset_path + "region")
nation_filtered = dd.read_parquet(dataset_path + "nation")
supplier_filtered = dd.read_parquet(dataset_path + "supplier")
part_filtered = dd.read_parquet(dataset_path + "part")
partsupp_filtered = dd.read_parquet(dataset_path + "partsupp")

region_filtered = region_ds[(region_ds["r_name"] == var3)]
r_n_merged = nation_filtered.merge(
Expand Down Expand Up @@ -112,9 +112,9 @@ def test_query_3(client, dataset_path, fs):
var1 = datetime.strptime("1995-03-15", "%Y-%m-%d")
var2 = "BUILDING"

lineitem_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs)
orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs)
cutomer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs)
lineitem_ds = dd.read_parquet(dataset_path + "lineitem")
orders_ds = dd.read_parquet(dataset_path + "orders")
cutomer_ds = dd.read_parquet(dataset_path + "customer")

lsel = lineitem_ds.l_shipdate > var1
osel = orders_ds.o_orderdate < var1
Expand All @@ -137,8 +137,8 @@ def test_query_4(client, dataset_path, fs):
date1 = datetime.strptime("1993-10-01", "%Y-%m-%d")
date2 = datetime.strptime("1993-07-01", "%Y-%m-%d")

line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs)
orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs)
line_item_ds = dd.read_parquet(dataset_path + "lineitem")
orders_ds = dd.read_parquet(dataset_path + "orders")

lsel = line_item_ds.l_commitdate < line_item_ds.l_receiptdate
osel = (orders_ds.o_orderdate < date1) & (orders_ds.o_orderdate >= date2)
Expand All @@ -160,12 +160,12 @@ def test_query_5(client, dataset_path, fs):
date1 = datetime.strptime("1994-01-01", "%Y-%m-%d")
date2 = datetime.strptime("1995-01-01", "%Y-%m-%d")

region_ds = dd.read_parquet(dataset_path + "region", filesystem=fs)
nation_ds = dd.read_parquet(dataset_path + "nation", filesystem=fs)
customer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs)
line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs)
orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs)
supplier_ds = dd.read_parquet(dataset_path + "supplier", filesystem=fs)
region_ds = dd.read_parquet(dataset_path + "region")
nation_ds = dd.read_parquet(dataset_path + "nation")
customer_ds = dd.read_parquet(dataset_path + "customer")
line_item_ds = dd.read_parquet(dataset_path + "lineitem")
orders_ds = dd.read_parquet(dataset_path + "orders")
supplier_ds = dd.read_parquet(dataset_path + "supplier")

rsel = region_ds.r_name == "ASIA"
osel = (orders_ds.o_orderdate >= date1) & (orders_ds.o_orderdate < date2)
Expand All @@ -190,7 +190,7 @@ def test_query_6(client, dataset_path, fs):
date2 = datetime.strptime("1995-01-01", "%Y-%m-%d")
var3 = 24

line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs)
line_item_ds = dd.read_parquet(dataset_path + "lineitem")

sel = (
(line_item_ds.l_shipdate >= date1)
Expand All @@ -208,11 +208,11 @@ def test_query_7(client, dataset_path, fs):
var1 = datetime.strptime("1995-01-01", "%Y-%m-%d")
var2 = datetime.strptime("1997-01-01", "%Y-%m-%d")

nation_ds = dd.read_parquet(dataset_path + "nation", filesystem=fs)
customer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs)
line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs)
orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs)
supplier_ds = dd.read_parquet(dataset_path + "supplier", filesystem=fs)
nation_ds = dd.read_parquet(dataset_path + "nation")
customer_ds = dd.read_parquet(dataset_path + "customer")
line_item_ds = dd.read_parquet(dataset_path + "lineitem")
orders_ds = dd.read_parquet(dataset_path + "orders")
supplier_ds = dd.read_parquet(dataset_path + "supplier")

lineitem_filtered = line_item_ds[
(line_item_ds["l_shipdate"] >= var1) & (line_item_ds["l_shipdate"] < var2)
Expand Down
5 changes: 2 additions & 3 deletions tests/tpch/test_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@


def read_data(filename):
pyarrow_dataset = dataset(filename, format="parquet")
return pl.scan_pyarrow_dataset(pyarrow_dataset)

if filename.startswith("s3://"):
pyarrow_dataset = dataset(filename, format="parquet")
return pl.scan_pyarrow_dataset(pyarrow_dataset)
import boto3

session = boto3.session.Session()
Expand Down
Loading