Skip to content

Commit

Permalink
Merge pull request BVLC#2978 from lukeyeager/h5t_integer
Browse files Browse the repository at this point in the history
Allow H5T_INTEGER in HDF5 files
  • Loading branch information
jeffdonahue committed Sep 24, 2015
2 parents 37dc63c + ebc9963 commit 349ff65
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 7 deletions.
14 changes: 8 additions & 6 deletions src/caffe/test/test_data/generate_sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,25 @@
f['label'] = label
f['label2'] = label2

with h5py.File(script_dir + '/sample_data_2_gzip.h5', 'w') as f:
with h5py.File(script_dir + '/sample_data_uint8_gzip.h5', 'w') as f:
f.create_dataset(
'data', data=data + total_size,
compression='gzip', compression_opts=1
)
f.create_dataset(
'label', data=label,
compression='gzip', compression_opts=1
compression='gzip', compression_opts=1,
dtype='uint8',
)
f.create_dataset(
'label2', data=label2,
compression='gzip', compression_opts=1
compression='gzip', compression_opts=1,
dtype='uint8',
)

with open(script_dir + '/sample_data_list.txt', 'w') as f:
f.write(script_dir + '/sample_data.h5\n')
f.write(script_dir + '/sample_data_2_gzip.h5\n')
f.write('src/caffe/test/test_data/sample_data.h5\n')
f.write('src/caffe/test/test_data/sample_uint8_gzip.h5\n')

# Generate GradientBasedSolver solver_data.h5

Expand All @@ -76,4 +78,4 @@
f['targets'] = targets

with open(script_dir + '/solver_data_list.txt', 'w') as f:
f.write(script_dir + '/solver_data.h5\n')
f.write('src/caffe/test/test_data/solver_data.h5\n')
Binary file modified src/caffe/test/test_data/sample_data_2_gzip.h5
Binary file not shown.
29 changes: 28 additions & 1 deletion src/caffe/util/hdf5.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,34 @@ void hdf5_load_nd_dataset_helper(
status = H5LTget_dataset_info(
file_id, dataset_name_, dims.data(), &class_, NULL);
CHECK_GE(status, 0) << "Failed to get dataset info for " << dataset_name_;
CHECK_EQ(class_, H5T_FLOAT) << "Expected float or double data";
switch (class_) {
case H5T_FLOAT:
LOG_FIRST_N(INFO, 1) << "Datatype class: H5T_FLOAT";
break;
case H5T_INTEGER:
LOG_FIRST_N(INFO, 1) << "Datatype class: H5T_INTEGER";
break;
case H5T_TIME:
LOG(FATAL) << "Unsupported datatype class: H5T_TIME";
case H5T_STRING:
LOG(FATAL) << "Unsupported datatype class: H5T_STRING";
case H5T_BITFIELD:
LOG(FATAL) << "Unsupported datatype class: H5T_BITFIELD";
case H5T_OPAQUE:
LOG(FATAL) << "Unsupported datatype class: H5T_OPAQUE";
case H5T_COMPOUND:
LOG(FATAL) << "Unsupported datatype class: H5T_COMPOUND";
case H5T_REFERENCE:
LOG(FATAL) << "Unsupported datatype class: H5T_REFERENCE";
case H5T_ENUM:
LOG(FATAL) << "Unsupported datatype class: H5T_ENUM";
case H5T_VLEN:
LOG(FATAL) << "Unsupported datatype class: H5T_VLEN";
case H5T_ARRAY:
LOG(FATAL) << "Unsupported datatype class: H5T_ARRAY";
default:
LOG(FATAL) << "Datatype class unknown";
}

vector<int> blob_dims(dims.size());
for (int i = 0; i < dims.size(); ++i) {
Expand Down

0 comments on commit 349ff65

Please sign in to comment.