From 84e390c5a16347c7369f6c92cb62526e42ce73ac Mon Sep 17 00:00:00 2001 From: Luke Yeager Date: Thu, 24 Sep 2015 12:35:35 -0700 Subject: [PATCH 1/2] Allow H5T_INTEGER in HDF5 files --- src/caffe/util/hdf5.cpp | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/src/caffe/util/hdf5.cpp b/src/caffe/util/hdf5.cpp index d0d05f70f8f..7730e76ab87 100644 --- a/src/caffe/util/hdf5.cpp +++ b/src/caffe/util/hdf5.cpp @@ -27,7 +27,34 @@ void hdf5_load_nd_dataset_helper( status = H5LTget_dataset_info( file_id, dataset_name_, dims.data(), &class_, NULL); CHECK_GE(status, 0) << "Failed to get dataset info for " << dataset_name_; - CHECK_EQ(class_, H5T_FLOAT) << "Expected float or double data"; + switch (class_) { + case H5T_FLOAT: + LOG_FIRST_N(INFO, 1) << "Datatype class: H5T_FLOAT"; + break; + case H5T_INTEGER: + LOG_FIRST_N(INFO, 1) << "Datatype class: H5T_INTEGER"; + break; + case H5T_TIME: + LOG(FATAL) << "Unsupported datatype class: H5T_TIME"; + case H5T_STRING: + LOG(FATAL) << "Unsupported datatype class: H5T_STRING"; + case H5T_BITFIELD: + LOG(FATAL) << "Unsupported datatype class: H5T_BITFIELD"; + case H5T_OPAQUE: + LOG(FATAL) << "Unsupported datatype class: H5T_OPAQUE"; + case H5T_COMPOUND: + LOG(FATAL) << "Unsupported datatype class: H5T_COMPOUND"; + case H5T_REFERENCE: + LOG(FATAL) << "Unsupported datatype class: H5T_REFERENCE"; + case H5T_ENUM: + LOG(FATAL) << "Unsupported datatype class: H5T_ENUM"; + case H5T_VLEN: + LOG(FATAL) << "Unsupported datatype class: H5T_VLEN"; + case H5T_ARRAY: + LOG(FATAL) << "Unsupported datatype class: H5T_ARRAY"; + default: + LOG(FATAL) << "Datatype class unknown"; + } vector blob_dims(dims.size()); for (int i = 0; i < dims.size(); ++i) { From ebc9963fea7b72f397c446a10a9aeab576979566 Mon Sep 17 00:00:00 2001 From: Luke Yeager Date: Tue, 25 Aug 2015 18:58:45 -0700 Subject: [PATCH 2/2] Modify HDF5DataLayerTest to test H5T_INTEGER data --- .../test/test_data/generate_sample_data.py | 14 ++++++++------ .../test/test_data/sample_data_2_gzip.h5 | Bin 15446 -> 15446 bytes 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/caffe/test/test_data/generate_sample_data.py b/src/caffe/test/test_data/generate_sample_data.py index 3703b41823b..8349dbbc8e6 100644 --- a/src/caffe/test/test_data/generate_sample_data.py +++ b/src/caffe/test/test_data/generate_sample_data.py @@ -36,23 +36,25 @@ f['label'] = label f['label2'] = label2 -with h5py.File(script_dir + '/sample_data_2_gzip.h5', 'w') as f: +with h5py.File(script_dir + '/sample_data_uint8_gzip.h5', 'w') as f: f.create_dataset( 'data', data=data + total_size, compression='gzip', compression_opts=1 ) f.create_dataset( 'label', data=label, - compression='gzip', compression_opts=1 + compression='gzip', compression_opts=1, + dtype='uint8', ) f.create_dataset( 'label2', data=label2, - compression='gzip', compression_opts=1 + compression='gzip', compression_opts=1, + dtype='uint8', ) with open(script_dir + '/sample_data_list.txt', 'w') as f: - f.write(script_dir + '/sample_data.h5\n') - f.write(script_dir + '/sample_data_2_gzip.h5\n') + f.write('src/caffe/test/test_data/sample_data.h5\n') + f.write('src/caffe/test/test_data/sample_uint8_gzip.h5\n') # Generate GradientBasedSolver solver_data.h5 @@ -76,4 +78,4 @@ f['targets'] = targets with open(script_dir + '/solver_data_list.txt', 'w') as f: - f.write(script_dir + '/solver_data.h5\n') + f.write('src/caffe/test/test_data/solver_data.h5\n') diff --git a/src/caffe/test/test_data/sample_data_2_gzip.h5 b/src/caffe/test/test_data/sample_data_2_gzip.h5 index a138e0367be3d4b4ce4b51dcf0d7895056018883..0cb9ef92241d049b699b65f87e800f97337cae54 100644 GIT binary patch delta 225 zcmcasajjwl4+~4C%-zt<0xUly1qB!w85kG@fEYwGFmOy(l#7r6vxOKqz(ODnNCN|d z$Ha}klMPrT7=H51ueFFg#LWgGis1!kei3Wf$yr%6iSb&3kmDU2qQ`Q_!o GjsXC!dpCsu