Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DataFeeder #6102

Merged
merged 6 commits into from
Dec 4, 2017
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/paddle/v2/fluid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,21 @@
import backward
import regularizer
from param_attr import ParamAttr

from data_feeder import DataFeeder
from core import LoDTensor, CPUPlace, GPUPlace

Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + [
'io', 'initializer', 'layers', 'nets', 'optimizer', 'backward',
'regularizer', 'LoDTensor', 'CPUPlace', 'GPUPlace', 'Tensor', 'ParamAttr'
'DataFeeder'
]


def __read_gflags_from_env__():
"""
Enable reading gflags from environment variables.

Returns:
None
"""
Expand Down
95 changes: 95 additions & 0 deletions python/paddle/v2/fluid/data_feeder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from __future__ import print_function
from framework import Variable
import core
import layers
import numpy
import six.moves as six


class DataToLoDTensorConverter(object):
def __init__(self, place, lod_level, shape, batch_size_dim, dtype):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that batch_size_dim here is no use.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. batch_size_dim could be removed.

self.place = place
self.lod_level = lod_level
self.shape = shape
self.batch_size_dim = batch_size_dim
if dtype == core.DataType.FP32:
self.dtype = 'float32'
elif dtype == core.DataType.INT64:
self.dtype = 'int64'
elif dtype == core.DataType.FP64:
self.dtype = 'float64'
elif dtype == core.DataType.INT32:
self.dtype = 'int32'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else:
  throw exception


self.data = []
self.lod = []

for i in six.range(lod_level):
self.lod.append([0])

def feed(self, data):
self._feed_impl_(data, self.lod, self.lod_level)

def _feed_impl_(self, data, lod, lod_level):
if lod_level == 0:
self.data.append(data)
else:
cur_lod_len = len(data)
lod[-1].append(lod[-1][-1] + cur_lod_len)
for each_data in data:
self._feed_impl_(each_data, lod[:-1], lod_level - 1)

def done(self):
arr = numpy.array(self.data, dtype=self.dtype).reshape(self.shape)
t = core.LoDTensor()
t.set(arr, self.place)
if self.lod_level != 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.lod_level > 0

t.set_lod(self.lod)
return t


class DataFeeder(object):
def __init__(self, feed_list, place):
self.feed_dtypes = []
self.feed_names = []
self.feed_shapes = []
self.feed_lod_level = []
for each_var in feed_list:
if not isinstance(each_var, Variable):
raise TypeError("Feed list should contain a list of variable")
self.feed_dtypes.append(each_var.dtype)
self.feed_names.append(each_var.name)
shape = each_var.shape
batch_size_dim = -1
for i, s in enumerate(shape):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is not batch_size_dim always in the first dimension?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, batch_size_dim is not the first dimension if we use static RNN

if s < 0:
batch_size_dim = i
break
if batch_size_dim == -1:
raise ValueError("Variable {0} must has a batch size dimension",
each_var.name)
self.feed_lod_level.append(each_var.lod_level)
self.feed_shapes.append((batch_size_dim, shape))

self.place = place

def feed(self, iterable):
converter = []
for lod_level, shape, dtype in six.zip(
self.feed_lod_level, self.feed_shapes, self.feed_dtypes):
batch_size_dim, shape = shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line take batch_size_dim as the first dimension.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. shape is a tuple before.

converter.append(
DataToLoDTensorConverter(
place=self.place,
lod_level=lod_level,
shape=shape,
batch_size_dim=batch_size_dim,
dtype=dtype))

for each_sample in iterable:
for each_converter, each_slot in six.zip(converter, each_sample):
each_converter.feed(each_slot)
ret_dict = {}
for each_name, each_converter in six.zip(self.feed_names, converter):
ret_dict[each_name] = each_converter.done()
return ret_dict
7 changes: 2 additions & 5 deletions python/paddle/v2/fluid/tests/book/test_fit_a_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
batch_size=BATCH_SIZE)

place = fluid.CPUPlace()
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
exe = fluid.Executor(place)

exe.run(fluid.default_startup_program())
Expand All @@ -31,12 +32,8 @@
fluid.io.save_persistables(exe, "./fit_a_line.model/")
fluid.io.load_persistables(exe, "./fit_a_line.model/")
for data in train_reader():
x_data = np.array(map(lambda _: _[0], data)).astype("float32")
y_data = np.array(map(lambda _: _[1], data)).astype("float32")

avg_loss_value, = exe.run(fluid.default_main_program(),
feed={'x': x_data,
'y': y_data},
feed=feeder.feed(data),
fetch_list=[avg_cost])

if avg_loss_value[0] < 10.0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,23 +113,14 @@ def conv_block(input, num_filter, groups, dropouts):

place = fluid.CPUPlace()
exe = fluid.Executor(place)

feeder = fluid.DataFeeder(place=place, feed_list=[images, label])
exe.run(fluid.default_startup_program())

for pass_id in range(PASS_NUM):
accuracy.reset(exe)
for data in train_reader():
img_data = np.array(map(lambda x: x[0].reshape(data_shape),
data)).astype("float32")
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
batch_size = 1
for i in y_data.shape:
batch_size = batch_size * i
y_data = y_data.reshape([batch_size, 1])

loss, acc = exe.run(fluid.default_main_program(),
feed={"pixel": img_data,
"label": y_data},
feed=feeder.feed(data),
fetch_list=[avg_cost] + accuracy.metrics)
pass_acc = accuracy.eval(exe)
print("loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str(
Expand Down
60 changes: 27 additions & 33 deletions python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,9 @@ def load_parameter(file_name, h, w):
return np.fromfile(f, dtype=np.float32).reshape(h, w)


def db_lstm():
def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark,
**ignored):
# 8 features
word = fluid.layers.data(name='word_data', shape=[1], dtype='int64')
predicate = fluid.layers.data(name='verb_data', shape=[1], dtype='int64')
ctx_n2 = fluid.layers.data(name='ctx_n2_data', shape=[1], dtype='int64')
ctx_n1 = fluid.layers.data(name='ctx_n1_data', shape=[1], dtype='int64')
ctx_0 = fluid.layers.data(name='ctx_0_data', shape=[1], dtype='int64')
ctx_p1 = fluid.layers.data(name='ctx_p1_data', shape=[1], dtype='int64')
ctx_p2 = fluid.layers.data(name='ctx_p2_data', shape=[1], dtype='int64')
mark = fluid.layers.data(name='mark_data', shape=[1], dtype='int64')

predicate_embedding = fluid.layers.embedding(
input=predicate,
size=[pred_len, word_dim],
Expand Down Expand Up @@ -120,8 +112,25 @@ def to_lodtensor(data, place):

def main():
# define network topology
feature_out = db_lstm()
target = fluid.layers.data(name='target', shape=[1], dtype='int64')
word = fluid.layers.data(
name='word_data', shape=[1], dtype='int64', lod_level=1)
predicate = fluid.layers.data(
name='verb_data', shape=[1], dtype='int64', lod_level=1)
ctx_n2 = fluid.layers.data(
name='ctx_n2_data', shape=[1], dtype='int64', lod_level=1)
ctx_n1 = fluid.layers.data(
name='ctx_n1_data', shape=[1], dtype='int64', lod_level=1)
ctx_0 = fluid.layers.data(
name='ctx_0_data', shape=[1], dtype='int64', lod_level=1)
ctx_p1 = fluid.layers.data(
name='ctx_p1_data', shape=[1], dtype='int64', lod_level=1)
ctx_p2 = fluid.layers.data(
name='ctx_p2_data', shape=[1], dtype='int64', lod_level=1)
mark = fluid.layers.data(
name='mark_data', shape=[1], dtype='int64', lod_level=1)
feature_out = db_lstm(**locals())
target = fluid.layers.data(
name='target', shape=[1], dtype='int64', lod_level=1)
crf_cost = fluid.layers.linear_chain_crf(
input=feature_out,
label=target,
Expand All @@ -139,6 +148,11 @@ def main():
paddle.dataset.conll05.test(), buf_size=8192),
batch_size=BATCH_SIZE)
place = fluid.CPUPlace()
feeder = fluid.DataFeeder(
feed_list=[
word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, predicate, mark, target
],
place=place)
exe = fluid.Executor(place)

exe.run(fluid.default_startup_program())
Expand All @@ -150,28 +164,8 @@ def main():
batch_id = 0
for pass_id in xrange(PASS_NUM):
for data in train_data():
word_data = to_lodtensor(map(lambda x: x[0], data), place)
ctx_n2_data = to_lodtensor(map(lambda x: x[1], data), place)
ctx_n1_data = to_lodtensor(map(lambda x: x[2], data), place)
ctx_0_data = to_lodtensor(map(lambda x: x[3], data), place)
ctx_p1_data = to_lodtensor(map(lambda x: x[4], data), place)
ctx_p2_data = to_lodtensor(map(lambda x: x[5], data), place)
verb_data = to_lodtensor(map(lambda x: x[6], data), place)
mark_data = to_lodtensor(map(lambda x: x[7], data), place)
target = to_lodtensor(map(lambda x: x[8], data), place)

outs = exe.run(fluid.default_main_program(),
feed={
'word_data': word_data,
'ctx_n2_data': ctx_n2_data,
'ctx_n1_data': ctx_n1_data,
'ctx_0_data': ctx_0_data,
'ctx_p1_data': ctx_p1_data,
'ctx_p2_data': ctx_p2_data,
'verb_data': verb_data,
'mark_data': mark_data,
'target': target
},
feed=feeder.feed(data),
fetch_list=[avg_cost])
avg_cost_val = np.array(outs[0])

Expand Down
10 changes: 2 additions & 8 deletions python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,14 @@

place = fluid.CPUPlace()
exe = fluid.Executor(place)

feeder = fluid.DataFeeder(feed_list=[images, label], place=place)
exe.run(fluid.default_startup_program())

for pass_id in range(PASS_NUM):
accuracy.reset(exe)
for data in train_reader():
img_data = np.array(map(lambda x: x[0].reshape([1, 28, 28]),
data)).astype("float32")
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
y_data = y_data.reshape([BATCH_SIZE, 1])

loss, acc = exe.run(fluid.default_main_program(),
feed={"pixel": img_data,
"label": y_data},
feed=feeder.feed(data),
fetch_list=[avg_cost] + accuracy.metrics)
pass_acc = accuracy.eval(exe)
print("pass_id=" + str(pass_id) + " acc=" + str(acc) + " pass_acc=" +
Expand Down
28 changes: 5 additions & 23 deletions python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,40 +48,22 @@

place = fluid.CPUPlace()
exe = fluid.Executor(place)

feeder = fluid.DataFeeder(feed_list=[image, label], place=place)
exe.run(fluid.default_startup_program())

PASS_NUM = 100
for pass_id in range(PASS_NUM):
accuracy.reset(exe)
for data in train_reader():
x_data = np.array(map(lambda x: x[0], data)).astype("float32")
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
y_data = np.expand_dims(y_data, axis=1)

tensor_x = fluid.LoDTensor()
tensor_x.set(x_data, place)

tensor_y = fluid.LoDTensor()
tensor_y.set(y_data, place)

outs = exe.run(fluid.default_main_program(),
feed={'x': tensor_x,
'y': tensor_y},
fetch_list=[avg_cost] + accuracy.metrics)
out = np.array(outs[0])
acc = np.array(outs[1])
out, acc = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost] + accuracy.metrics)
pass_acc = accuracy.eval(exe)

test_accuracy.reset(exe)
for data in test_reader():
x_data = np.array(map(lambda x: x[0], data)).astype("float32")
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
y_data = np.expand_dims(y_data, axis=1)

out, acc = exe.run(inference_program,
feed={'x': x_data,
'y': y_data},
feed=feeder.feed(data),
fetch_list=[avg_cost] + test_accuracy.metrics)

test_pass_acc = test_accuracy.eval(exe)
Expand Down
28 changes: 10 additions & 18 deletions python/paddle/v2/fluid/tests/book/test_understand_sentiment_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
import paddle.v2.fluid as fluid


def convolution_net(input_dim, class_dim=2, emb_dim=32, hid_dim=32):
data = fluid.layers.data(name="words", shape=[1], dtype="int64")
label = fluid.layers.data(name="label", shape=[1], dtype="int64")

def convolution_net(data, label, input_dim, class_dim=2, emb_dim=32,
hid_dim=32):
emb = fluid.layers.embedding(input=data, size=[input_dim, emb_dim])
conv_3 = fluid.nets.sequence_conv_pool(
input=emb,
Expand Down Expand Up @@ -55,34 +53,28 @@ def main():
dict_dim = len(word_dict)
class_dim = 2

data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
cost, accuracy, acc_out = convolution_net(
input_dim=dict_dim, class_dim=class_dim)
data, label, input_dim=dict_dim, class_dim=class_dim)

train_data = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.imdb.train(word_dict), buf_size=1000),
batch_size=BATCH_SIZE)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[data, label], place=place)

exe.run(fluid.default_startup_program())

for pass_id in xrange(PASS_NUM):
accuracy.reset(exe)
for data in train_data():
tensor_words = to_lodtensor(map(lambda x: x[0], data), place)

label = np.array(map(lambda x: x[1], data)).astype("int64")
label = label.reshape([BATCH_SIZE, 1])

tensor_label = fluid.LoDTensor()
tensor_label.set(label, place)

cost_val, acc_val = exe.run(
fluid.default_main_program(),
feed={"words": tensor_words,
"label": tensor_label},
fetch_list=[cost, acc_out])
cost_val, acc_val = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[cost, acc_out])
pass_acc = accuracy.eval(exe)
print("cost=" + str(cost_val) + " acc=" + str(acc_val) +
" pass_acc=" + str(pass_acc))
Expand Down
Loading