Skip to content

Commit

Permalink
Fixes to Arrow pydoc and README (#405)
Browse files Browse the repository at this point in the history
  • Loading branch information
BryanCutler authored Aug 1, 2019
1 parent 8acf617 commit 7b05cd4
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
24 changes: 12 additions & 12 deletions tensorflow_io/arrow/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ Dataset. Example usage:

```python
import tensorflow as tf
from tensorflow_io.arrow import ArrowDataset
import tensorflow_io.arrow as arrow_io

# Assume `df` is an existing Pandas DataFrame
dataset = ArrowDataset.from_pandas(df)
dataset = arrow_io.ArrowDataset.from_pandas(df)

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
Expand Down Expand Up @@ -49,11 +49,11 @@ write_feather(df, '/path/to/a.feather')

```python
import tensorflow as tf
from tensorflow_io.arrow import ArrowFeatherDataset
import tensorflow_io.arrow as arrow_io

# Each Feather file must have the same column types, here we use the above
# DataFrame which has 2 columns with dtypes=(int32, float32)
dataset = ArrowFeatherDataset(
dataset = arrow_io.ArrowFeatherDataset(
['/path/to/a.feather', '/path/to/b.feather'],
columns=(0, 1),
output_types=(tf.int32, tf.float32),
Expand All @@ -72,30 +72,30 @@ with tf.Session() as sess:
```

An alternate constructor can also be used to infer output types and shapes from
a given `pyarrow.Schema`, e.g. `dataset = ArrowFeatherDataset.from_schema(filenames, schema)`
a given `pyarrow.Schema`, e.g. `dataset = arrow_io.ArrowFeatherDataset.from_schema(filenames, schema)`

## From a Stream of Arrow Record Batches

The `ArrowStreamDataset` provides a Dataset that will connect to one or more
endpoints that are serving Arrow record batches in the Arrow stream
format. See [here](https://arrow.apache.org/docs/python/ipc.html#writing-and-reading-streams)
for more on the stream format. Currently supported endpoints are a POSIX IPv4
socket with endpoint "<IP>:<PORT>" or "tcp://<IP>:<PORT>", a Unix Domain Socket
with endpoint "unix://<pathname>", and STDIN with endpoint "fd://0" or "fd://-".
socket with endpoint "\<IP\>:\<PORT\>" or "tcp://\<IP\>:\<PORT\>", a Unix Domain Socket
with endpoint "unix://\<pathname\>", and STDIN with endpoint "fd://0" or "fd://-".

The following example will create an `ArrowStreamDataset` that will connect to
a local host endpoint that is serving an Arrow stream of record batches with 2
columns of dtypes=(int32, float32):

```python
import tensorflow as tf
from tensorflow_io.arrow import ArrowStreamDataset
import tensorflow_io.arrow as arrow_io

# The parameter `endpoints` can be a Python string or a list of strings and
# should be in the format '<HOSTNAME>:<PORT>' for an IPv4 host
endpoints = '127.0.0.1:8999'

dataset = ArrowStreamDataset(
dataset = arrow_io.ArrowStreamDataset(
endpoints,
columns=(0, 1),
output_types=(tf.int32, tf.float32),
Expand All @@ -115,7 +115,7 @@ with tf.Session() as sess:
```

An alternate constructor can also be used to infer output types and shapes from
a given `pyarrow.Schema`, e.g. `dataset = ArrowStreamDataset.from_schema(host, schema)`
a given `pyarrow.Schema`, e.g. `dataset = arrow_io.ArrowStreamDataset.from_schema(host, schema)`

## Creating Batches with Arrow Datasets

Expand All @@ -127,12 +127,12 @@ and 'auto'. If the last elements of the Dataset do not combine to the set
will automatically set a batch size to the number of records in the incoming
Arrow record batches. This a good option to use if the incoming Arrow record
batch size can be controlled to ensure the output batch size is not too large
and sequential Arrow record batches are sized equally.
and each of the Arrow record batches are sized equally.

Setting the `batch_size` or using `batch_mode` of 'auto' can be more efficient
than using `tf.data.Dataset.batch()` on an Arrow Dataset. This is because the
output tensor can be sized to the desired batch size on creation, and then data
is transferred directly from Arrow memory. Otherwise, if batching elements with
the output of an Arrow Dataset, e.g. `ArrowDataset(...).batch(batch_size=4)`,
then the tensor data will need to be aggregated and copied to get the final
batched outputs.
batched output.
8 changes: 4 additions & 4 deletions tensorflow_io/arrow/python/ops/arrow_dataset_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def __init__(self,
Args:
endpoints: A `tf.string` tensor, Python list or scalar string defining the
input stream.
`endpoints` could have the following formats:
`endpoints` supports the following formats:
- "host:port": IPv4 address (default)
- "tcp://<host:port>": IPv4 address,
- "unix://<path>": local path as unix socket address,
Expand Down Expand Up @@ -422,7 +422,7 @@ def from_schema(cls,
Args:
endpoints: A `tf.string` tensor, Python list or scalar string defining the
input stream.
`endpoints` could have the following formats:
`endpoints` supports the following formats:
- "host:port": IPv4 address (default)
- "tcp://<host:port>": IPv4 address,
- "unix://<path>": local path as unix socket address,
Expand Down Expand Up @@ -478,7 +478,7 @@ def from_record_batches(cls,
"drop_remainder" (discard partial batch data),
"auto" (size to number of records in Arrow record batch)
record_batch_iter_factory: Optional factory to create additional record
batch iterators after being consumed.
batch iterators for multiple iterations.
"""
import pyarrow as pa

Expand Down Expand Up @@ -549,7 +549,7 @@ def from_pandas(cls,
This constructor requires pandas and pyarrow to be installed.
Args:
df: A Pandas DataFrame or sequence of DataFrames
data_frames: A Pandas DataFrame or sequence of DataFrames
columns: Optional column indices to use, if None all are used
preserve_index: Flag to include the DataFrame index as the last column
batch_size: Batch size of output tensors, setting a batch size here
Expand Down

0 comments on commit 7b05cd4

Please sign in to comment.