Skip to content

Commit

Permalink
Allow JSON encoder to handle ndarray (#777)
Browse files Browse the repository at this point in the history
Co-authored-by: Saaketh Narayan <saaketh.narayan@databricks.com>
  • Loading branch information
srowen and snarayan21 authored Sep 9, 2024
1 parent 8273f11 commit 06fd29f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
2 changes: 2 additions & 0 deletions streaming/base/format/mds/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,8 @@ class JSON(Encoding):
"""Store arbitrary data as JSON."""

def encode(self, obj: Any) -> bytes:
if isinstance(obj, np.ndarray):
obj = obj.tolist()
data = json.dumps(obj)
self._is_valid(obj, data)
return data.encode('utf-8')
Expand Down
16 changes: 16 additions & 0 deletions tests/test_encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,22 @@ def test_json_encode_decode(self, data: Any):
# Validate data content
assert dec_data == data

@pytest.mark.parametrize('data', [np.array([1]), np.array(['foo']), np.array([{'foo': 1}])])
def test_json_encode_decode_ndarray(self, data: Any):
json_enc = mdsEnc.JSON()
assert json_enc.size is None

# Test encode
enc_data = json_enc.encode(data)
assert isinstance(enc_data, bytes)

# Test decode
dec_data = json_enc.decode(enc_data)
assert isinstance(dec_data, list)

# Validate data content
assert dec_data == data.tolist()

def test_json_invalid_data(self):
wrong_json_with_single_quotes = "{'name': 'streaming'}"
with pytest.raises(json.JSONDecodeError):
Expand Down

0 comments on commit 06fd29f

Please sign in to comment.