Skip to content

Commit

Permalink
resolving merge conflict
Browse files Browse the repository at this point in the history
Signed-off-by: Sayed Bilal Bari <sbari@nvidia.com>
  • Loading branch information
bilalbari committed Jul 11, 2024
2 parents cd7700e + c41b702 commit 6b9be97
Showing 1 changed file with 73 additions and 29 deletions.
102 changes: 73 additions & 29 deletions nds-h/nds_h_power.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from pyspark.sql import SparkSession
import os
import sys
import re

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))

# Construct the path to the utils directory
Expand All @@ -47,9 +49,8 @@
from python_benchmark_reporter.PysparkBenchReport import PysparkBenchReport
from pyspark.sql import DataFrame

from check import check_version
from check import check_version, check_json_summary_folder, check_query_subset_exists
from nds_h_schema import get_schemas
import re

check_version()

Expand All @@ -71,7 +72,7 @@ def gen_sql_from_stream(query_stream_file_path):
# Find all matches in the content
matches = pattern.findall(stream)

# Populate the dictionary with template file numbers as keys and queries as values
# Populate the dictionary with template file numbers as keys and queries as values
for match in matches:
template_number = match[0]
if int(template_number) == 15:
Expand All @@ -85,14 +86,14 @@ def gen_sql_from_stream(query_stream_file_path):

return extended_queries


def setup_tables(spark_session, input_prefix, input_format, execution_time_list):
"""set up data tables in Spark before running the Power Run queries.
Args:
spark_session (SparkSession): a SparkSession instance to run queries.
input_prefix (str): path of input data.
input_format (str): type of input data source, e.g. parquet, orc, csv, json.
use_decimal (bool): use decimal type for certain columns when loading data of text type.
execution_time_list ([(str, str, int)]): a list to record query and its execution time.
Returns:
Expand All @@ -103,9 +104,11 @@ def setup_tables(spark_session, input_prefix, input_format, execution_time_list)
for table_name in get_schemas().keys():
start = int(time.time() * 1000)
table_path = input_prefix + '/' + table_name
reader = spark_session.read.format(input_format)
reader = spark_session.read.format(input_format)
if input_format in ['csv', 'json']:
reader = reader.schema(get_schemas()[table_name])
print("Loading table ", table_path)
print("table name ", table_name)
reader.load(table_path).createOrReplaceTempView(table_name)
end = int(time.time() * 1000)
print("====== Creating TempView for table {} ======".format(table_name))
Expand All @@ -114,6 +117,7 @@ def setup_tables(spark_session, input_prefix, input_format, execution_time_list)
(spark_app_id, "CreateTempView {}".format(table_name), end - start))
return execution_time_list


def ensure_valid_column_names(df: DataFrame):
def is_column_start(char):
return char.isalpha() or char == '_'
Expand Down Expand Up @@ -143,16 +147,17 @@ def deduplicate(column_names):
# In some queries like q35, it's possible to get columns with the same name. Append a number
# suffix to resolve this problem.
dedup_col_names = []
for i,v in enumerate(column_names):
for i, v in enumerate(column_names):
count = column_names.count(v)
index = column_names[:i].count(v)
dedup_col_names.append(v+str(index) if count > 1 else v)
dedup_col_names.append(v + str(index) if count > 1 else v)
return dedup_col_names

valid_col_names = [c if is_valid(c) else make_valid(c) for c in df.columns]
dedup_col_names = deduplicate(valid_col_names)
return df.toDF(*dedup_col_names)


def run_one_query(spark_session,
query,
query_name,
Expand All @@ -163,18 +168,29 @@ def run_one_query(spark_session,
df.collect()
else:
ensure_valid_column_names(df).write.format(output_format).mode('overwrite').save(
output_path + '/' + query_name)
output_path + '/' + query_name)


def get_query_subset(query_dict, subset):
"""Get a subset of queries from query_dict.
The subset is specified by a list of query names.
"""
check_query_subset_exists(query_dict, subset)
return dict((k, query_dict[k]) for k in subset)


def run_query_stream(input_prefix,
property_file,
query_dict,
time_log_output_path,
input_format="parquet",
sub_queries,
input_format,
output_path=None,
keep_sc=False,
output_format="parquet"):
output_format="parquet",
json_summary_folder=None):
"""run SQL in Spark and record execution time log. The execution time log is saved as a CSV file
for easy accesibility. TempView Creation time is also recorded.
for easy accessibility. TempView Creation time is also recorded.
Args:
:param input_prefix : path of input data or warehouse if input_format is "iceberg" or hive_external=True.
Expand All @@ -187,6 +203,8 @@ def run_query_stream(input_prefix,
action will be applied to each query. Defaults to None.
:param output_format : query output format, choices are csv, orc, parquet. Defaults to "parquet".
:param keep_sc : Databricks specific to keep the spark context alive. Defaults to False.
:param json_summary_folder : path to save JSON summary files for each query.
to "parquet".
"""
queries_reports = []
execution_time_list = []
Expand All @@ -198,36 +216,48 @@ def run_query_stream(input_prefix,
session_builder = SparkSession.builder
if property_file:
spark_properties = load_properties(property_file)
for k,v in spark_properties.items():
session_builder = session_builder.config(k,v)
for k, v in spark_properties.items():
session_builder = session_builder.config(k, v)
spark_session = session_builder.appName(
app_name).getOrCreate()
spark_app_id = spark_session.sparkContext.applicationId
if input_format != 'iceberg' and input_format != 'delta':
execution_time_list = setup_tables(spark_session, input_prefix, input_format,
execution_time_list)

check_json_summary_folder(json_summary_folder)
if sub_queries:
query_dict = get_query_subset(query_dict, sub_queries)

power_start = int(time.time())
for query_name, q_content in query_dict.items():
# show query name in Spark web UI
spark_session.sparkContext.setJobGroup(query_name, query_name)
print("====== Run {} ======".format(query_name))
q_report = PysparkBenchReport(spark_session, query_name)
summary = q_report.report_on(run_one_query,spark_session,
q_content,
query_name,
output_path,
output_format)
summary = q_report.report_on(run_one_query, spark_session,
q_content,
query_name,
output_path,
output_format)
print(f"Time taken: {summary['queryTimes']} millis for {query_name}")
query_times = summary['queryTimes']
execution_time_list.append((spark_app_id, query_name, query_times[0]))
queries_reports.append(q_report)
if json_summary_folder:
if property_file:
summary_prefix = os.path.join(
json_summary_folder, os.path.basename(property_file)
)
else:
summary_prefix = os.path.join(json_summary_folder, '')
q_report.write_summary(prefix=summary_prefix)
power_end = int(time.time())
power_elapse = int((power_end - power_start)*1000)
if not keep_sc:
spark_session.sparkContext.stop()
total_time_end = time.time()
total_elapse = int((total_time_end - total_time_start)*1000)
total_elapse = int((total_time_end - total_time_start) * 1000)
print("====== Power Test Time: {} milliseconds ======".format(power_elapse))
print("====== Total Time: {} milliseconds ======".format(total_elapse))
execution_time_list.append(
Expand Down Expand Up @@ -261,6 +291,9 @@ def run_query_stream(input_prefix,
if exit_code:
print("Above queries failed or completed with failed tasks. Please check the logs for the detailed reason.")

sys.exit(exit_code)


def load_properties(filename):
myvars = {}
with open(filename) as myfile:
Expand All @@ -269,15 +302,16 @@ def load_properties(filename):
myvars[name.strip()] = var.strip()
return myvars


if __name__ == "__main__":
parser = parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser()
parser.add_argument('input_prefix',
help='text to prepend to every input file path (e.g., "hdfs:///ds-generated-data"). ' +
'If --hive or if input_format is "iceberg", this argument will be regarded as the value of property ' +
'"spark.sql.catalog.spark_catalog.warehouse". Only default Spark catalog ' +
'session name "spark_catalog" is supported now, customized catalog is not ' +
'yet supported. Note if this points to a Delta Lake table, the path must be ' +
'absolute. Issue: https://github.com/delta-io/delta/issues/555')
'If --hive or if input_format is "iceberg", this argument will be regarded as the value of property ' +
'"spark.sql.catalog.spark_catalog.warehouse". Only default Spark catalog ' +
'session name "spark_catalog" is supported now, customized catalog is not ' +
'yet supported. Note if this points to a Delta Lake table, the path must be ' +
'absolute. Issue: https://github.com/delta-io/delta/issues/555')
parser.add_argument('query_stream_file',
help='query stream file that contains NDS queries in specific order')
parser.add_argument('--keep_sc',
Expand All @@ -290,13 +324,21 @@ def load_properties(filename):
default="")
parser.add_argument('--input_format',
help='type for input data source, e.g. parquet, orc, json, csv or iceberg, delta. ' +
'Certain types are not fully supported by GPU reading, please refer to ' +
'https://github.com/NVIDIA/spark-rapids/blob/branch-24.08/docs/compatibility.md ' +
'for more details.',
'Certain types are not fully supported by GPU reading, please refer to ' +
'https://github.com/NVIDIA/spark-rapids/blob/branch-24.08/docs/compatibility.md ' +
'for more details.',
choices=['parquet', 'orc', 'avro', 'csv', 'json', 'iceberg', 'delta'],
default='parquet')
parser.add_argument('--output_prefix',
help='text to prepend to every output file (e.g., "hdfs:///ds-parquet")')
parser.add_argument('--json_summary_folder',
help='Empty folder/path (will create if not exist) to save JSON summary file for each query.')
parser.add_argument('--sub_queries',
type=lambda s: [x.strip() for x in s.split(',')],
help='comma separated list of queries to run. If not specified, all queries ' +
'in the stream file will be run. e.g. "query1,query2,query3". Note, use ' +
'"_part1" and "_part2" suffix for the following query names: ' +
'query14, query23, query24, query39. e.g. query14_part1, query39_part2')
parser.add_argument('--output_format',
help='type of query output',
default='parquet')
Expand All @@ -308,7 +350,9 @@ def load_properties(filename):
args.property_file,
query_dict,
args.time_log,
args.sub_queries,
args.input_format,
args.output_prefix,
args.keep_sc,
args.output_format)
args.output_format,
args.json_summary_folder)

0 comments on commit 6b9be97

Please sign in to comment.