-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Image classification #10738
Merged
daming-lu
merged 31 commits into
PaddlePaddle:develop
from
daming-lu:image_classification_word2vec
May 18, 2018
Merged
Image classification #10738
Changes from all commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
9dff644
Inferencer take infer_func as parameter
jacquesqiao d94f673
update trainer and word2vector demo
jacquesqiao 52ac039
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao acc94dc
delete unused code
jacquesqiao 214dc6e
update test_fit_a_line
jacquesqiao 1a05ae0
update test_recognize_digits_conv.py
jacquesqiao e347198
update test_recognize_digits_mlp.py
jacquesqiao ac0836c
clean code
jacquesqiao 1801557
fix test failure
76b7a43
fix style
683c48f
style
33031d9
Merge branch 'develop' into image_classification_word2vec
daming-lu 98ef2c8
rm notest
daming-lu 6bfa6ff
finish vgg
daming-lu ccd95e3
Merge remote-tracking branch 'upstream/develop' into image_classifica…
daming-lu 04200d7
style
daming-lu 2d99eb1
image classification done
daming-lu 6c3ca56
Merge remote-tracking branch 'upstream/develop' into image_classifica…
daming-lu 4108430
style
aa2b0bf
the train_network returned result has to be an array
daming-lu bc485f7
Merge branch 'image_classification_word2vec' of https://github.com/da…
daming-lu 6b405c1
add cmake file
daming-lu f863f28
Merge remote-tracking branch 'upstream/develop' into image_classifica…
daming-lu 939b2c7
move cifar10 dataset to local so that we can read a smaller dataset
daming-lu 0afb31e
Merge remote-tracking branch 'upstream/develop' into image_classifica…
daming-lu 8bd087f
switch to smaller dataset
daming-lu b23b6e8
style
f59c8c9
Merge remote-tracking branch 'upstream/develop' into image_classifica…
2ef6f2f
tune threshold to be small as the training sample is small
daming-lu 47a8b25
Merge branch 'image_classification_word2vec' of https://github.com/da…
daming-lu a6ec94e
Merge remote-tracking branch 'upstream/develop' into image_classifica…
daming-lu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
7 changes: 7 additions & 0 deletions
7
python/paddle/fluid/tests/book/high-level-api/image_classification/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") | ||
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") | ||
|
||
# default test | ||
foreach(src ${TEST_OPS}) | ||
py_test(${src} SRCS ${src}.py) | ||
endforeach() |
82 changes: 82 additions & 0 deletions
82
python/paddle/fluid/tests/book/high-level-api/image_classification/cifar10_small_test_set.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
CIFAR dataset. | ||
|
||
This module will download dataset from | ||
https://www.cs.toronto.edu/~kriz/cifar.html and parse train/test set into | ||
paddle reader creators. | ||
|
||
The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, | ||
with 6000 images per class. There are 50000 training images and 10000 test | ||
images. | ||
|
||
The CIFAR-100 dataset is just like the CIFAR-10, except it has 100 classes | ||
containing 600 images each. There are 500 training images and 100 testing | ||
images per class. | ||
|
||
""" | ||
|
||
import cPickle | ||
import itertools | ||
import numpy | ||
import paddle.v2.dataset.common | ||
import tarfile | ||
|
||
__all__ = ['train10'] | ||
|
||
URL_PREFIX = 'https://www.cs.toronto.edu/~kriz/' | ||
CIFAR10_URL = URL_PREFIX + 'cifar-10-python.tar.gz' | ||
CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a' | ||
|
||
|
||
def reader_creator(filename, sub_name, batch_size=None): | ||
def read_batch(batch): | ||
data = batch['data'] | ||
labels = batch.get('labels', batch.get('fine_labels', None)) | ||
assert labels is not None | ||
for sample, label in itertools.izip(data, labels): | ||
yield (sample / 255.0).astype(numpy.float32), int(label) | ||
|
||
def reader(): | ||
with tarfile.open(filename, mode='r') as f: | ||
names = (each_item.name for each_item in f | ||
if sub_name in each_item.name) | ||
|
||
batch_count = 0 | ||
for name in names: | ||
batch = cPickle.load(f.extractfile(name)) | ||
for item in read_batch(batch): | ||
if isinstance(batch_size, int) and batch_count > batch_size: | ||
break | ||
batch_count += 1 | ||
yield item | ||
|
||
return reader | ||
|
||
|
||
def train10(batch_size=None): | ||
""" | ||
CIFAR-10 training set creator. | ||
|
||
It returns a reader creator, each sample in the reader is image pixels in | ||
[0, 1] and label in [0, 9]. | ||
|
||
:return: Training reader creator | ||
:rtype: callable | ||
""" | ||
return reader_creator( | ||
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5), | ||
'data_batch', | ||
batch_size=batch_size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we do not add this helper method, then the minimum training loop is 1 epoch, which could take loooong time to finish.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can use Trainer.stop added by #10762
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed offline, since CI is just 10 mins away from finishing, we can merge this first and change to trainer.stop() once that PR is merged. 😬