Skip to content

Commit

Permalink
Support universal inputs (PaddlePaddle#56)
Browse files Browse the repository at this point in the history
* add support for universal input
  • Loading branch information
lilong12 authored Jul 1, 2020
1 parent 632a8f3 commit 59f79be
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 63 deletions.
140 changes: 82 additions & 58 deletions plsc/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@

import errno
import json
import logging
import math
import os
import shutil
import subprocess
import sys
import tempfile
import time
import logging

import numpy as np
import paddle
Expand All @@ -43,6 +42,7 @@
from .utils.learning_rate import lr_warmup
from .utils.parameter_converter import ParameterConverter
from .utils.verification import evaluate
from .utils.input_field import InputField

log_handler = logging.StreamHandler()
log_format = logging.Formatter(
Expand Down Expand Up @@ -116,7 +116,6 @@ def __init__(self):
self.val_targets = self.config.val_targets
self.dataset_dir = self.config.dataset_dir
self.num_classes = self.config.num_classes
self.image_shape = self.config.image_shape
self.loss_type = self.config.loss_type
self.margin = self.config.margin
self.scale = self.config.scale
Expand All @@ -142,6 +141,15 @@ def __init__(self):
self.lr_decay_factor = 0.1
self.log_period = 200

self.input_info = [{'name': 'image',
'shape': [-1, 3, 224, 224],
'dtype': 'float32'},
{'name': 'label',
'shape':[-1, 1],
'dtype': 'int64'}
]
self.input_field = None

logger.info('=' * 30)
logger.info("Default configuration:")
for key in self.config:
Expand All @@ -152,6 +160,31 @@ def __init__(self):
logger.info('default log period: {}'.format(self.log_period))
logger.info('=' * 30)

def set_input_info(self, input):
"""
Set the information of inputs which is a list or tuple. Each element
is a dict which contains the info of a input, including name, dtype
and shape.
"""
if not (isinstance(input, list) or isinstance(input, tuple)):
raise ValueError("The type of 'input' must be list or tuple.")

has_label = False
for element in input:
assert isinstance(element, dict), (
"The type of elements for input must be dict")
assert 'name' in element.keys(), (
"Every element has to contain the key 'name'")
assert 'shape' in element.keys(), (
"Every element has to contain the key 'shape'")
assert 'dtype' in element.keys(), (
"Every element has to contain the key 'dtype'")
if element['name'] == 'label':
has_label = True
assert has_label, "The input must contain a field named 'label'"

self.input_info = input

def set_val_targets(self, targets):
"""
Set the names of validation datasets, separated by comma.
Expand Down Expand Up @@ -314,12 +347,6 @@ def set_loss_type(self, loss_type):
self.loss_type = loss_type
logger.info("Set loss_type to {}.".format(loss_type))

def set_image_shape(self, shape):
if not isinstance(shape, (list, tuple)):
raise ValueError("Shape must be of type list or tuple")
self.image_shape = shape
logger.info("Set image_shape to {}.".format(shape))

def set_optimizer(self, optimizer):
if not isinstance(optimizer, Optimizer):
raise ValueError("Optimizer must be of type Optimizer")
Expand Down Expand Up @@ -404,7 +431,6 @@ def build_program(self,
trainer_id = self.trainer_id
num_trainers = self.num_trainers

image_shape = [int(m) for m in self.image_shape]
# model definition
model = self.model
if model is None:
Expand All @@ -413,15 +439,11 @@ def build_program(self,
startup_program = self.startup_program
with fluid.program_guard(main_program, startup_program):
with fluid.unique_name.guard():
image = fluid.layers.data(name='image',
shape=image_shape,
dtype='float32')
label = fluid.layers.data(name='label',
shape=[1],
dtype='int64')

emb, loss, prob = model.get_output(input=image,
label=label,
input_field = InputField(self.input_info)
input_field.build()
self.input_field = input_field

emb, loss, prob = model.get_output(input=input_field,
num_ranks=num_trainers,
rank_id=trainer_id,
is_train=is_train,
Expand Down Expand Up @@ -449,7 +471,7 @@ def build_program(self,
num_or_sections=num_trainers)
prob = fluid.layers.concat(prob_list, axis=1)
label_all = fluid.layers.collective._c_allgather(
label,
input_field.label,
nranks=num_trainers,
use_calc_stream=True)
acc1 = fluid.layers.accuracy(input=prob,
Expand All @@ -461,10 +483,10 @@ def build_program(self,
else:
if self.calc_train_acc:
acc1 = fluid.layers.accuracy(input=prob,
label=label,
label=input_field.label,
k=1)
acc5 = fluid.layers.accuracy(input=prob,
label=label,
label=input_field.label,
k=5)

optimizer = None
Expand All @@ -489,7 +511,7 @@ def build_program(self,
def get_files_from_hdfs(self):
assert self.fs_checkpoint_dir, \
logger.error("Please set the fs_checkpoint_dir paramerters for "
"set_hdfs_info to get models from hdfs.")
"set_llllllhdfs_info to get models from hdfs.")
self.fs_checkpoint_dir = os.path.join(self.fs_checkpoint_dir, '*')
cmd = "hadoop fs -D fs.default.name="
cmd += self.fs_name + " "
Expand Down Expand Up @@ -631,15 +653,10 @@ def convert_for_prediction(self):
startup_program = self.startup_program
with fluid.program_guard(main_program, startup_program):
with fluid.unique_name.guard():
image = fluid.layers.data(name='image',
shape=image_shape,
dtype='float32')
label = fluid.layers.data(name='label',
shape=[1],
dtype='int64')

emb = model.build_network(input=image,
label=label,
input_field = InputField(self.input_info)
input_field.build()

emb = model.build_network(input=input_field,
is_train=False)

gpu_id = int(os.getenv("FLAGS_selected_gpus", 0))
Expand All @@ -658,8 +675,12 @@ def convert_for_prediction(self):
logger.info("model_save_dir for inference model ({}) exists, "
"we will overwrite it.".format(self.model_save_dir))
shutil.rmtree(self.model_save_dir)
feed_var_names = []
for name in input_field.feed_list_str:
if name == "label": continue
feed_var_names.append(name)
fluid.io.save_inference_model(self.model_save_dir,
feeded_var_names=[image.name],
feeded_var_names=feed_var_names,
target_vars=[emb],
executor=exe,
main_program=main_program)
Expand All @@ -678,7 +699,6 @@ def _get_info(self, key):

def predict(self):
model_name = self.model_name
image_shape = [int(m) for m in self.image_shape]
# model definition
model = self.model
if model is None:
Expand All @@ -687,15 +707,10 @@ def predict(self):
startup_program = self.startup_program
with fluid.program_guard(main_program, startup_program):
with fluid.unique_name.guard():
image = fluid.layers.data(name='image',
shape=image_shape,
dtype='float32')
label = fluid.layers.data(name='label',
shape=[1],
dtype='int64')

emb = model.build_network(input=image,
label=label,
input_field = InputField(self.input_info)
input_field.build()

emb = model.build_network(input=input_field,
is_train=False)

gpu_id = int(os.getenv("FLAGS_selected_gpus", 0))
Expand All @@ -709,20 +724,20 @@ def predict(self):
load_for_train=False)

if self.predict_reader is None:
predict_reader = paddle.batch(reader.arc_train(self.dataset_dir,
self.num_classes),
batch_size=self.train_batch_size)
predict_reader = reader.arc_train(self.dataset_dir,
self.num_classes)
else:
predict_reader = self.predict_reader

feeder = fluid.DataFeeder(place=place,
feed_list=['image', 'label'],
program=main_program)
input_field.loader.set_sample_generator(
predict_reader,
batch_size=self.train_batch_size,
places=place)

fetch_list = [emb.name]
for data in predict_reader():
for data in input_field.loader:
emb = exe.run(main_program,
feed=feeder.feed(data),
feed=data,
fetch_list=fetch_list,
use_program_cache=True)
print("emb: ", emb)
Expand All @@ -741,6 +756,14 @@ def _run_test(self,
for j in range(len(data_list)):
data = data_list[j]
embeddings = None
# For multi-card test, the dataset can be partitioned into two
# part. For the first part, the total number of samples is
# divisiable by the number of cards. And then, these samples
# are split on different cards and tested parallely. For the
# second part, these samples are tested on all cards but only
# the result of the first card is used.

# The number of steps for parallel test.
parallel_test_steps = data.shape[0] // real_test_batch_size
for idx in range(parallel_test_steps):
start = idx * real_test_batch_size
Expand Down Expand Up @@ -876,7 +899,7 @@ def test(self):
load_for_train=False)

feeder = fluid.DataFeeder(place=place,
feed_list=['image', 'label'],
feed_list=self.input_field.feed_list_str,
program=test_program)
fetch_list = [emb_name]

Expand Down Expand Up @@ -940,9 +963,10 @@ def train(self):
else:
train_reader = self.train_reader

feeder = fluid.DataFeeder(place=place,
feed_list=['image', 'label'],
program=origin_prog)
self.input_field.loader.set_sample_generator(
train_reader,
batch_size=self.train_batch_size,
places=place)

if self.calc_train_acc:
fetch_list = [loss.name, global_lr.name,
Expand All @@ -958,19 +982,19 @@ def train(self):
self.train_pass_id = pass_id
train_info = [[], [], [], []]
local_train_info = [[], [], [], []]
for batch_id, data in enumerate(train_reader()):
for batch_id, data in enumerate(self.input_field.loader):
nsamples += global_batch_size
t1 = time.time()
acc1 = None
acc5 = None
if self.calc_train_acc:
loss, lr, acc1, acc5 = exe.run(train_prog,
feed=feeder.feed(data),
feed=data,
fetch_list=fetch_list,
use_program_cache=True)
else:
loss, lr = exe.run(train_prog,
feed=feeder.feed(data),
feed=data,
fetch_list=fetch_list,
use_program_cache=True)
t2 = time.time()
Expand Down
6 changes: 3 additions & 3 deletions plsc/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class BaseModel(object):
def __init__(self):
super(BaseModel, self).__init__()

def build_network(self, input, label, is_train=True):
def build_network(self, input, is_train=True):
"""
Construct the custom model, and we will add the distributed fc layer
at the end of your model automatically.
Expand All @@ -43,7 +43,6 @@ def build_network(self, input, label, is_train=True):

def get_output(self,
input,
label,
num_classes,
num_ranks=1,
rank_id=0,
Expand Down Expand Up @@ -76,7 +75,8 @@ def get_output(self,
"Supported loss types: {}, but given: {}".format(
supported_loss_types, loss_type)

emb = self.build_network(input, label, is_train)
emb = self.build_network(input, is_train)
label = input.label
prob = None
loss = None
if loss_type == "softmax":
Expand Down
3 changes: 1 addition & 2 deletions plsc/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(self, layers=50, emb_dim=512):

def build_network(self,
input,
label,
is_train=True):
layers = self.layers
supported_layers = [50, 101, 152]
Expand All @@ -44,7 +43,7 @@ def build_network(self,
num_filters = [64, 128, 256, 512]

conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=3, stride=1,
input=input.image, num_filters=64, filter_size=3, stride=1,
pad=1, act='prelu', is_train=is_train)

for block in range(len(depth)):
Expand Down
Loading

0 comments on commit 59f79be

Please sign in to comment.