From 989e35483ebbd8ef94446a015e643e8b099887b8 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 10 Sep 2024 10:52:10 -0500 Subject: [PATCH] Add MapType as JSON-compatible (#776) Co-authored-by: Saaketh Narayan --- streaming/base/converters/dataframe_to_mds.py | 4 +++- tests/base/converters/test_dataframe_to_mds.py | 17 +++++++++++------ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/base/converters/dataframe_to_mds.py index 47fdec772..2b8891e79 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/base/converters/dataframe_to_mds.py @@ -19,7 +19,7 @@ from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import (ArrayType, BinaryType, BooleanType, ByteType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, - IntegerType, LongType, NullType, ShortType, StringType, + IntegerType, LongType, MapType, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType) except ImportError as e: e.msg = get_import_exception_message(e.name, extra_deps='spark') # pyright: ignore @@ -70,6 +70,8 @@ def is_json_compatible(data_type: Any): return all(is_json_compatible(field.dataType) for field in data_type.fields) elif isinstance(data_type, ArrayType): return is_json_compatible(data_type.elementType) + elif isinstance(data_type, MapType): + return is_json_compatible(data_type.keyType) and is_json_compatible(data_type.valueType) elif isinstance(data_type, (StringType, IntegerType, FloatType, BooleanType, NullType)): return True else: diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index c8889eb1a..a79e98f2d 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -418,13 +418,18 @@ def test_is_json_compatible(self): ]), True), True) ]) - valid_schemas = [message_schema, prompt_response_schema, combined_schema] + string_map_keys_schema = StructType( + [StructField('map_field', MapType(StringType(), StringType()), nullable=True)]) + + valid_schemas = [ + message_schema, prompt_response_schema, combined_schema, string_map_keys_schema + ] schema_with_binary = StructType([StructField('data', BinaryType(), nullable=True)]) # Schema with MapType having non-string keys - schema_with_non_string_map_keys = StructType( - [StructField('map_field', MapType(IntegerType(), StringType()), nullable=True)]) + non_string_map_keys_schema = StructType( + [StructField('map_field', MapType(BinaryType(), StringType()), nullable=True)]) # Schema with DateType and TimestampType schema_with_date_and_timestamp = StructType([ @@ -433,14 +438,14 @@ def test_is_json_compatible(self): ]) invalid_schemas = [ - schema_with_binary, schema_with_non_string_map_keys, schema_with_date_and_timestamp + schema_with_binary, non_string_map_keys_schema, schema_with_date_and_timestamp ] for s in valid_schemas: - assert is_json_compatible(s) + assert is_json_compatible(s), str(s) for s in invalid_schemas: - assert not is_json_compatible(s) + assert not is_json_compatible(s), str(s) def test_complex_schema(self, complex_dataframe: Any,