Skip to content

Commit

Permalink
Merge pull request BVLC#3 from gnina/two_data_sources
Browse files Browse the repository at this point in the history
Two data sources
  • Loading branch information
dkoes authored Jun 13, 2017
2 parents df66a1b + dd5f587 commit 89015da
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 18 deletions.
15 changes: 14 additions & 1 deletion include/caffe/layers/molgrid_data_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,11 @@ class MolGridDataLayer : public BaseDataLayer<Dtype> {

struct examples
{
bool store_all;
bool store_actives_decoys;
bool store_pairs;
int count;

string root_folder;

vector<example> all;
Expand All @@ -247,7 +252,15 @@ class MolGridDataLayer : public BaseDataLayer<Dtype> {
int decoys_index;
bool shuffle_on_wrap; //TODO this doesn't apply to pairs for now

examples(): all_index(0), actives_index(0), decoys_index(0), shuffle_on_wrap(false) {}
examples():
store_all(true), store_actives_decoys(true), store_pairs(true), count(0),
all_index(0), actives_index(0), decoys_index(0), shuffle_on_wrap(false) {}

examples(bool all, bool actives_decoys, bool pairs, bool shuffle, string& root):
store_all(all), store_actives_decoys(actives_decoys), store_pairs(pairs), count(0),
all_index(0), actives_index(0), decoys_index(0), shuffle_on_wrap(shuffle),
root_folder(root) {}

void add(const example& ex);
void shuffle_();
void next(example& ex);
Expand Down
43 changes: 26 additions & 17 deletions src/caffe/layers/molgrid_data_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,22 @@ MolGridDataLayer<Dtype>::example::example(string line, bool hasaffinity, bool ha
template <typename Dtype>
void MolGridDataLayer<Dtype>::examples::add(const example& ex)
{
all.push_back(ex);

if (ex.label)
actives.push_back(ex);
else
decoys.push_back(ex);

pairs.add(ex);
if (store_all)
{
all.push_back(ex);
}
if (store_actives_decoys)
{
if (ex.label)
actives.push_back(ex);
else
decoys.push_back(ex);
}
if (store_pairs)
{
pairs.add(ex);
}
count++;
}

template <typename Dtype>
Expand Down Expand Up @@ -292,6 +300,8 @@ void MolGridDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
//keep track of atoms and transformations for each example in batch
batch_transform.resize(batch_size);

CHECK_LE(inmem + paired + balanced, 1) << "Only one of inmemory, paired, and balanced can be set";

if(!inmem)
{
const string& source = this->layer_param_.molgrid_data_param().source();
Expand All @@ -305,8 +315,8 @@ void MolGridDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
std::ifstream infile(source.c_str());
CHECK((bool)infile) << "Could not open " << source;

data.root_folder = root_folder;
data.shuffle_on_wrap = shuffle;
bool all = !(balanced || paired);
data = examples(all, balanced, paired, shuffle, root_folder);

string line;
while (getline(infile, line))
Expand All @@ -324,8 +334,7 @@ void MolGridDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
std::ifstream infile(source2.c_str());
CHECK((bool)infile) << "Could not open " << source2;

data2.root_folder = root_folder2;
data2.shuffle_on_wrap = shuffle;
data2 = examples(all, balanced, paired, shuffle, root_folder2);

while (getline(infile, line))
{
Expand All @@ -351,7 +360,7 @@ void MolGridDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
data2.shuffle_();
}

LOG(INFO) << "Total examples: " << data.all.size() + data2.all.size();
LOG(INFO) << "Total examples: " << data.count + data2.count;

// Check if we would need to randomly skip a few data points
if (this->layer_param_.molgrid_data_param().rand_skip())
Expand All @@ -360,17 +369,17 @@ void MolGridDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
this->layer_param_.molgrid_data_param().rand_skip();

LOG(INFO) << "Skipping first " << skip << " data points.";
CHECK_GT(data.all.size(), skip) << "Not enough points to skip in " << source;
CHECK_GT(data.count, skip) << "Not enough points to skip in " << source;

data.all_index = skip;
data.all_index = skip % data.all.size();
data.actives_index = skip % data.actives.size();
data.decoys_index = skip % data.decoys.size();

if (two_data_sources)
{
CHECK_GT(data2.all.size(), skip) << "Not enough points to skip in " << source2;
CHECK_GT(data2.count, skip) << "Not enough points to skip in " << source2;

data2.all_index = skip;
data2.all_index = skip % data2.all.size();
data2.actives_index = skip % data2.actives.size();
data2.decoys_index = skip % data2.decoys.size();
}
Expand Down

0 comments on commit 89015da

Please sign in to comment.