diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index f0744be35..c28d0058d 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -510,6 +510,8 @@ class JSON(Encoding): """Store arbitrary data as JSON.""" def encode(self, obj: Any) -> bytes: + if isinstance(obj, np.ndarray): + obj = obj.tolist() data = json.dumps(obj) self._is_valid(obj, data) return data.encode('utf-8') diff --git a/tests/test_encodings.py b/tests/test_encodings.py index f8ec67744..47fe2a6b2 100644 --- a/tests/test_encodings.py +++ b/tests/test_encodings.py @@ -286,6 +286,22 @@ def test_json_encode_decode(self, data: Any): # Validate data content assert dec_data == data + @pytest.mark.parametrize('data', [np.array([1]), np.array(['foo']), np.array([{'foo': 1}])]) + def test_json_encode_decode_ndarray(self, data: Any): + json_enc = mdsEnc.JSON() + assert json_enc.size is None + + # Test encode + enc_data = json_enc.encode(data) + assert isinstance(enc_data, bytes) + + # Test decode + dec_data = json_enc.decode(enc_data) + assert isinstance(dec_data, list) + + # Validate data content + assert dec_data == data.tolist() + def test_json_invalid_data(self): wrong_json_with_single_quotes = "{'name': 'streaming'}" with pytest.raises(json.JSONDecodeError):