From 373f4acac26ba18b53b1eedd9145f4a716f974bb Mon Sep 17 00:00:00 2001 From: IMC07 <135529687+IMC07@users.noreply.github.com> Date: Mon, 22 Jan 2024 16:44:06 +0100 Subject: [PATCH] rewrote generate_summarised_row_dq_res (#66) * rewrote generate_summarised_row_dq_res * make fmt * edit test --- spark_expectations/sinks/utils/writer.py | 46 ++++++++++++------------ tests/sinks/utils/test_writer.py | 6 ++-- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/spark_expectations/sinks/utils/writer.py b/spark_expectations/sinks/utils/writer.py index 3cc2ad3c..f208e40f 100644 --- a/spark_expectations/sinks/utils/writer.py +++ b/spark_expectations/sinks/utils/writer.py @@ -12,6 +12,7 @@ create_map, explode, to_json, + col, ) from spark_expectations import _log from spark_expectations.core.exceptions import ( @@ -440,32 +441,33 @@ def generate_summarised_row_dq_res(self, df: DataFrame, rule_type: str) -> None: """ try: + df_exploded = df.select( + explode(f"meta_{rule_type}_results").alias("row_dq_res") + ) - def update_dict(accumulator: dict) -> dict: # pragma: no cover - if accumulator.get("failed_row_count") is None: # pragma: no cover - accumulator["failed_row_count"] = str(2) # pragma: no cover - else: # pragma: no cover - accumulator["failed_row_count"] = str( # pragma: no cover - int(accumulator["failed_row_count"]) + 1 # pragma: no cover - ) # pragma: no cover - - return accumulator # pragma: no cover - - summarised_row_dq_dict: Dict[str, Dict[str, str]] = ( - df.select(explode(f"meta_{rule_type}_results").alias("row_dq_res")) - .rdd.map( - lambda rule_meta_dict: ( - rule_meta_dict[0]["rule"], - {**rule_meta_dict[0], "failed_row_count": 1}, - ) - ) - .reduceByKey(lambda acc, itr: update_dict(acc)) - ).collectAsMap() + keys = ( + df_exploded.select(explode("row_dq_res")) + .select("key") + .distinct() + .rdd.flatMap(lambda x: x) + .collect() + ) + nested_keys = [col("row_dq_res").getItem(k).alias(k) for k in keys] - self._context.set_summarised_row_dq_res( - list(summarised_row_dq_dict.values()) + df_select = df_exploded.select(*nested_keys) + df_pivot = ( + df_select.groupBy(df_select.columns) + .count() + .withColumnRenamed("count", "failed_row_count") ) + keys += ["failed_row_count"] + summarised_row_dq_list = df_pivot.rdd.map( + lambda x: {i: x[i] for i in keys} + ).collect() + + self._context.set_summarised_row_dq_res(summarised_row_dq_list) + except Exception as e: raise SparkExpectationsMiscException( f"error occurred created summarised row dq statistics {e}" diff --git a/tests/sinks/utils/test_writer.py b/tests/sinks/utils/test_writer.py index 74298040..d33c295c 100644 --- a/tests/sinks/utils/test_writer.py +++ b/tests/sinks/utils/test_writer.py @@ -635,8 +635,8 @@ def test_write_error_records_final_dependent(save_df_as_table, {"meta_row_dq_results": [{"rule": "rule2"}]}, ], [ - {"rule": "rule1", "failed_row_count": "2"}, - {"rule": "rule2", "failed_row_count": "2"}, + {"rule": "rule1", "failed_row_count": 2}, + {"rule": "rule2", "failed_row_count": 2}, ] ), ( @@ -645,7 +645,7 @@ def test_write_error_records_final_dependent(save_df_as_table, {"meta_row_dq_results": [{"rule": "rule1"}]}, ], [ - {"rule": "rule1", "failed_row_count": "2"}, + {"rule": "rule1", "failed_row_count": 2}, ] ) ])