Skip to content

Commit

Permalink
Fix to_json ValueError and remove pandas pin (#6201)
Browse files Browse the repository at this point in the history
* Unpin pandas

* Fix JsonDatasetWriter

* Fix typo in docstring

* Leave default index for orient different from split or table

* Pass index within to_json_kwargs when relevant
  • Loading branch information
albertvillanova committed Oct 24, 2023
1 parent 5cf9bbe commit e8e31dd
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 14 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@
# For smart caching dataset processing
"dill>=0.3.0,<0.3.8", # tmp pin until dill has official support for determinism see https://github.com/uqfoundation/dill/issues/19
# For performance gains with apache arrow
"pandas<2.1.0", # temporary pin
"pandas",
# for downloading datasets over HTTPS
"requests>=2.19.0",
# progress bars in download and scripts
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4844,7 +4844,7 @@ def to_json(
<Changed version="2.11.0">
Now, `index` defaults to `False` if `orint` is `"split"` or `"table"` is specified.
Now, `index` defaults to `False` if `orient` is `"split"` or `"table"`.
If you would like to write the index, pass `index=True`.
Expand Down
20 changes: 8 additions & 12 deletions src/datasets/io/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,37 +93,34 @@ def write(self) -> int:
_ = self.to_json_kwargs.pop("path_or_buf", None)
orient = self.to_json_kwargs.pop("orient", "records")
lines = self.to_json_kwargs.pop("lines", True if orient == "records" else False)
index = self.to_json_kwargs.pop("index", False if orient in ["split", "table"] else True)
if "index" not in self.to_json_kwargs and orient in ["split", "table"]:
self.to_json_kwargs["index"] = False
compression = self.to_json_kwargs.pop("compression", None)

if compression not in [None, "infer", "gzip", "bz2", "xz"]:
raise NotImplementedError(f"`datasets` currently does not support {compression} compression")

if isinstance(self.path_or_buf, (str, bytes, os.PathLike)):
with fsspec.open(self.path_or_buf, "wb", compression=compression) as buffer:
written = self._write(file_obj=buffer, orient=orient, lines=lines, index=index, **self.to_json_kwargs)
written = self._write(file_obj=buffer, orient=orient, lines=lines, **self.to_json_kwargs)
else:
if compression:
raise NotImplementedError(
f"The compression parameter is not supported when writing to a buffer, but compression={compression}"
" was passed. Please provide a local path instead."
)
written = self._write(
file_obj=self.path_or_buf, orient=orient, lines=lines, index=index, **self.to_json_kwargs
)
written = self._write(file_obj=self.path_or_buf, orient=orient, lines=lines, **self.to_json_kwargs)
return written

def _batch_json(self, args):
offset, orient, lines, index, to_json_kwargs = args
offset, orient, lines, to_json_kwargs = args

batch = query_table(
table=self.dataset.data,
key=slice(offset, offset + self.batch_size),
indices=self.dataset._indices,
)
json_str = batch.to_pandas().to_json(
path_or_buf=None, orient=orient, lines=lines, index=index, **to_json_kwargs
)
json_str = batch.to_pandas().to_json(path_or_buf=None, orient=orient, lines=lines, **to_json_kwargs)
if not json_str.endswith("\n"):
json_str += "\n"
return json_str.encode(self.encoding)
Expand All @@ -133,7 +130,6 @@ def _write(
file_obj: BinaryIO,
orient,
lines,
index,
**to_json_kwargs,
) -> int:
"""Writes the pyarrow table as JSON lines to a binary file handle.
Expand All @@ -149,15 +145,15 @@ def _write(
disable=not logging.is_progress_bar_enabled(),
desc="Creating json from Arrow format",
):
json_str = self._batch_json((offset, orient, lines, index, to_json_kwargs))
json_str = self._batch_json((offset, orient, lines, to_json_kwargs))
written += file_obj.write(json_str)
else:
num_rows, batch_size = len(self.dataset), self.batch_size
with multiprocessing.Pool(self.num_proc) as pool:
for json_str in logging.tqdm(
pool.imap(
self._batch_json,
[(offset, orient, lines, index, to_json_kwargs) for offset in range(0, num_rows, batch_size)],
[(offset, orient, lines, to_json_kwargs) for offset in range(0, num_rows, batch_size)],
),
total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size,
unit="ba",
Expand Down

0 comments on commit e8e31dd

Please sign in to comment.