Skip to content

Commit

Permalink
Add MapType as JSON-compatible (#776)
Browse files Browse the repository at this point in the history
Co-authored-by: Saaketh Narayan <saaketh.narayan@databricks.com>
  • Loading branch information
srowen and snarayan21 authored Sep 10, 2024
1 parent 06fd29f commit 989e354
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
4 changes: 3 additions & 1 deletion streaming/base/converters/dataframe_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import (ArrayType, BinaryType, BooleanType, ByteType, DateType,
DayTimeIntervalType, DecimalType, DoubleType, FloatType,
IntegerType, LongType, NullType, ShortType, StringType,
IntegerType, LongType, MapType, NullType, ShortType, StringType,
StructField, StructType, TimestampNTZType, TimestampType)
except ImportError as e:
e.msg = get_import_exception_message(e.name, extra_deps='spark') # pyright: ignore
Expand Down Expand Up @@ -70,6 +70,8 @@ def is_json_compatible(data_type: Any):
return all(is_json_compatible(field.dataType) for field in data_type.fields)
elif isinstance(data_type, ArrayType):
return is_json_compatible(data_type.elementType)
elif isinstance(data_type, MapType):
return is_json_compatible(data_type.keyType) and is_json_compatible(data_type.valueType)
elif isinstance(data_type, (StringType, IntegerType, FloatType, BooleanType, NullType)):
return True
else:
Expand Down
17 changes: 11 additions & 6 deletions tests/base/converters/test_dataframe_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,13 +418,18 @@ def test_is_json_compatible(self):
]), True), True)
])

valid_schemas = [message_schema, prompt_response_schema, combined_schema]
string_map_keys_schema = StructType(
[StructField('map_field', MapType(StringType(), StringType()), nullable=True)])

valid_schemas = [
message_schema, prompt_response_schema, combined_schema, string_map_keys_schema
]

schema_with_binary = StructType([StructField('data', BinaryType(), nullable=True)])

# Schema with MapType having non-string keys
schema_with_non_string_map_keys = StructType(
[StructField('map_field', MapType(IntegerType(), StringType()), nullable=True)])
non_string_map_keys_schema = StructType(
[StructField('map_field', MapType(BinaryType(), StringType()), nullable=True)])

# Schema with DateType and TimestampType
schema_with_date_and_timestamp = StructType([
Expand All @@ -433,14 +438,14 @@ def test_is_json_compatible(self):
])

invalid_schemas = [
schema_with_binary, schema_with_non_string_map_keys, schema_with_date_and_timestamp
schema_with_binary, non_string_map_keys_schema, schema_with_date_and_timestamp
]

for s in valid_schemas:
assert is_json_compatible(s)
assert is_json_compatible(s), str(s)

for s in invalid_schemas:
assert not is_json_compatible(s)
assert not is_json_compatible(s), str(s)

def test_complex_schema(self,
complex_dataframe: Any,
Expand Down

0 comments on commit 989e354

Please sign in to comment.