diff --git a/tests/tpch/conftest.py b/tests/tpch/conftest.py index bb70c2d0c4..80d4f625dd 100644 --- a/tests/tpch/conftest.py +++ b/tests/tpch/conftest.py @@ -63,8 +63,8 @@ def dataset_path(local, scale): } local_paths = { 1: "./tpch-data/scale-1/", - 10: "./tpch-data/scale10/", - 100: "./tpch-data/scale100/", + 10: "./tpch-data/scale-10/", + 100: "./tpch-data/scale-100/", } if local: @@ -167,7 +167,7 @@ def cluster_spec(scale): if scale == 10: return { "worker_vm_types": ["m6i.large"], - "n_workers": 16, + "n_workers": 8, **everywhere, } elif scale == 100: @@ -178,7 +178,7 @@ def cluster_spec(scale): } elif scale == 1000: return { - "worker_vm_types": ["m6i.large"], + "worker_vm_types": ["m6i.xlarge"], "n_workers": 32, **everywhere, } @@ -203,8 +203,9 @@ def cluster( make_chart, ): if local: - with LocalCluster() as cluster: - yield cluster + with dask.config.set({"distributed.scheduler.worker-saturation": 4}): + with LocalCluster() as cluster: + yield cluster else: kwargs = dict( name=f"tpch-{module}-{scale}-{name}", @@ -317,7 +318,7 @@ def fs(local): def machine_spec(scale): if scale == 10: return { - "vm_type": "m6i.8xlarge", + "vm_type": "m6i.4xlarge", } elif scale == 100: return { @@ -325,7 +326,7 @@ def machine_spec(scale): } elif scale == 1000: return { - "vm_type": "m6i.16xlarge", + "vm_type": "m6i.32xlarge", } elif scale == 10000: return { @@ -399,7 +400,8 @@ def make_chart(request, name, tmp_path_factory, local, scale): with lock: generate( - outfile=os.path.join("charts", f"{local}-{scale}-query-{name}.json"), + outfile=os.path.join("charts", f"{local}-{scale}-{name}.json"), name=name, scale=scale, + local=local, ) diff --git a/tests/tpch/generate-data.py b/tests/tpch/generate-data.py index c5476cd71a..512f1069ab 100644 --- a/tests/tpch/generate-data.py +++ b/tests/tpch/generate-data.py @@ -1,4 +1,5 @@ import functools +import multiprocessing import pathlib import tempfile import warnings @@ -49,14 +50,22 @@ def generate( # workload is best with 1vCPU and ~3-4GiB memory worker_vm_types=["m7a.medium", "m3.medium"], worker_options={"nthreads": 1}, + spot_policy="spot_with_fallback", region=REGION, ) as cluster: - cluster.adapt(minimum=1, maximum=350) + cluster.adapt(minimum=1, maximum=500) with cluster.get_client() as client: jobs = client.map(_tpch_data_gen, range(0, scale), **kwargs) client.gather(jobs) else: - _tpch_data_gen(step=None, **kwargs) + with dask.distributed.LocalCluster( + threads_per_worker=1, + memory_limit=dask.distributed.system.MEMORY_LIMIT, + n_workers=multiprocessing.cpu_count() // 2, + ) as cluster: + with cluster.get_client() as client: + jobs = client.map(_tpch_data_gen, range(0, scale), **kwargs) + client.gather(jobs) def retry(f): @@ -116,8 +125,8 @@ def _tpch_data_gen( con.sql( f""" SET memory_limit='{psutil.virtual_memory().available // 2**30 }G'; - SET preserve_insertion_order=false; SET threads TO 1; + SET preserve_insertion_order=false; SET enable_progress_bar=false; """ ) @@ -166,8 +175,8 @@ def _tpch_data_gen( (format parquet, per_thread_output true, filename_pattern "{table}_{{uuid}}", overwrite_or_ignore) """ ) - print(f"Finished exporting table {table}!") - print("Finished exporting all data!") + print(f"Finished exporting table {table}") + print("Finished exporting all data") def rows_approx_mb(con, table_name, partition_size: str): diff --git a/tests/tpch/generate_plot.py b/tests/tpch/generate_plot.py index 198ddee0f1..c3063d65da 100644 --- a/tests/tpch/generate_plot.py +++ b/tests/tpch/generate_plot.py @@ -3,7 +3,7 @@ import pandas as pd -def generate(outfile="chart.json", name=None, scale=None): +def generate(outfile="chart.json", name=None, scale=None, local=None): df = pd.read_sql_table(table_name="test_run", con="sqlite:///benchmark.db") df = df[ @@ -47,7 +47,9 @@ def recent(df): ), tooltip=["library", "duration"], ) - .properties(title=f"TPC-H -- scale:{df.scale.iloc[0]} name:{df.name.iloc[0]}") + .properties( + title=f"TPC-H: {local} scale {df.scale.iloc[0]} -- {df.name.iloc[0]}" + ) .configure_title( fontSize=20, ) diff --git a/tests/tpch/test_dask.py b/tests/tpch/test_dask.py index d848c6ff2c..13fa526735 100644 --- a/tests/tpch/test_dask.py +++ b/tests/tpch/test_dask.py @@ -5,7 +5,7 @@ def test_query_1(client, dataset_path, fs): VAR1 = datetime(1998, 9, 2) - lineitem_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs) + lineitem_ds = dd.read_parquet(dataset_path + "lineitem") lineitem_filtered = lineitem_ds[lineitem_ds.l_shipdate <= VAR1] lineitem_filtered["sum_qty"] = lineitem_filtered.l_quantity @@ -45,11 +45,11 @@ def test_query_2(client, dataset_path, fs): var2 = "BRASS" var3 = "EUROPE" - region_ds = dd.read_parquet(dataset_path + "region", filesystem=fs) - nation_filtered = dd.read_parquet(dataset_path + "nation", filesystem=fs) - supplier_filtered = dd.read_parquet(dataset_path + "supplier", filesystem=fs) - part_filtered = dd.read_parquet(dataset_path + "part", filesystem=fs) - partsupp_filtered = dd.read_parquet(dataset_path + "partsupp", filesystem=fs) + region_ds = dd.read_parquet(dataset_path + "region") + nation_filtered = dd.read_parquet(dataset_path + "nation") + supplier_filtered = dd.read_parquet(dataset_path + "supplier") + part_filtered = dd.read_parquet(dataset_path + "part") + partsupp_filtered = dd.read_parquet(dataset_path + "partsupp") region_filtered = region_ds[(region_ds["r_name"] == var3)] r_n_merged = nation_filtered.merge( @@ -112,9 +112,9 @@ def test_query_3(client, dataset_path, fs): var1 = datetime.strptime("1995-03-15", "%Y-%m-%d") var2 = "BUILDING" - lineitem_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs) - orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs) - cutomer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs) + lineitem_ds = dd.read_parquet(dataset_path + "lineitem") + orders_ds = dd.read_parquet(dataset_path + "orders") + cutomer_ds = dd.read_parquet(dataset_path + "customer") lsel = lineitem_ds.l_shipdate > var1 osel = orders_ds.o_orderdate < var1 @@ -137,8 +137,8 @@ def test_query_4(client, dataset_path, fs): date1 = datetime.strptime("1993-10-01", "%Y-%m-%d") date2 = datetime.strptime("1993-07-01", "%Y-%m-%d") - line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs) - orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs) + line_item_ds = dd.read_parquet(dataset_path + "lineitem") + orders_ds = dd.read_parquet(dataset_path + "orders") lsel = line_item_ds.l_commitdate < line_item_ds.l_receiptdate osel = (orders_ds.o_orderdate < date1) & (orders_ds.o_orderdate >= date2) @@ -160,12 +160,12 @@ def test_query_5(client, dataset_path, fs): date1 = datetime.strptime("1994-01-01", "%Y-%m-%d") date2 = datetime.strptime("1995-01-01", "%Y-%m-%d") - region_ds = dd.read_parquet(dataset_path + "region", filesystem=fs) - nation_ds = dd.read_parquet(dataset_path + "nation", filesystem=fs) - customer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs) - line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs) - orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs) - supplier_ds = dd.read_parquet(dataset_path + "supplier", filesystem=fs) + region_ds = dd.read_parquet(dataset_path + "region") + nation_ds = dd.read_parquet(dataset_path + "nation") + customer_ds = dd.read_parquet(dataset_path + "customer") + line_item_ds = dd.read_parquet(dataset_path + "lineitem") + orders_ds = dd.read_parquet(dataset_path + "orders") + supplier_ds = dd.read_parquet(dataset_path + "supplier") rsel = region_ds.r_name == "ASIA" osel = (orders_ds.o_orderdate >= date1) & (orders_ds.o_orderdate < date2) @@ -190,7 +190,7 @@ def test_query_6(client, dataset_path, fs): date2 = datetime.strptime("1995-01-01", "%Y-%m-%d") var3 = 24 - line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs) + line_item_ds = dd.read_parquet(dataset_path + "lineitem") sel = ( (line_item_ds.l_shipdate >= date1) @@ -208,11 +208,11 @@ def test_query_7(client, dataset_path, fs): var1 = datetime.strptime("1995-01-01", "%Y-%m-%d") var2 = datetime.strptime("1997-01-01", "%Y-%m-%d") - nation_ds = dd.read_parquet(dataset_path + "nation", filesystem=fs) - customer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs) - line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs) - orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs) - supplier_ds = dd.read_parquet(dataset_path + "supplier", filesystem=fs) + nation_ds = dd.read_parquet(dataset_path + "nation") + customer_ds = dd.read_parquet(dataset_path + "customer") + line_item_ds = dd.read_parquet(dataset_path + "lineitem") + orders_ds = dd.read_parquet(dataset_path + "orders") + supplier_ds = dd.read_parquet(dataset_path + "supplier") lineitem_filtered = line_item_ds[ (line_item_ds["l_shipdate"] >= var1) & (line_item_ds["l_shipdate"] < var2) diff --git a/tests/tpch/test_polars.py b/tests/tpch/test_polars.py index 6925513381..99ad4a9734 100644 --- a/tests/tpch/test_polars.py +++ b/tests/tpch/test_polars.py @@ -7,20 +7,20 @@ def read_data(filename): - pyarrow_dataset = dataset(filename, format="parquet") - return pl.scan_pyarrow_dataset(pyarrow_dataset) - if filename.startswith("s3://"): + pyarrow_dataset = dataset(filename, format="parquet") + return pl.scan_pyarrow_dataset(pyarrow_dataset) import boto3 session = boto3.session.Session() credentials = session.get_credentials() return pl.scan_parquet( - filename, + filename + "/*", storage_options={ "aws_access_key_id": credentials.access_key, "aws_secret_access_key": credentials.secret_key, "region": "us-east-2", + "session_token": credentials.token, }, ) else: