From 25a0cddd8413d27197c9b24211112e1ba95d79b3 Mon Sep 17 00:00:00 2001 From: Yinqing Hao Date: Thu, 17 Oct 2024 13:43:37 +0800 Subject: [PATCH] Fix compare function Signed-off-by: Yinqing Hao --- nds-h/nds_h_validate.py | 63 +++++++++++++++++------------------------ 1 file changed, 26 insertions(+), 37 deletions(-) diff --git a/nds-h/nds_h_validate.py b/nds-h/nds_h_validate.py index 487818b..823fc20 100644 --- a/nds-h/nds_h_validate.py +++ b/nds-h/nds_h_validate.py @@ -37,10 +37,10 @@ import os import re import time -from decimal import * +from decimal import Decimal from pyspark.sql import DataFrame, SparkSession -from pyspark.sql.types import * +from pyspark.sql.types import DoubleType, FloatType from pyspark.sql.functions import col from nds_h_power import gen_sql_from_stream, get_query_subset @@ -132,16 +132,10 @@ def collect_results(df: DataFrame, df = df.drop(*SKIP_COLUMNS[query_name]) # apply sorting if specified - non_float_cols = [col(field.name) for \ - field in df.schema.fields \ - if (field.dataType.typeName() != FloatType.typeName()) \ - and \ - (field.dataType.typeName() != DoubleType.typeName())] - float_cols = [col(field.name) for \ - field in df.schema.fields \ - if (field.dataType.typeName() == FloatType.typeName()) \ - or \ - (field.dataType.typeName() == DoubleType.typeName())] + non_float_cols = [col(field.name) for field in df.schema.fields + if field.dataType.typeName() not in (FloatType.typeName(), DoubleType.typeName())] + float_cols = [col(field.name) for field in df.schema.fields + if field.dataType.typeName() in (FloatType.typeName(), DoubleType.typeName())] if ignore_ordering: df = df.sort(non_float_cols + float_cols) @@ -172,20 +166,12 @@ def compare(expected, actual, epsilon=0.00001): # Double is converted to float in pyspark... if math.isnan(expected) and math.isnan(actual): return True - else: - return math.isclose(expected, actual, rel_tol=epsilon) - elif isinstance(expected, str) and isinstance(actual, str): - return expected == actual - elif expected == None and actual == None: - return True - elif expected != None and actual == None: - return False - elif expected == None and actual != None: - return False - elif isinstance(expected, Decimal) and isinstance(actual, Decimal): return math.isclose(expected, actual, rel_tol=epsilon) - else: - return expected == actual + + if isinstance(expected, Decimal) and isinstance(actual, Decimal): + return math.isclose(expected, actual, rel_tol=epsilon) + + return expected == actual def iterate_queries(spark_session: SparkSession, input1: str, @@ -239,22 +225,25 @@ def update_summary(prefix, unmatch_queries): for query_name in query_dict.keys(): summary_wildcard = prefix + f'/*{query_name}-*.json' file_glob = glob.glob(summary_wildcard) + + # Expect only one summary file for each query if len(file_glob) > 1: raise Exception(f"More than one summary file found for query {query_name} in folder {prefix}.") if len(file_glob) == 0: raise Exception(f"No summary file found for query {query_name} in folder {prefix}.") - for filename in file_glob: - with open(filename, 'r') as f: - summary = json.load(f) - if query_name in unmatch_queries: - if 'Completed' in summary['queryStatus'] or 'CompletedWithTaskFailures' in summary['queryStatus']: - summary['queryValidationStatus'] = ['Fail'] - else: - summary['queryValidationStatus'] = ['NotAttempted'] + + filename = file_glob[0] + with open(filename, 'r') as f: + summary = json.load(f) + if query_name in unmatch_queries: + if 'Completed' in summary['queryStatus'] or 'CompletedWithTaskFailures' in summary['queryStatus']: + summary['queryValidationStatus'] = ['Fail'] else: - summary['queryValidationStatus'] = ['Pass'] - with open(filename, 'w') as f: - json.dump(summary, f, indent=2) + summary['queryValidationStatus'] = ['NotAttempted'] + else: + summary['queryValidationStatus'] = ['Pass'] + with open(filename, 'w') as f: + json.dump(summary, f, indent=2) if __name__ == "__main__": parser = parser = argparse.ArgumentParser() @@ -313,4 +302,4 @@ def update_summary(prefix, unmatch_queries): max_errors=args.max_errors, epsilon=args.epsilon) if args.json_summary_folder: - update_summary(args.json_summary_folder, unmatch_queries) \ No newline at end of file + update_summary(args.json_summary_folder, unmatch_queries)