From 2cae2fe1af228395558e87233dc81f9ffbe609a6 Mon Sep 17 00:00:00 2001 From: Anjana S Date: Fri, 26 Oct 2018 23:09:41 +0530 Subject: [PATCH] closes #23283 --- doc/source/whatsnew/v0.24.0.txt | 1 + pandas/io/parquet.py | 19 ++++++++++++------- pandas/tests/io/test_parquet.py | 24 +++++++++++++++++++++++- 3 files changed, 36 insertions(+), 8 deletions(-) diff --git a/doc/source/whatsnew/v0.24.0.txt b/doc/source/whatsnew/v0.24.0.txt index 768868d5857214..7afc134ddbf33d 100644 --- a/doc/source/whatsnew/v0.24.0.txt +++ b/doc/source/whatsnew/v0.24.0.txt @@ -213,6 +213,7 @@ Other Enhancements - New attribute :attr:`__git_version__` will return git commit sha of current build (:issue:`21295`). - Compatibility with Matplotlib 3.0 (:issue:`22790`). - Added :meth:`Interval.overlaps`, :meth:`IntervalArray.overlaps`, and :meth:`IntervalIndex.overlaps` for determining overlaps between interval-like objects (:issue:`21998`) +- :func:`~DataFrame.to_parquet` now supports writing a DataFrame as a directory of parquet files partitioned by a subset of the columns. (:issue:`23283`). - :meth:`Timestamp.tz_localize`, :meth:`DatetimeIndex.tz_localize`, and :meth:`Series.tz_localize` have gained the ``nonexistent`` argument for alternative handling of nonexistent times. See :ref:`timeseries.timezone_nonexsistent` (:issue:`8917`) .. _whatsnew_0240.api_breaking: diff --git a/pandas/io/parquet.py b/pandas/io/parquet.py index aef1d84a19bc75..3661894b7df210 100644 --- a/pandas/io/parquet.py +++ b/pandas/io/parquet.py @@ -42,7 +42,6 @@ def get_engine(engine): class BaseImpl(object): - api = None # module @staticmethod @@ -97,9 +96,9 @@ def __init__(self): ) self._pyarrow_lt_060 = ( - LooseVersion(pyarrow.__version__) < LooseVersion('0.6.0')) + LooseVersion(pyarrow.__version__) < LooseVersion('0.6.0')) self._pyarrow_lt_070 = ( - LooseVersion(pyarrow.__version__) < LooseVersion('0.7.0')) + LooseVersion(pyarrow.__version__) < LooseVersion('0.7.0')) self.api = pyarrow @@ -125,9 +124,14 @@ def write(self, df, path, compression='snappy', else: table = self.api.Table.from_pandas(df, **from_pandas_kwargs) - self.api.parquet.write_table( - table, path, compression=compression, - coerce_timestamps=coerce_timestamps, **kwargs) + if 'partition_cols' in kwargs: + self.api.parquet.write_to_dataset( + table, path, compression=compression, + coerce_timestamps=coerce_timestamps, **kwargs) + else: + self.api.parquet.write_table( + table, path, compression=compression, + coerce_timestamps=coerce_timestamps, **kwargs) def read(self, path, columns=None, **kwargs): path, _, _, should_close = get_filepath_or_buffer(path) @@ -252,7 +256,8 @@ def to_parquet(df, path, engine='auto', compression='snappy', index=None, ---------- df : DataFrame path : string - File path + File path ( Will be used as `root_path` if + `partition_cols` is provided as parameter for 'pyarrow' engine). engine : {'auto', 'pyarrow', 'fastparquet'}, default 'auto' Parquet library to use. If 'auto', then the option ``io.parquet.engine`` is used. The default ``io.parquet.engine`` diff --git a/pandas/tests/io/test_parquet.py b/pandas/tests/io/test_parquet.py index 4c58d8ce29d8b6..82984326615df6 100644 --- a/pandas/tests/io/test_parquet.py +++ b/pandas/tests/io/test_parquet.py @@ -1,6 +1,8 @@ """ test parquet compat """ import pytest +import tempfile +import shutil import datetime from distutils.version import LooseVersion from warnings import catch_warnings @@ -14,12 +16,14 @@ try: import pyarrow # noqa + _HAVE_PYARROW = True except ImportError: _HAVE_PYARROW = False try: import fastparquet # noqa + _HAVE_FASTPARQUET = True except ImportError: _HAVE_FASTPARQUET = False @@ -406,7 +410,6 @@ def test_write_ignoring_index(self, engine): class TestParquetPyArrow(Base): def test_basic(self, pa, df_full): - df = df_full # additional supported types for pyarrow @@ -478,6 +481,25 @@ def test_s3_roundtrip(self, df_compat, s3_resource, pa): check_round_trip(df_compat, pa, path='s3://pandas-test/pyarrow.parquet') + def test_partition_cols_supported(self, pa_ge_070, df_full): + partition_cols = ['bool', 'int'] + df = df_full + path = tempfile.mkdtemp() + df.to_parquet(path, partition_cols=partition_cols, + compression=None) + import pyarrow.parquet as pq + dataset = pq.ParquetDataset(path, validate_schema=False) + assert len(dataset.pieces) == 2 + assert len(dataset.partitions.partition_names) == 2 + assert dataset.partitions.partition_names == set(partition_cols) + shutil.rmtree(path) + + def test_ignore_partition_cols_lt_070(self, pa_lt_070, df_full): + partition_cols = ['bool', 'int'] + pa = pa_lt_070 + df = df_full + check_round_trip(df, pa, write_kwargs={'partition_cols': partition_cols}) + class TestParquetFastParquet(Base):