Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fluid benchmark support recordio reader #11121

13 changes: 8 additions & 5 deletions benchmark/fluid/fluid_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,10 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
# FIXME(wuyi): For use_reader_op, if the current
# pass is not the last, the last batch of this pass
# is also equal to args.batch_size.
num_samples += len(args.batch_size)
if args.use_reader_op:
num_samples += args.batch_size
Copy link
Contributor

@chengduoZH chengduoZH Jun 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args.batch_size is the batch size on each GPU now. So it should be num_samples += args.batch_size * args.gpus.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks very much! Done. Current know issue, if set --use_reader_op we must also set --no_test will fix this in next PR.

else:
num_samples += len(data)
train_losses.append(loss)
print("Pass: %d, Iter: %d, Loss: %f\n" %
(pass_id, iters, np.mean(train_losses)))
Expand Down Expand Up @@ -350,9 +353,6 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
if iters == args.skip_batch_num:
start_time = time.time()
num_samples = 0
# NOTE: if use reader ops, the input data is not splited to multiple cards
if args.use_reader_op and iters >= args.iterations / args.gpus:
break
if args.use_fake_data or args.use_reader_op:
try:
loss, = exe.run([avg_loss.name])
Expand All @@ -362,7 +362,10 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
loss, = exe.run([avg_loss.name], feed=feeder.feed(data))
if args.update_method == "pserver":
exe.bcast_params()
num_samples += len(data)
if args.use_reader_op:
num_samples += args.batch_size
else:
num_samples += len(data)
iters += 1
if batch_id % 1 == 0:
print("Pass %d, batch %d, loss %s" %
Expand Down