From 13fb2eb607d3bc1f6ac62398c4564d5ff2c3afad Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 12 Dec 2023 09:22:06 -0800 Subject: [PATCH] Offload exception to mds_write. (#528) * remove exception throwing and let mds_write handle * fix tests * update * update * fix lints * fix lints --------- Co-authored-by: Karan Jariwala --- streaming/base/converters/dataframe_to_mds.py | 27 ++++--------------- .../base/converters/test_dataframe_to_mds.py | 20 +++++++------- 2 files changed, 15 insertions(+), 32 deletions(-) diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/base/converters/dataframe_to_mds.py index c74460b3f..46370c004 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/base/converters/dataframe_to_mds.py @@ -123,7 +123,7 @@ def dataframeToMDS(dataframe: DataFrame, merge_index: bool = True, mds_kwargs: Optional[Dict[str, Any]] = None, udf_iterable: Optional[Callable] = None, - udf_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[Any, int]: + udf_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[str, str]: """Deprecated API Signature. To be replaced by dataframe_to_mds @@ -138,7 +138,7 @@ def dataframe_to_mds(dataframe: DataFrame, merge_index: bool = True, mds_kwargs: Optional[Dict[str, Any]] = None, udf_iterable: Optional[Callable] = None, - udf_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[Any, int]: + udf_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[str, str]: """Execute a spark dataframe to MDS conversion process. This method orchestrates the conversion of a spark dataframe into MDS format by processing the @@ -157,8 +157,6 @@ def dataframe_to_mds(dataframe: DataFrame, Returns: mds_path (str or (str,str)): actual local and remote path were used - fail_count (int): number of records failed to be converted - Notes: - The method creates a SparkSession if not already available. - The 'udf_kwargs' dictionaries can be used to pass additional @@ -192,8 +190,6 @@ def write_mds(iterator: Iterable): if merge_index: kwargs['keep_local'] = True # need to keep workers' locals to do merge - count = 0 - with MDSWriter(**kwargs) as mds_writer: for pdf in iterator: if udf_iterable is not None: @@ -206,11 +202,7 @@ def write_mds(iterator: Iterable): f'{type(records)}') for sample in records: - try: - mds_writer.write(sample) - except Exception as ex: - raise RuntimeError(f'failed to write sample: {sample}') from ex - count += 1 + mds_writer.write(sample) yield pd.concat([ pd.Series([os.path.join(partition_path[0], get_index_basename())], @@ -219,8 +211,7 @@ def write_mds(iterator: Iterable): os.path.join(partition_path[1], get_index_basename()) if partition_path[1] != '' else '' ], - name='mds_path_remote'), - pd.Series([count], name='fail_count') + name='mds_path_remote') ], axis=1) @@ -267,7 +258,6 @@ def write_mds(iterator: Iterable): result_schema = StructType([ StructField('mds_path_local', StringType(), False), StructField('mds_path_remote', StringType(), False), - StructField('fail_count', IntegerType(), False) ]) partitions = dataframe.mapInPandas(func=write_mds, schema=result_schema).collect() @@ -285,11 +275,4 @@ def write_mds(iterator: Iterable): if not keep_local_files: shutil.rmtree(cu.local, ignore_errors=True) - sum_fail_count = 0 - for row in partitions: - sum_fail_count += row['fail_count'] - - if sum_fail_count > 0: - logger.warning( - f'Total failed records = {sum_fail_count}\nOverall records {dataframe.count()}') - return mds_path, sum_fail_count + return mds_path diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index a99ea973a..23c4a3e92 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -68,13 +68,13 @@ def test_end_to_end_conversion_local_nocolumns(self, dataframe: Any, keep_local: } with pytest.raises(ValueError, match=f'.*is not supported by MDSWriter.*'): - _, _ = dataframe_to_mds(dataframe.select(col('id'), col('dept'), col('properties')), - merge_index=merge_index, - mds_kwargs=mds_kwargs) + _ = dataframe_to_mds(dataframe.select(col('id'), col('dept'), col('properties')), + merge_index=merge_index, + mds_kwargs=mds_kwargs) - _, _ = dataframe_to_mds(dataframe.select(col('id'), col('dept')), - merge_index=merge_index, - mds_kwargs=mds_kwargs) + _ = dataframe_to_mds(dataframe.select(col('id'), col('dept')), + merge_index=merge_index, + mds_kwargs=mds_kwargs) if keep_local: assert len(os.listdir(out)) > 0, f'{out} is empty' @@ -115,7 +115,7 @@ def test_end_to_end_conversion_local_decimal(self, decimal_dataframe: Any, use_c if use_columns: mds_kwargs['columns'] = user_defined_columns - _, _ = dataframe_to_mds(decimal_dataframe, merge_index=True, mds_kwargs=mds_kwargs) + _ = dataframe_to_mds(decimal_dataframe, merge_index=True, mds_kwargs=mds_kwargs) assert len(os.listdir(out)) > 0, f'{out} is empty' def test_user_defined_columns(self, dataframe: Any, local_remote_dir: Tuple[str, str]): @@ -126,7 +126,7 @@ def test_user_defined_columns(self, dataframe: Any, local_remote_dir: Tuple[str, 'columns': user_defined_columns, } with pytest.raises(ValueError, match=f'.*is not a column of input dataframe.*'): - _, _ = dataframe_to_mds(dataframe, merge_index=False, mds_kwargs=mds_kwargs) + _ = dataframe_to_mds(dataframe, merge_index=False, mds_kwargs=mds_kwargs) user_defined_columns = {'id': 'strr', 'dept': 'str'} @@ -135,7 +135,7 @@ def test_user_defined_columns(self, dataframe: Any, local_remote_dir: Tuple[str, 'columns': user_defined_columns, } with pytest.raises(ValueError, match=f'.* is not supported by MDSWriter.*'): - _, _ = dataframe_to_mds(dataframe, merge_index=False, mds_kwargs=mds_kwargs) + _ = dataframe_to_mds(dataframe, merge_index=False, mds_kwargs=mds_kwargs) @pytest.mark.parametrize('keep_local', [True, False]) @pytest.mark.parametrize('merge_index', [True, False]) @@ -154,7 +154,7 @@ def test_end_to_end_conversion_local(self, dataframe: Any, keep_local: bool, mer 'size_limit': 1 << 26 } - _, _ = dataframe_to_mds(dataframe, merge_index=merge_index, mds_kwargs=mds_kwargs) + _ = dataframe_to_mds(dataframe, merge_index=merge_index, mds_kwargs=mds_kwargs) if keep_local: assert len(os.listdir(out)) > 0, f'{out} is empty'