From a7f191ba5947075066154a33da7908b24c412ccb Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 18 Sep 2024 08:44:22 +0800 Subject: [PATCH] [SPARK-49640][PS] Apply reservoir sampling in `SampledPlotBase` ### What changes were proposed in this pull request? Apply reservoir sampling in `SampledPlotBase` ### Why are the changes needed? Existing sampling approach has two drawbacks: 1, it needs two jobs to sample `max_rows` rows: - df.count() to compute `fraction = max_rows / count` - df.sample(fraction).to_pandas() to do the sampling 2, the df.sample is based on Bernoulli sampling which **cannot** guarantee the sampled size == expected `max_rows`, e.g. ``` In [1]: df = spark.range(10000) In [2]: [df.sample(0.01).count() for i in range(0, 10)] Out[2]: [96, 97, 95, 97, 105, 105, 105, 87, 95, 110] ``` The size of sampled data is floating near the target size 10000*0.01=100. This relative deviation cannot be ignored, when the input dataset is large and the sampling fraction is small. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI and manually check ### Was this patch authored or co-authored using generative AI tooling? No Closes #48105 from zhengruifeng/ps_sampling. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- python/pyspark/pandas/plot/core.py | 51 ++++++++++++++++++++++++------ 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/python/pyspark/pandas/plot/core.py b/python/pyspark/pandas/plot/core.py index 067c7db664dee..7630ecc398954 100644 --- a/python/pyspark/pandas/plot/core.py +++ b/python/pyspark/pandas/plot/core.py @@ -68,19 +68,52 @@ class SampledPlotBase: def get_sampled(self, data): from pyspark.pandas import DataFrame, Series + if not isinstance(data, (DataFrame, Series)): + raise TypeError("Only DataFrame and Series are supported for plotting.") + if isinstance(data, Series): + data = data.to_frame() + fraction = get_option("plotting.sample_ratio") - if fraction is None: - fraction = 1 / (len(data) / get_option("plotting.max_rows")) - fraction = min(1.0, fraction) - self.fraction = fraction - - if isinstance(data, (DataFrame, Series)): - if isinstance(data, Series): - data = data.to_frame() + if fraction is not None: + self.fraction = fraction sampled = data._internal.resolved_copy.spark_frame.sample(fraction=self.fraction) return DataFrame(data._internal.with_new_sdf(sampled))._to_pandas() else: - raise TypeError("Only DataFrame and Series are supported for plotting.") + from pyspark.sql import Observation + + max_rows = get_option("plotting.max_rows") + observation = Observation("ps plotting") + sdf = data._internal.resolved_copy.spark_frame.observe( + observation, F.count(F.lit(1)).alias("count") + ) + + rand_col_name = "__ps_plotting_sampled_plot_base_rand__" + id_col_name = "__ps_plotting_sampled_plot_base_id__" + + sampled = ( + sdf.select( + "*", + F.rand().alias(rand_col_name), + F.monotonically_increasing_id().alias(id_col_name), + ) + .sort(rand_col_name) + .limit(max_rows + 1) + .coalesce(1) + .sortWithinPartitions(id_col_name) + .drop(rand_col_name, id_col_name) + ) + + pdf = DataFrame(data._internal.with_new_sdf(sampled))._to_pandas() + + if len(pdf) > max_rows: + try: + self.fraction = float(max_rows) / observation.get["count"] + except Exception: + pass + return pdf[:max_rows] + else: + self.fraction = 1.0 + return pdf def set_result_text(self, ax): assert hasattr(self, "fraction")