diff --git a/setup.py b/setup.py index 1e3cd36980c..8dfcd21c5d8 100644 --- a/setup.py +++ b/setup.py @@ -115,7 +115,7 @@ # For smart caching dataset processing "dill>=0.3.0,<0.3.8", # tmp pin until dill has official support for determinism see https://github.com/uqfoundation/dill/issues/19 # For performance gains with apache arrow - "pandas<2.1.0", # temporary pin + "pandas", # for downloading datasets over HTTPS "requests>=2.19.0", # progress bars in download and scripts diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index fb158986c55..9fd1caedb86 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -4844,7 +4844,7 @@ def to_json( - Now, `index` defaults to `False` if `orint` is `"split"` or `"table"` is specified. + Now, `index` defaults to `False` if `orient` is `"split"` or `"table"`. If you would like to write the index, pass `index=True`. diff --git a/src/datasets/io/json.py b/src/datasets/io/json.py index 68728f52875..5f43efec542 100644 --- a/src/datasets/io/json.py +++ b/src/datasets/io/json.py @@ -93,7 +93,8 @@ def write(self) -> int: _ = self.to_json_kwargs.pop("path_or_buf", None) orient = self.to_json_kwargs.pop("orient", "records") lines = self.to_json_kwargs.pop("lines", True if orient == "records" else False) - index = self.to_json_kwargs.pop("index", False if orient in ["split", "table"] else True) + if "index" not in self.to_json_kwargs and orient in ["split", "table"]: + self.to_json_kwargs["index"] = False compression = self.to_json_kwargs.pop("compression", None) if compression not in [None, "infer", "gzip", "bz2", "xz"]: @@ -101,29 +102,25 @@ def write(self) -> int: if isinstance(self.path_or_buf, (str, bytes, os.PathLike)): with fsspec.open(self.path_or_buf, "wb", compression=compression) as buffer: - written = self._write(file_obj=buffer, orient=orient, lines=lines, index=index, **self.to_json_kwargs) + written = self._write(file_obj=buffer, orient=orient, lines=lines, **self.to_json_kwargs) else: if compression: raise NotImplementedError( f"The compression parameter is not supported when writing to a buffer, but compression={compression}" " was passed. Please provide a local path instead." ) - written = self._write( - file_obj=self.path_or_buf, orient=orient, lines=lines, index=index, **self.to_json_kwargs - ) + written = self._write(file_obj=self.path_or_buf, orient=orient, lines=lines, **self.to_json_kwargs) return written def _batch_json(self, args): - offset, orient, lines, index, to_json_kwargs = args + offset, orient, lines, to_json_kwargs = args batch = query_table( table=self.dataset.data, key=slice(offset, offset + self.batch_size), indices=self.dataset._indices, ) - json_str = batch.to_pandas().to_json( - path_or_buf=None, orient=orient, lines=lines, index=index, **to_json_kwargs - ) + json_str = batch.to_pandas().to_json(path_or_buf=None, orient=orient, lines=lines, **to_json_kwargs) if not json_str.endswith("\n"): json_str += "\n" return json_str.encode(self.encoding) @@ -133,7 +130,6 @@ def _write( file_obj: BinaryIO, orient, lines, - index, **to_json_kwargs, ) -> int: """Writes the pyarrow table as JSON lines to a binary file handle. @@ -149,7 +145,7 @@ def _write( disable=not logging.is_progress_bar_enabled(), desc="Creating json from Arrow format", ): - json_str = self._batch_json((offset, orient, lines, index, to_json_kwargs)) + json_str = self._batch_json((offset, orient, lines, to_json_kwargs)) written += file_obj.write(json_str) else: num_rows, batch_size = len(self.dataset), self.batch_size @@ -157,7 +153,7 @@ def _write( for json_str in logging.tqdm( pool.imap( self._batch_json, - [(offset, orient, lines, index, to_json_kwargs) for offset in range(0, num_rows, batch_size)], + [(offset, orient, lines, to_json_kwargs) for offset in range(0, num_rows, batch_size)], ), total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size, unit="ba",