From 7b05cd4c6b40eac79b9e26148db890b0f3594439 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 1 Aug 2019 15:43:45 -0700 Subject: [PATCH] Fixes to Arrow pydoc and README (#405) --- tensorflow_io/arrow/README.md | 24 +++++++++---------- .../arrow/python/ops/arrow_dataset_ops.py | 8 +++---- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/tensorflow_io/arrow/README.md b/tensorflow_io/arrow/README.md index 477638e13..e0454cbc9 100644 --- a/tensorflow_io/arrow/README.md +++ b/tensorflow_io/arrow/README.md @@ -15,10 +15,10 @@ Dataset. Example usage: ```python import tensorflow as tf -from tensorflow_io.arrow import ArrowDataset +import tensorflow_io.arrow as arrow_io # Assume `df` is an existing Pandas DataFrame -dataset = ArrowDataset.from_pandas(df) +dataset = arrow_io.ArrowDataset.from_pandas(df) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() @@ -49,11 +49,11 @@ write_feather(df, '/path/to/a.feather') ```python import tensorflow as tf -from tensorflow_io.arrow import ArrowFeatherDataset +import tensorflow_io.arrow as arrow_io # Each Feather file must have the same column types, here we use the above # DataFrame which has 2 columns with dtypes=(int32, float32) -dataset = ArrowFeatherDataset( +dataset = arrow_io.ArrowFeatherDataset( ['/path/to/a.feather', '/path/to/b.feather'], columns=(0, 1), output_types=(tf.int32, tf.float32), @@ -72,7 +72,7 @@ with tf.Session() as sess: ``` An alternate constructor can also be used to infer output types and shapes from -a given `pyarrow.Schema`, e.g. `dataset = ArrowFeatherDataset.from_schema(filenames, schema)` +a given `pyarrow.Schema`, e.g. `dataset = arrow_io.ArrowFeatherDataset.from_schema(filenames, schema)` ## From a Stream of Arrow Record Batches @@ -80,8 +80,8 @@ The `ArrowStreamDataset` provides a Dataset that will connect to one or more endpoints that are serving Arrow record batches in the Arrow stream format. See [here](https://arrow.apache.org/docs/python/ipc.html#writing-and-reading-streams) for more on the stream format. Currently supported endpoints are a POSIX IPv4 -socket with endpoint ":" or "tcp://:", a Unix Domain Socket -with endpoint "unix://", and STDIN with endpoint "fd://0" or "fd://-". +socket with endpoint "\:\" or "tcp://\:\", a Unix Domain Socket +with endpoint "unix://\", and STDIN with endpoint "fd://0" or "fd://-". The following example will create an `ArrowStreamDataset` that will connect to a local host endpoint that is serving an Arrow stream of record batches with 2 @@ -89,13 +89,13 @@ columns of dtypes=(int32, float32): ```python import tensorflow as tf -from tensorflow_io.arrow import ArrowStreamDataset +import tensorflow_io.arrow as arrow_io # The parameter `endpoints` can be a Python string or a list of strings and # should be in the format ':' for an IPv4 host endpoints = '127.0.0.1:8999' -dataset = ArrowStreamDataset( +dataset = arrow_io.ArrowStreamDataset( endpoints, columns=(0, 1), output_types=(tf.int32, tf.float32), @@ -115,7 +115,7 @@ with tf.Session() as sess: ``` An alternate constructor can also be used to infer output types and shapes from -a given `pyarrow.Schema`, e.g. `dataset = ArrowStreamDataset.from_schema(host, schema)` +a given `pyarrow.Schema`, e.g. `dataset = arrow_io.ArrowStreamDataset.from_schema(host, schema)` ## Creating Batches with Arrow Datasets @@ -127,7 +127,7 @@ and 'auto'. If the last elements of the Dataset do not combine to the set will automatically set a batch size to the number of records in the incoming Arrow record batches. This a good option to use if the incoming Arrow record batch size can be controlled to ensure the output batch size is not too large -and sequential Arrow record batches are sized equally. +and each of the Arrow record batches are sized equally. Setting the `batch_size` or using `batch_mode` of 'auto' can be more efficient than using `tf.data.Dataset.batch()` on an Arrow Dataset. This is because the @@ -135,4 +135,4 @@ output tensor can be sized to the desired batch size on creation, and then data is transferred directly from Arrow memory. Otherwise, if batching elements with the output of an Arrow Dataset, e.g. `ArrowDataset(...).batch(batch_size=4)`, then the tensor data will need to be aggregated and copied to get the final -batched outputs. +batched output. diff --git a/tensorflow_io/arrow/python/ops/arrow_dataset_ops.py b/tensorflow_io/arrow/python/ops/arrow_dataset_ops.py index 3c3ca76ad..f57bf48f7 100644 --- a/tensorflow_io/arrow/python/ops/arrow_dataset_ops.py +++ b/tensorflow_io/arrow/python/ops/arrow_dataset_ops.py @@ -377,7 +377,7 @@ def __init__(self, Args: endpoints: A `tf.string` tensor, Python list or scalar string defining the input stream. - `endpoints` could have the following formats: + `endpoints` supports the following formats: - "host:port": IPv4 address (default) - "tcp://": IPv4 address, - "unix://": local path as unix socket address, @@ -422,7 +422,7 @@ def from_schema(cls, Args: endpoints: A `tf.string` tensor, Python list or scalar string defining the input stream. - `endpoints` could have the following formats: + `endpoints` supports the following formats: - "host:port": IPv4 address (default) - "tcp://": IPv4 address, - "unix://": local path as unix socket address, @@ -478,7 +478,7 @@ def from_record_batches(cls, "drop_remainder" (discard partial batch data), "auto" (size to number of records in Arrow record batch) record_batch_iter_factory: Optional factory to create additional record - batch iterators after being consumed. + batch iterators for multiple iterations. """ import pyarrow as pa @@ -549,7 +549,7 @@ def from_pandas(cls, This constructor requires pandas and pyarrow to be installed. Args: - df: A Pandas DataFrame or sequence of DataFrames + data_frames: A Pandas DataFrame or sequence of DataFrames columns: Optional column indices to use, if None all are used preserve_index: Flag to include the DataFrame index as the last column batch_size: Batch size of output tensors, setting a batch size here