Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding new __repr__ for pyspark StructField such that the error logs explicitly show metadata differences #77

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions chispa/schema_comparer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
from chispa.prettytable import PrettyTable
from chispa.bcolors import *
import chispa.six as six
from pyspark.sql.types import StructField


class SchemasNotEqualError(Exception):
"""The schemas are not equal"""
pass


class StructFieldPrettyPrint(StructField):
def __init__(self, structfield: StructField) -> None:
self.structfield = structfield

def __repr__(self):
return "StructField(%s, %s, %s, %s)" % (
self.structfield.name,
self.structfield.dataType,
str(self.structfield.nullable).lower(),
str(self.structfield.metadata)
)


def assert_schema_equality(s1, s2, ignore_nullable=False, ignore_metadata=False):
if not ignore_nullable and not ignore_metadata:
assert_basic_schema_equality(s1, s2)
Expand All @@ -30,9 +44,9 @@ def inner(s1, s2, ignore_nullable, ignore_metadata):
zipped = list(six.moves.zip_longest(s1, s2))
for sf1, sf2 in zipped:
if are_structfields_equal(sf1, sf2, True):
t.add_row([blue(sf1), blue(sf2)])
t.add_row([blue(StructFieldPrettyPrint(sf1)), blue(StructFieldPrettyPrint(sf2))])
else:
t.add_row([sf1, sf2])
t.add_row([StructFieldPrettyPrint(sf1), StructFieldPrettyPrint(sf2)])
raise SchemasNotEqualError("\n" + t.get_string())


Expand All @@ -45,9 +59,9 @@ def assert_basic_schema_equality(s1, s2):
zipped = list(six.moves.zip_longest(s1, s2))
for sf1, sf2 in zipped:
if sf1 == sf2:
t.add_row([blue(sf1), blue(sf2)])
t.add_row([blue(StructFieldPrettyPrint(sf1)), blue(StructFieldPrettyPrint(sf2))])
else:
t.add_row([sf1, sf2])
t.add_row([StructFieldPrettyPrint(sf1), StructFieldPrettyPrint(sf2)])
raise SchemasNotEqualError("\n" + t.get_string())


Expand All @@ -59,9 +73,9 @@ def assert_schema_equality_ignore_nullable(s1, s2):
zipped = list(six.moves.zip_longest(s1, s2))
for sf1, sf2 in zipped:
if are_structfields_equal(sf1, sf2, True):
t.add_row([blue(sf1), blue(sf2)])
t.add_row([blue(StructFieldPrettyPrint(sf1)), blue(StructFieldPrettyPrint(sf2))])
else:
t.add_row([sf1, sf2])
t.add_row([StructFieldPrettyPrint(sf1), StructFieldPrettyPrint(sf2)])
raise SchemasNotEqualError("\n" + t.get_string())


Expand Down