-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add min_pool_size, Add default value of should_shuffle #70
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from paddle.trainer.PyDataProvider2 import * | ||
|
||
|
||
# Define a py data provider | ||
@provider(input_types=[ | ||
dense_vector(28 * 28), | ||
integer_value(10) | ||
]) | ||
def process(settings, filename): # settings is not used currently. | ||
f = open(filename, 'r') # open one of training file | ||
|
||
for line in f: # read each line | ||
label, pixel = line.split(';') | ||
|
||
# get features and label | ||
pixels_str = pixel.split(' ') | ||
|
||
pixels_float = [] | ||
for each_pixel_str in pixels_str: | ||
pixels_float.append(float(each_pixel_str)) | ||
|
||
# give data to paddle. | ||
yield { "pixel": pixels_float, 'label': int(label) } | ||
|
||
f.close() # close file |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -149,9 +149,13 @@ void DoubleBuffer::startAsyncLoad() { | |
taskReadySem_.post(); | ||
} | ||
|
||
ClassRegistrar<DataProvider, DataConfig, bool> DataProvider::registrar_; | ||
DataProvider* DataProvider::create(const DataConfig& config, bool useGpu) { | ||
return registrar_.createByType(config.type(), config, useGpu); | ||
ClassRegistrar<DataProvider, DataConfig, ModelConfig, bool> | ||
DataProvider::registrar_; | ||
|
||
DataProvider* DataProvider::create(const DataConfig& config, | ||
const ModelConfig& modelConfig, | ||
bool useGpu) { | ||
return registrar_.createByType(config.type(), config, modelConfig, useGpu); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add ModelConfig in |
||
} | ||
|
||
REGISTER_DATA_PROVIDER(simple, SimpleDataProvider); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,15 +39,30 @@ limitations under the License. */ | |
#include "paddle/parameter/Argument.h" | ||
|
||
namespace paddle { | ||
|
||
/** | ||
* @def REGISTER_DATA_PROVIDER | ||
* @brief Macro for registering a data provider | ||
* @brief Macro for registering a data provider. The class type should contain | ||
* a consturctor with parameter (DataConfig, bool). | ||
*/ | ||
#define REGISTER_DATA_PROVIDER(__type_name, __class_name) \ | ||
static InitFunction __reg_type_##__type_name([]() { \ | ||
DataProvider::registrar_.registerClass<__class_name>(#__type_name); \ | ||
}) | ||
#define REGISTER_DATA_PROVIDER(__type_name, __class_name)\ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. write more comment |
||
static InitFunction __reg_type_##__type_name([]() {\ | ||
DataProvider::registrar_.registerClass(\ | ||
#__type_name, \ | ||
[](DataConfig conf, ModelConfig, bool useGpu) -> DataProvider* { \ | ||
DataProvider* dp = new __class_name (conf, useGpu);\ | ||
return dp;\ | ||
});\ | ||
}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. REGISTER_DATA_PROVIDER for DataProvider not use ModelConfig |
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add more comment |
||
/** | ||
* @def REGISTER_DATA_PROVIDER_EX | ||
* @brief Macro for registering a data provider, which contains a constructor | ||
* with parameter (DataConfig, ModelConfig, bool). | ||
*/ | ||
#define REGISTER_DATA_PROVIDER_EX(__type_name, __class_name) \ | ||
static InitFunction __reg_type_##__type_name([] { \ | ||
DataProvider::registrar_.registerClass<__class_name>(#__type_name); \ | ||
}) | ||
|
||
class DataBatch; | ||
class BufferBatch; | ||
|
@@ -285,10 +300,18 @@ class DoubleBuffer { | |
*/ | ||
class DataProvider { | ||
public: | ||
static ClassRegistrar<DataProvider, DataConfig, bool> registrar_; | ||
static ClassRegistrar<DataProvider, DataConfig, ModelConfig, bool> registrar_; | ||
static DataProvider* create(const DataConfig& config, | ||
const ModelConfig& modelConfig, | ||
bool useGpu = FLAGS_use_gpu); | ||
|
||
/** | ||
* @brief create only used for unittest. | ||
*/ | ||
inline static DataProvider* create(const DataConfig &config, bool useGpu) { | ||
return create(config, ModelConfig(), useGpu); | ||
} | ||
|
||
DataProvider(const DataConfig& config, bool useGpu) | ||
: config_(config), | ||
skipShuffle_(false), | ||
|
@@ -336,13 +359,13 @@ class DataProvider { | |
* @note return -1 to indicate unlimited number of samples. | ||
*/ | ||
virtual int64_t getSize() = 0; | ||
|
||
/** | ||
* @brief Get next batch training samples internally | ||
* @param[in] size size of training samples to get | ||
* @param[out] batch a batch of training samples | ||
* @return actual size of obtained training samples | ||
*/ | ||
|
||
virtual int64_t getNextBatchInternal(int64_t size, DataBatch* batch) = 0; | ||
|
||
protected: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, we use paddle.trainer.PyDataProvider2.provider's comments as documentation.