diff --git a/python/pyspark/pandas/plot/core.py b/python/pyspark/pandas/plot/core.py index 067c7db664dee..7630ecc398954 100644 --- a/python/pyspark/pandas/plot/core.py +++ b/python/pyspark/pandas/plot/core.py @@ -68,19 +68,52 @@ class SampledPlotBase: def get_sampled(self, data): from pyspark.pandas import DataFrame, Series + if not isinstance(data, (DataFrame, Series)): + raise TypeError("Only DataFrame and Series are supported for plotting.") + if isinstance(data, Series): + data = data.to_frame() + fraction = get_option("plotting.sample_ratio") - if fraction is None: - fraction = 1 / (len(data) / get_option("plotting.max_rows")) - fraction = min(1.0, fraction) - self.fraction = fraction - - if isinstance(data, (DataFrame, Series)): - if isinstance(data, Series): - data = data.to_frame() + if fraction is not None: + self.fraction = fraction sampled = data._internal.resolved_copy.spark_frame.sample(fraction=self.fraction) return DataFrame(data._internal.with_new_sdf(sampled))._to_pandas() else: - raise TypeError("Only DataFrame and Series are supported for plotting.") + from pyspark.sql import Observation + + max_rows = get_option("plotting.max_rows") + observation = Observation("ps plotting") + sdf = data._internal.resolved_copy.spark_frame.observe( + observation, F.count(F.lit(1)).alias("count") + ) + + rand_col_name = "__ps_plotting_sampled_plot_base_rand__" + id_col_name = "__ps_plotting_sampled_plot_base_id__" + + sampled = ( + sdf.select( + "*", + F.rand().alias(rand_col_name), + F.monotonically_increasing_id().alias(id_col_name), + ) + .sort(rand_col_name) + .limit(max_rows + 1) + .coalesce(1) + .sortWithinPartitions(id_col_name) + .drop(rand_col_name, id_col_name) + ) + + pdf = DataFrame(data._internal.with_new_sdf(sampled))._to_pandas() + + if len(pdf) > max_rows: + try: + self.fraction = float(max_rows) / observation.get["count"] + except Exception: + pass + return pdf[:max_rows] + else: + self.fraction = 1.0 + return pdf def set_result_text(self, ax): assert hasattr(self, "fraction")