From 98632f46e040217cad029e3421163c107f445e1e Mon Sep 17 00:00:00 2001 From: guosheng Date: Thu, 30 Nov 2017 23:56:01 +0800 Subject: [PATCH 1/2] Add the Inception-ResNet-v2 model --- image_classification/README.md | 21 +- image_classification/inception_resnet_v2.py | 326 ++++++++++++++++++++ image_classification/infer.py | 23 +- image_classification/train.py | 15 +- 4 files changed, 370 insertions(+), 15 deletions(-) create mode 100644 image_classification/inception_resnet_v2.py diff --git a/image_classification/README.md b/image_classification/README.md index 843d683c00..6e23349200 100644 --- a/image_classification/README.md +++ b/image_classification/README.md @@ -1,7 +1,7 @@ 图像分类 ======================= -这里将介绍如何在PaddlePaddle下使用AlexNet、VGG、GoogLeNet和ResNet模型进行图像分类。图像分类问题的描述和这四种模型的介绍可以参考[PaddlePaddle book](https://github.com/PaddlePaddle/book/tree/develop/03.image_classification)。 +这里将介绍如何在PaddlePaddle下使用AlexNet、VGG、GoogLeNet、ResNet和Inception-ResNet-v2模型进行图像分类。图像分类问题的描述和这五种模型的介绍可以参考[PaddlePaddle book](https://github.com/PaddlePaddle/book/tree/develop/03.image_classification)。 ## 训练模型 @@ -11,6 +11,8 @@ ```python import gzip +import argparse + import paddle.v2.dataset.flowers as flowers import paddle.v2 as paddle import reader @@ -18,6 +20,7 @@ import vgg import resnet import alexnet import googlenet +import inception_resnet_v2 # PaddlePaddle init @@ -29,7 +32,7 @@ paddle.init(use_gpu=False, trainer_count=1) 设置算法参数(如数据维度、类别数目和batch size等参数),定义数据输入层`image`和类别标签`lbl`。 ```python -DATA_DIM = 3 * 224 * 224 +DATA_DIM = 3 * 224 * 224 # Use 3 * 331 * 331 or 3 * 299 * 299 for Inception-ResNet-v2. CLASS_DIM = 102 BATCH_SIZE = 128 @@ -41,7 +44,7 @@ lbl = paddle.layer.data( ### 获得所用模型 -这里可以选择使用AlexNet、VGG、GoogLeNet和ResNet模型中的一个模型进行图像分类。通过调用相应的方法可以获得网络最后的Softmax层。 +这里可以选择使用AlexNet、VGG、GoogLeNet、ResNet和Inception-ResNet-v2模型中的一个模型进行图像分类。通过调用相应的方法可以获得网络最后的Softmax层。 1. 使用AlexNet模型 @@ -86,6 +89,16 @@ ResNet模型可以通过下面的代码获取: out = resnet.resnet_imagenet(image, class_dim=CLASS_DIM) ``` +5. 使用Inception-ResNet-v2模型 + +提供的Inception-ResNet-v2模型支持`3 * 331 * 331`和`3 * 299 * 299`两种大小的输入,同时可以自行设置dropout概率,可以通过如下的代码使用: + +```python +out = inception_resnet_v2.inception_resnet_v2(image, class_dim=CLASS_DIM, dropout_rate=0.5, size=DATA_DIM) +``` + +注意,由于和其他几种模型输入大小不同,若配合提供的`reader.py`使用Inception-ResNet-v2时请先将`reader.py`中`paddle.image.simple_transform`中的参数为修改为相应大小。 + ### 定义损失函数 ```python @@ -173,7 +186,7 @@ def event_handler(event): ### 定义训练方法 -对于AlexNet、VGG和ResNet,可以按下面的代码定义训练方法: +对于AlexNet、VGG、ResNet和Inception-ResNet-v2,可以按下面的代码定义训练方法: ```python # Create trainer diff --git a/image_classification/inception_resnet_v2.py b/image_classification/inception_resnet_v2.py new file mode 100644 index 0000000000..6dd08c5e7a --- /dev/null +++ b/image_classification/inception_resnet_v2.py @@ -0,0 +1,326 @@ +import paddle.v2 as paddle + + +def conv_bn_layer(input, + ch_out, + filter_size, + stride, + padding=0, + active_type=paddle.activation.Relu(), + ch_in=None): + tmp = paddle.layer.img_conv( + input=input, + filter_size=filter_size, + num_channels=ch_in, + num_filters=ch_out, + stride=stride, + padding=padding, + act=paddle.activation.Linear(), + bias_attr=False) + return paddle.layer.batch_norm(input=tmp, epsilon=0.001, act=active_type) + + +def sequential_block(input, *layers): + for layer in layers: + layer_func, layer_conf = layer + input = layer_func(input, **layer_conf) + return input + + +def mixed_5b_block(input): + branch0 = conv_bn_layer( + input, ch_in=192, ch_out=96, filter_size=1, stride=1) + branch1 = sequential_block(input, (conv_bn_layer, { + "ch_in": 192, + "ch_out": 48, + "filter_size": 1, + "stride": 1 + }), (conv_bn_layer, { + "ch_in": 48, + "ch_out": 64, + "filter_size": 5, + "stride": 1, + "padding": 2 + })) + branch2 = sequential_block(input, (conv_bn_layer, { + "ch_in": 192, + "ch_out": 64, + "filter_size": 1, + "stride": 1 + }), (conv_bn_layer, { + "ch_in": 64, + "ch_out": 96, + "filter_size": 3, + "stride": 1, + "padding": 1 + }), (conv_bn_layer, { + "ch_in": 96, + "ch_out": 96, + "filter_size": 3, + "stride": 1, + "padding": 1 + })) + branch3 = sequential_block( + input, + (paddle.layer.img_pool, { + "pool_size": 3, + "stride": 1, + "padding": 1, + "pool_type": paddle.pooling.Avg(), + "exclude_mode": False + }), + (conv_bn_layer, { + "ch_in": 192, + "ch_out": 64, + "filter_size": 1, + "stride": 1 + }), ) + out = paddle.layer.concat(input=[branch0, branch1, branch2, branch3]) + return out + + +def block35(input, scale=1.0): + branch0 = conv_bn_layer( + input, ch_in=320, ch_out=32, filter_size=1, stride=1) + branch1 = sequential_block(input, (conv_bn_layer, { + "ch_in": 320, + "ch_out": 32, + "filter_size": 1, + "stride": 1 + }), (conv_bn_layer, { + "ch_in": 32, + "ch_out": 32, + "filter_size": 3, + "stride": 1, + "padding": 1 + })) + branch2 = sequential_block(input, (conv_bn_layer, { + "ch_in": 320, + "ch_out": 32, + "filter_size": 1, + "stride": 1 + }), (conv_bn_layer, { + "ch_in": 32, + "ch_out": 48, + "filter_size": 3, + "stride": 1, + "padding": 1 + }), (conv_bn_layer, { + "ch_in": 48, + "ch_out": 64, + "filter_size": 3, + "stride": 1, + "padding": 1 + })) + out = paddle.layer.concat(input=[branch0, branch1, branch2]) + out = paddle.layer.img_conv( + input=out, + filter_size=1, + num_channels=128, + num_filters=320, + stride=1, + padding=0, + act=paddle.activation.Linear(), + bias_attr=None) + out = paddle.layer.slope_intercept(out, slope=scale, intercept=0.0) + out = paddle.layer.addto(input=[input, out], act=paddle.activation.Relu()) + return out + + +def mixed_6a_block(input): + branch0 = conv_bn_layer( + input, ch_in=320, ch_out=384, filter_size=3, stride=2) + branch1 = sequential_block(input, (conv_bn_layer, { + "ch_in": 320, + "ch_out": 256, + "filter_size": 1, + "stride": 1 + }), (conv_bn_layer, { + "ch_in": 256, + "ch_out": 256, + "filter_size": 3, + "stride": 1, + "padding": 1 + }), (conv_bn_layer, { + "ch_in": 256, + "ch_out": 384, + "filter_size": 3, + "stride": 2 + })) + branch2 = paddle.layer.img_pool( + input, + num_channels=320, + pool_size=3, + stride=2, + pool_type=paddle.pooling.Max()) + out = paddle.layer.concat(input=[branch0, branch1, branch2]) + return out + + +def block17(input, scale=1.0): + branch0 = conv_bn_layer( + input, ch_in=1088, ch_out=192, filter_size=1, stride=1) + branch1 = sequential_block(input, (conv_bn_layer, { + "ch_in": 1088, + "ch_out": 128, + "filter_size": 1, + "stride": 1 + }), (conv_bn_layer, { + "ch_in": 128, + "ch_out": 160, + "filter_size": [7, 1], + "stride": 1, + "padding": [3, 0] + }), (conv_bn_layer, { + "ch_in": 160, + "ch_out": 192, + "filter_size": [1, 7], + "stride": 1, + "padding": [0, 3] + })) + out = paddle.layer.concat(input=[branch0, branch1]) + out = paddle.layer.img_conv( + input=out, + filter_size=1, + num_channels=384, + num_filters=1088, + stride=1, + padding=0, + act=paddle.activation.Linear(), + bias_attr=None) + out = paddle.layer.slope_intercept(out, slope=scale, intercept=0.0) + out = paddle.layer.addto(input=[input, out], act=paddle.activation.Relu()) + return out + + +def mixed_7a_block(input): + branch0 = sequential_block( + input, + (conv_bn_layer, { + "ch_in": 1088, + "ch_out": 256, + "filter_size": 1, + "stride": 1 + }), + (conv_bn_layer, { + "ch_in": 256, + "ch_out": 384, + "filter_size": 3, + "stride": 2 + }), ) + branch1 = sequential_block( + input, + (conv_bn_layer, { + "ch_in": 1088, + "ch_out": 256, + "filter_size": 1, + "stride": 1 + }), + (conv_bn_layer, { + "ch_in": 256, + "ch_out": 288, + "filter_size": 3, + "stride": 2 + }), ) + branch2 = sequential_block(input, (conv_bn_layer, { + "ch_in": 1088, + "ch_out": 256, + "filter_size": 1, + "stride": 1 + }), (conv_bn_layer, { + "ch_in": 256, + "ch_out": 288, + "filter_size": 3, + "stride": 1, + "padding": 1 + }), (conv_bn_layer, { + "ch_in": 288, + "ch_out": 320, + "filter_size": 3, + "stride": 2 + })) + branch3 = paddle.layer.img_pool( + input, + num_channels=1088, + pool_size=3, + stride=2, + pool_type=paddle.pooling.Max()) + out = paddle.layer.concat(input=[branch0, branch1, branch2, branch3]) + return out + + +def block8(input, scale=1.0, no_relu=False): + branch0 = conv_bn_layer( + input, ch_in=2080, ch_out=192, filter_size=1, stride=1) + branch1 = sequential_block(input, (conv_bn_layer, { + "ch_in": 2080, + "ch_out": 192, + "filter_size": 1, + "stride": 1 + }), (conv_bn_layer, { + "ch_in": 192, + "ch_out": 224, + "filter_size": [3, 1], + "stride": 1, + "padding": [1, 0] + }), (conv_bn_layer, { + "ch_in": 224, + "ch_out": 256, + "filter_size": [1, 3], + "stride": 1, + "padding": [0, 1] + })) + out = paddle.layer.concat(input=[branch0, branch1]) + out = paddle.layer.img_conv( + input=out, + filter_size=1, + num_channels=448, + num_filters=2080, + stride=1, + padding=0, + act=paddle.activation.Linear(), + bias_attr=None) + out = paddle.layer.slope_intercept(out, slope=scale, intercept=0.0) + out = paddle.layer.addto( + input=[input, out], + act=paddle.activation.Linear() if no_relu else paddle.activation.Relu()) + return out + + +def inception_resnet_v2(input, + class_dim, + dropout_rate=0.5, + data_dim=3 * 331 * 331): + conv2d_1a = conv_bn_layer( + input, ch_in=3, ch_out=32, filter_size=3, stride=2) + conv2d_2a = conv_bn_layer( + conv2d_1a, ch_in=32, ch_out=32, filter_size=3, stride=1) + conv2d_2b = conv_bn_layer( + conv2d_2a, ch_in=32, ch_out=64, filter_size=3, stride=1, padding=1) + maxpool_3a = paddle.layer.img_pool( + input=conv2d_2b, pool_size=3, stride=2, pool_type=paddle.pooling.Max()) + conv2d_3b = conv_bn_layer( + maxpool_3a, ch_in=64, ch_out=80, filter_size=1, stride=1) + conv2d_4a = conv_bn_layer( + conv2d_3b, ch_in=80, ch_out=192, filter_size=3, stride=1) + maxpool_5a = paddle.layer.img_pool( + input=conv2d_4a, pool_size=3, stride=2, pool_type=paddle.pooling.Max()) + mixed_5b = mixed_5b_block(maxpool_5a) + repeat = sequential_block(mixed_5b, *([(block35, {"scale": 0.17})] * 10)) + mixed_6a = mixed_6a_block(repeat) + repeat1 = sequential_block(mixed_6a, *([(block17, {"scale": 0.10})] * 20)) + mixed_7a = mixed_7a_block(repeat1) + repeat2 = sequential_block(mixed_7a, *([(block8, {"scale": 0.20})] * 9)) + block_8 = block8(repeat2, no_relu=True) + conv2d_7b = conv_bn_layer( + block_8, ch_in=2080, ch_out=1536, filter_size=1, stride=1) + avgpool_1a = paddle.layer.img_pool( + input=conv2d_7b, + pool_size=8 if data_dim == 3 * 299 * 299 else 9, + stride=1, + pool_type=paddle.pooling.Avg(), + exclude_mode=False) + drop_out = paddle.layer.dropout(input=avgpool_1a, dropout_rate=dropout_rate) + out = paddle.layer.fc( + input=drop_out, size=class_dim, act=paddle.activation.Softmax()) + return out diff --git a/image_classification/infer.py b/image_classification/infer.py index 659c4f2a8e..ed68d3c8a4 100644 --- a/image_classification/infer.py +++ b/image_classification/infer.py @@ -1,18 +1,18 @@ +import os import gzip +import argparse +import numpy as np +from PIL import Image + import paddle.v2 as paddle import reader import vgg import resnet import alexnet import googlenet -import argparse -import os -from PIL import Image -import numpy as np +import inception_resnet_v2 -WIDTH = 224 -HEIGHT = 224 -DATA_DIM = 3 * WIDTH * HEIGHT +DATA_DIM = 3 * 224 * 224 # Use 3 * 331 * 331 or 3 * 299 * 299 for Inception-ResNet-v2. CLASS_DIM = 102 @@ -26,7 +26,10 @@ def main(): parser.add_argument( 'model', help='The model for image classification', - choices=['alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet']) + choices=[ + 'alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet', + 'inception-resnet-v2' + ]) parser.add_argument( 'params_path', help='The file which stores the parameters') args = parser.parse_args() @@ -49,6 +52,10 @@ def main(): out = resnet.resnet_imagenet(image, class_dim=CLASS_DIM) elif args.model == 'googlenet': out, _, _ = googlenet.googlenet(image, class_dim=CLASS_DIM) + elif args.model == 'inception-resnet-v2': + assert DATA_DIM == 3 * 331 * 331 or DATA_DIM == 3 * 299 * 299 + out = inception_resnet_v2.inception_resnet_v2( + image, class_dim=CLASS_DIM, dropout_rate=0.5, data_dim=DATA_DIM) # load parameters with gzip.open(args.params_path, 'r') as f: diff --git a/image_classification/train.py b/image_classification/train.py index 12a582db3a..4aeb33019f 100644 --- a/image_classification/train.py +++ b/image_classification/train.py @@ -1,4 +1,6 @@ import gzip +import argparse + import paddle.v2.dataset.flowers as flowers import paddle.v2 as paddle import reader @@ -6,9 +8,9 @@ import resnet import alexnet import googlenet -import argparse +import inception_resnet_v2 -DATA_DIM = 3 * 224 * 224 +DATA_DIM = 3 * 224 * 224 # Use 3 * 331 * 331 or 3 * 299 * 299 for Inception-ResNet-v2. CLASS_DIM = 102 BATCH_SIZE = 128 @@ -19,7 +21,10 @@ def main(): parser.add_argument( 'model', help='The model for image classification', - choices=['alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet']) + choices=[ + 'alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet', + 'inception-resnet-v2' + ]) args = parser.parse_args() # PaddlePaddle init @@ -52,6 +57,10 @@ def main(): input=out2, label=lbl, coeff=0.3) paddle.evaluator.classification_error(input=out2, label=lbl) extra_layers = [loss1, loss2] + elif args.model == 'inception-resnet-v2': + assert DATA_DIM == 3 * 331 * 331 or DATA_DIM == 3 * 299 * 299 + out = inception_resnet_v2.inception_resnet_v2( + image, class_dim=CLASS_DIM, dropout_rate=0.5, data_dim=DATA_DIM) cost = paddle.layer.classification_cost(input=out, label=lbl) From 2076053b926ea41a6f4df8d1fdd4e8a9d86136cf Mon Sep 17 00:00:00 2001 From: guosheng Date: Wed, 20 Dec 2017 17:54:08 +0800 Subject: [PATCH 2/2] Refine the inception-resnet-v2 related --- README.cn.md | 10 ++++++---- README.md | 10 ++++++---- image_classification/README.md | 6 ++++-- image_classification/inception_resnet_v2.py | 2 ++ 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/README.cn.md b/README.cn.md index 9491690e3d..4a80f2632c 100644 --- a/README.cn.md +++ b/README.cn.md @@ -98,12 +98,14 @@ PaddlePaddle提供了丰富的运算单元,帮助大家以模块化的方式 图像相比文字能够提供更加生动、容易理解及更具艺术感的信息,是人们转递与交换信息的重要来源。图像分类是根据图像的语义信息对不同类别图像进行区分,是计算机视觉中重要的基础问题,也是图像检测、图像分割、物体跟踪、行为分析等其他高层视觉任务的基础,在许多领域都有着广泛的应用。如:安防领域的人脸识别和智能视频分析等,交通领域的交通场景识别,互联网领域基于内容的图像检索和相册自动归类,医学领域的图像识别等。 -在图像分类任务中,我们向大家介绍如何训练AlexNet、VGG、GoogLeNet和ResNet模型。同时提供了一个够将Caffe训练好的模型文件转换为PaddlePaddle模型文件的模型转换工具。 +在图像分类任务中,我们向大家介绍如何训练AlexNet、VGG、GoogLeNet、ResNet和Inception-Resnet-V2模型。同时提供了能够将Caffe或TensorFlow训练好的模型文件转换为PaddlePaddle模型文件的模型转换工具。 - 11.1 [将Caffe模型文件转换为PaddlePaddle模型文件](https://github.com/PaddlePaddle/models/tree/develop/image_classification/caffe2paddle) -- 11.2 [AlexNet](https://github.com/PaddlePaddle/models/tree/develop/image_classification) -- 11.3 [VGG](https://github.com/PaddlePaddle/models/tree/develop/image_classification) -- 11.4 [Residual Network](https://github.com/PaddlePaddle/models/tree/develop/image_classification) +- 11.2 [将TensorFlow模型文件转换为PaddlePaddle模型文件](https://github.com/PaddlePaddle/models/tree/develop/image_classification/tf2paddle) +- 11.3 [AlexNet](https://github.com/PaddlePaddle/models/tree/develop/image_classification) +- 11.4 [VGG](https://github.com/PaddlePaddle/models/tree/develop/image_classification) +- 11.5 [Residual Network](https://github.com/PaddlePaddle/models/tree/develop/image_classification) +- 11.6 [Inception-Resnet-V2](https://github.com/PaddlePaddle/models/tree/develop/image_classification) ## 12. 目标检测 diff --git a/README.md b/README.md index 8b938a30dc..473ee6167e 100644 --- a/README.md +++ b/README.md @@ -72,11 +72,13 @@ As an example for sequence-to-sequence learning, we take the machine translation ## 9. Image classification -For the example of image classification, we show you how to train AlexNet, VGG, GoogLeNet and ResNet models in PaddlePaddle. It also provides a model conversion tool that converts Caffe trained model files into PaddlePaddle model files. +For the example of image classification, we show you how to train AlexNet, VGG, GoogLeNet, ResNet and Inception-Resnet-V2 models in PaddlePaddle. It also provides model conversion tools that convert Caffe or TensorFlow trained model files into PaddlePaddle model files. - 9.1 [convert Caffe model file to PaddlePaddle model file](https://github.com/PaddlePaddle/models/tree/develop/image_classification/caffe2paddle) -- 9.2 [AlexNet](https://github.com/PaddlePaddle/models/tree/develop/image_classification) -- 9.3 [VGG](https://github.com/PaddlePaddle/models/tree/develop/image_classification) -- 9.4 [Residual Network](https://github.com/PaddlePaddle/models/tree/develop/image_classification) +- 9.2 [convert TensorFlow model file to PaddlePaddle model file](https://github.com/PaddlePaddle/models/tree/develop/image_classification/tf2paddle) +- 9.3 [AlexNet](https://github.com/PaddlePaddle/models/tree/develop/image_classification) +- 9.4 [VGG](https://github.com/PaddlePaddle/models/tree/develop/image_classification) +- 9.5 [Residual Network](https://github.com/PaddlePaddle/models/tree/develop/image_classification) +- 9.6 [Inception-Resnet-V2](https://github.com/PaddlePaddle/models/tree/develop/image_classification) This tutorial is contributed by [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) and licensed under the [Apache-2.0 license](LICENSE). diff --git a/image_classification/README.md b/image_classification/README.md index 6e23349200..40a4770f9c 100644 --- a/image_classification/README.md +++ b/image_classification/README.md @@ -32,7 +32,8 @@ paddle.init(use_gpu=False, trainer_count=1) 设置算法参数(如数据维度、类别数目和batch size等参数),定义数据输入层`image`和类别标签`lbl`。 ```python -DATA_DIM = 3 * 224 * 224 # Use 3 * 331 * 331 or 3 * 299 * 299 for Inception-ResNet-v2. +# Use 3 * 331 * 331 or 3 * 299 * 299 for DATA_DIM in Inception-ResNet-v2. +DATA_DIM = 3 * 224 * 224 CLASS_DIM = 102 BATCH_SIZE = 128 @@ -94,7 +95,8 @@ out = resnet.resnet_imagenet(image, class_dim=CLASS_DIM) 提供的Inception-ResNet-v2模型支持`3 * 331 * 331`和`3 * 299 * 299`两种大小的输入,同时可以自行设置dropout概率,可以通过如下的代码使用: ```python -out = inception_resnet_v2.inception_resnet_v2(image, class_dim=CLASS_DIM, dropout_rate=0.5, size=DATA_DIM) +out = inception_resnet_v2.inception_resnet_v2( + image, class_dim=CLASS_DIM, dropout_rate=0.5, size=DATA_DIM) ``` 注意,由于和其他几种模型输入大小不同,若配合提供的`reader.py`使用Inception-ResNet-v2时请先将`reader.py`中`paddle.image.simple_transform`中的参数为修改为相应大小。 diff --git a/image_classification/inception_resnet_v2.py b/image_classification/inception_resnet_v2.py index 6dd08c5e7a..cddd59ce4b 100644 --- a/image_classification/inception_resnet_v2.py +++ b/image_classification/inception_resnet_v2.py @@ -8,6 +8,7 @@ def conv_bn_layer(input, padding=0, active_type=paddle.activation.Relu(), ch_in=None): + """layer wrapper assembling convolution and batchnorm layer""" tmp = paddle.layer.img_conv( input=input, filter_size=filter_size, @@ -21,6 +22,7 @@ def conv_bn_layer(input, def sequential_block(input, *layers): + """helper function for sequential layers""" for layer in layers: layer_func, layer_conf = layer input = layer_func(input, **layer_conf)