Skip to content

Commit

Permalink
[SPARK-49640][PS] Apply reservoir sampling in SampledPlotBase
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Apply reservoir sampling in `SampledPlotBase`

### Why are the changes needed?
Existing sampling approach has two drawbacks:

1, it needs two jobs to sample `max_rows` rows:

- df.count() to compute `fraction = max_rows / count`
- df.sample(fraction).to_pandas() to do the sampling

2, the df.sample is based on Bernoulli sampling which **cannot** guarantee the sampled size == expected `max_rows`, e.g.
```
In [1]: df = spark.range(10000)

In [2]: [df.sample(0.01).count() for i in range(0, 10)]
Out[2]: [96, 97, 95, 97, 105, 105, 105, 87, 95, 110]
```
The size of sampled data is floating near the target size 10000*0.01=100.
This relative deviation cannot be ignored, when the input dataset is large and the sampling fraction is small.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI and manually check

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #48105 from zhengruifeng/ps_sampling.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
zhengruifeng committed Sep 18, 2024
1 parent 6fc176f commit a7f191b
Showing 1 changed file with 42 additions and 9 deletions.
51 changes: 42 additions & 9 deletions python/pyspark/pandas/plot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,52 @@ class SampledPlotBase:
def get_sampled(self, data):
from pyspark.pandas import DataFrame, Series

if not isinstance(data, (DataFrame, Series)):
raise TypeError("Only DataFrame and Series are supported for plotting.")
if isinstance(data, Series):
data = data.to_frame()

fraction = get_option("plotting.sample_ratio")
if fraction is None:
fraction = 1 / (len(data) / get_option("plotting.max_rows"))
fraction = min(1.0, fraction)
self.fraction = fraction

if isinstance(data, (DataFrame, Series)):
if isinstance(data, Series):
data = data.to_frame()
if fraction is not None:
self.fraction = fraction
sampled = data._internal.resolved_copy.spark_frame.sample(fraction=self.fraction)
return DataFrame(data._internal.with_new_sdf(sampled))._to_pandas()
else:
raise TypeError("Only DataFrame and Series are supported for plotting.")
from pyspark.sql import Observation

max_rows = get_option("plotting.max_rows")
observation = Observation("ps plotting")
sdf = data._internal.resolved_copy.spark_frame.observe(
observation, F.count(F.lit(1)).alias("count")
)

rand_col_name = "__ps_plotting_sampled_plot_base_rand__"
id_col_name = "__ps_plotting_sampled_plot_base_id__"

sampled = (
sdf.select(
"*",
F.rand().alias(rand_col_name),
F.monotonically_increasing_id().alias(id_col_name),
)
.sort(rand_col_name)
.limit(max_rows + 1)
.coalesce(1)
.sortWithinPartitions(id_col_name)
.drop(rand_col_name, id_col_name)
)

pdf = DataFrame(data._internal.with_new_sdf(sampled))._to_pandas()

if len(pdf) > max_rows:
try:
self.fraction = float(max_rows) / observation.get["count"]
except Exception:
pass
return pdf[:max_rows]
else:
self.fraction = 1.0
return pdf

def set_result_text(self, ax):
assert hasattr(self, "fraction")
Expand Down

0 comments on commit a7f191b

Please sign in to comment.