Skip to content

Useful helper functions for PySpark dataframe operations

Notifications You must be signed in to change notification settings


Folders and files

Last commit message
Last commit date

Latest commit



46 Commits

Repository files navigation

Pyspark Helper Functions

[For less verbose and foolproof operations]

    from pyspark import SparkConf
except ImportError:
    ! pip install pyspark==3.2.1

from pyspark import SparkConf
from pyspark.sql import SparkSession, types as st
from IPython.display import HTML

import spark.helpers as sh
# Setup Spark

conf = SparkConf().setMaster("local[1]").setAppName("examples")
spark = SparkSession.builder.config(conf=conf).getOrCreate()
# Load example datasets

dataframe_1 ="./data/dataset_1.csv")
dataframe_2 ="./data/dataset_2.csv")
html = (
    "<div style='float:left'><h4>Dataset 1:</h3>" +
    dataframe_1.toPandas().to_html() + 
    "</div><div style='float:left; margin-left:50px;'><h4>Dataset 2:</h3>" +
    dataframe_2.toPandas().to_html() +

Dataset 1:

x1 x2 x3 x4 x5
0 A J 734 499 595.0
1 B J 357 202 525.0
2 C H 864 568 433.5
3 D J 530 703 112.3
4 E H 61 521 906.0
5 F H 482 496 13.0
6 G A 350 279 941.0
7 H C 171 267 423.0
8 I C 755 133 600.0
9 J A 228 765 7.0

Dataset 2:

x1 x3 x4 x6 x7
0 W K 391 140 872.0
1 X G 88 483 707.1
2 Y M 144 476 714.3
3 Z J 896 68 902.0
4 A O 946 187 431.0
5 B P 692 523 503.5
6 C Q 550 988 181.05
7 D R 50 419 42.0
8 E S 824 805 558.2
9 F T 69 722 721.0

1. Pandas-like group by

for group, data in sh.group_iterator(dataframe_1, "x2"):
    print(group, " => ", data.toPandas().shape[0])
A  =>  2
C  =>  2
H  =>  3
J  =>  3

[Multiple columns group by]

for group, data in sh.group_iterator(dataframe_1, ["x1", "x2"]):
    print(group, " => ", data.toPandas().shape[0])
('A', 'J')  =>  1
('B', 'J')  =>  1
('C', 'H')  =>  1
('D', 'J')  =>  1
('E', 'H')  =>  1
('F', 'H')  =>  1
('G', 'A')  =>  1
('H', 'C')  =>  1
('I', 'C')  =>  1
('J', 'A')  =>  1

2. Bulk-change schema

before = [(x["name"], x["type"]) for x in dataframe_1.schema.jsonValue()["fields"]]

schema = {
    "x2": st.IntegerType(),
    "x5": st.FloatType(),
new_dataframe = sh.change_schema(dataframe_1, schema)

after = [(x["name"], x["type"]) for x in new_dataframe.schema.jsonValue()["fields"]]
check = [
    ('x1', 'string'),
    ('x2', 'integer'),
    ('x3', 'string'),
    ('x4', 'string'),
    ('x5', 'float')

assert before != after
assert after == check

3. Improved joins

joined = sh.join("x2", "x5"), dataframe_2, sh.JoinStatement("x2", "x1"))
x1 x2 x3 x4 x5 x6 x7
0 A A O 946 7.0 187 431.0
1 A A O 946 941.0 187 431.0
2 C C Q 550 600.0 988 181.05
3 C C Q 550 423.0 988 181.05

[When there are overlapping columns]

    joined = sh.join(dataframe_1, dataframe_2, sh.JoinStatement("x1"))
except ValueError as error:
    print(f"Error raised as expected: {error}")
    joined = sh.join(dataframe_1, dataframe_2, sh.JoinStatement("x1"), overwrite_strategy="left")
Error raised as expected: 

Overlapping columns found in the dataframes: ['x1', 'x3', 'x4']
Please provide the `overwrite_strategy` argument therefore, to select a selection strategy:
	* "left": Use all the intersecting columns from the left dataframe
	* "right": Use all the intersecting columns from the right dataframe
	* [["x_in_left", "y_in_left"], ["z_in_right"]]: Provide column names for both
x1 x2 x3 x4 x5 x6 x7
0 A J 734 499 595.0 187 431.0
1 B J 357 202 525.0 523 503.5
2 C H 864 568 433.5 988 181.05
3 D J 530 703 112.3 419 42.0
4 E H 61 521 906.0 805 558.2
5 F H 482 496 13.0 722 721.0

[Keeping the duplicate columns from the right dataframe]

joined = sh.join(dataframe_1, dataframe_2, sh.JoinStatement("x1"), overwrite_strategy="right")
x1 x2 x3 x4 x5 x6 x7
0 A J O 946 595.0 187 431.0
1 B J P 692 525.0 523 503.5
2 C H Q 550 433.5 988 181.05
3 D J R 50 112.3 419 42.0
4 E H S 824 906.0 805 558.2
5 F H T 69 13.0 722 721.0

[Keeping the duplicate columns from both]

joined = sh.join(
    dataframe_1, dataframe_2, sh.JoinStatement("x1"), 
    overwrite_strategy=[["x1", "x3"], ["x4"]]
x1 x2 x3 x4 x5 x6 x7
0 A J 734 946 595.0 187 431.0
1 B J 357 692 525.0 523 503.5
2 C H 864 550 433.5 988 181.05
3 D J 530 50 112.3 419 42.0
4 E H 61 824 906.0 805 558.2
5 F H 482 69 13.0 722 721.0

[Complex join]

x1_x1 = sh.JoinStatement("x1")
x1_x3 = sh.JoinStatement("x1", "x3")
statement = sh.JoinStatement(x1_x1, x1_x3, "or")
joined = sh.join(dataframe_1, dataframe_2, statement, overwrite_strategy="left")
x1 x2 x3 x4 x5 x6 x7
0 A J 734 499 595.0 187 431.0
1 B J 357 202 525.0 523 503.5
2 C H 864 568 433.5 988 181.05
3 D J 530 703 112.3 419 42.0
4 E H 61 521 906.0 805 558.2
5 F H 482 496 13.0 722 721.0
6 G A 350 279 941.0 483 707.1
7 J A 228 765 7.0 68 902.0

[Further nested joins are not supported]

(Perform sequential joins instead)

x1_x1 = sh.JoinStatement("x1")
x1_x2 = sh.JoinStatement("x1", "x3")
statement = sh.JoinStatement(x1_x1, x1_x2, "or")
statement_complex = sh.JoinStatement(statement, statement, "and")
    joined = sh.join(dataframe_1, dataframe_2, statement_complex, overwrite_strategy="left")
except NotImplementedError as error:
    print(f"Error raised as expected: [{error}]")
Error raised as expected: [Recursive JoinStatement not implemented]