diff --git a/python/paddle/fluid/contrib/slim/core/compress_pass.py b/python/paddle/fluid/contrib/slim/core/compress_pass.py index ed9fce12b6f7b..de14d7fcec430 100644 --- a/python/paddle/fluid/contrib/slim/core/compress_pass.py +++ b/python/paddle/fluid/contrib/slim/core/compress_pass.py @@ -13,6 +13,7 @@ # limitations under the License. from ....core import CPUPlace +from .... import io from ....data_feeder import DataFeeder from ..graph import get_executor, ImitationGraph from config import ConfigFactory @@ -41,6 +42,7 @@ def __init__(self, train_reader=None, eval_graph=None, eval_reader=None, + teacher_graphs=None, optimizer=None): # The total number of epoches to be trained. self.epoch = 0 @@ -57,6 +59,7 @@ def __init__(self, self.eval_graph = eval_graph self.eval_reader = eval_reader self.executor = None + self.teacher_graphs = teacher_graphs self.optimizer = optimizer def run_eval_graph(self): @@ -159,8 +162,9 @@ def config(self, config_file): def _load_checkpoint(self, context): if self.checkpoint: - exe = get_executor(context.train_graph, parallel=False) - fluid.io.load_persistables( + exe = get_executor( + context.train_graph, context.place, parallel=False) + io.load_persistables( exe.exe, self.checkpoint, main_program=context.train_graph.program) @@ -168,12 +172,13 @@ def _load_checkpoint(self, context): def _save_checkpoint(self, context): if context.epoch_id % 5 == 0 and self.model_save_dir: - model_path = os.path.join(self.model_save_dir, - str(context.epoch_id)) + model_path = os.path.join( + self.model_save_dir, + str(context.epoch_id) + "_" + str(context.batch_id)) if not os.path.isdir(model_path): os.makedirs(model_path) - exe = get_executor(context.train_graph, parallel=False) - fluid.io.save_persistables( + exe = get_executor(context.train_graph, context.place, False) + io.save_persistables( exe.exe, model_path, main_program=context.train_graph.program) print('Saved checkpoint to: {}'.format(model_path)) @@ -209,6 +214,7 @@ def run(self): train_reader=self.train_reader, eval_graph=self.eval_graph, eval_reader=self.eval_reader, + teacher_graphs=self.teacher_graphs, optimizer=self.optimizer) self._load_checkpoint(context) diff --git a/python/paddle/fluid/contrib/slim/demo/distillation/compress.py b/python/paddle/fluid/contrib/slim/demo/distillation/compress.py new file mode 100644 index 0000000000000..31c6ff98474c1 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/demo/distillation/compress.py @@ -0,0 +1,91 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import paddle +import os +import sys +from resnet import * +from paddle.fluid.contrib.slim import CompressPass +from paddle.fluid.contrib.slim import build_compressor +from paddle.fluid.contrib.slim import ImitationGraph + + +class Model(object): + def __init__(slef): + pass + + def compress(self): + + img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + resnet50 = ResNet50() + predict = resnet50.net(img, class_dim=10) + eval_program = fluid.default_main_program().clone(for_test=False) + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(cost) + + with fluid.program_guard(main_program=eval_program): + acc = fluid.layers.accuracy(input=predict, label=label) + + optimizer = fluid.optimizer.SGD(0.001) + optimizer.minimize(avg_cost) + + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.mnist.train(), buf_size=500), + batch_size=32) + eval_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=1) + + train_feed_list = {'img': img.name, 'label': label.name} + train_fetch_list = {'cost': avg_cost.name} + eval_feed_list = {'img': img.name, 'label': label.name} + eval_fetch_list = {'acc': acc.name} + + # define teacher program + teacher_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(teacher_program, startup_program): + img = fluid.layers.data( + name='img', shape=[1, 28, 28], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + resnet101 = ResNet101() + predict = resnet101.net(img, class_dim=10) + exe.run(startup_program) + + com_pass = CompressPass( + place, + fluid.global_scope(), + fluid.default_main_program(), + train_reader=train_reader, + train_feed_list=train_feed_list, + train_fetch_list=train_fetch_list, + eval_program=eval_program, + eval_reader=eval_reader, + eval_feed_list=eval_feed_list, + eval_fetch_list=eval_fetch_list, + teacher_programs=[teacher_program], + optimizer=optimizer) + com_pass.model_save_dir = './checkpoints' + com_pass.config('./config.yaml') + com_pass.run() + + +if __name__ == "__main__": + model = Model() + model.compress() diff --git a/python/paddle/fluid/contrib/slim/demo/distillation/config.yaml b/python/paddle/fluid/contrib/slim/demo/distillation/config.yaml new file mode 100644 index 0000000000000..78b0bc6cb0eb0 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/demo/distillation/config.yaml @@ -0,0 +1,14 @@ +version: 1.0 +distillers: + fsp_distiller: + class: 'FSPDistiller' +strategies: + fsp_distillation_strategy: + class: 'FSPDistillationStrategy' + distiller: 'fsp_distiller' + start_epoch: 0 + end_epoch: 10 +compress_pass: + epoch: 10 + strategies: + - fsp_distillation_strategy diff --git a/python/paddle/fluid/contrib/slim/demo/distillation/resnet.py b/python/paddle/fluid/contrib/slim/demo/distillation/resnet.py new file mode 100644 index 0000000000000..26c6ca143cdd6 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/demo/distillation/resnet.py @@ -0,0 +1,137 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle +import paddle.fluid as fluid +import math + +__all__ = ["ResNet", "ResNet50", "ResNet101", "ResNet152"] + +train_parameters = { + "input_size": [3, 224, 224], + "input_mean": [0.485, 0.456, 0.406], + "input_std": [0.229, 0.224, 0.225], + "learning_strategy": { + "name": "piecewise_decay", + "batch_size": 256, + "epochs": [30, 60, 90], + "steps": [0.1, 0.01, 0.001, 0.0001] + } +} + + +class ResNet(): + def __init__(self, layers=50): + self.params = train_parameters + self.layers = layers + + def net(self, input, class_dim=1000): + layers = self.layers + supported_layers = [50, 101, 152] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format(supported_layers, layers) + + if layers == 50: + depth = [3, 4, 6, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + num_filters = [64, 128, 256, 512] + + conv = self.conv_bn_layer( + input=input, num_filters=64, filter_size=7, stride=2, act='relu') + conv = fluid.layers.pool2d( + input=conv, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + + for block in range(len(depth)): + for i in range(depth[block]): + conv = self.bottleneck_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1) + + pool = fluid.layers.pool2d( + input=conv, pool_size=7, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + out = fluid.layers.fc(input=pool, + size=class_dim, + act='softmax', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, + stdv))) + return out + + def conv_bn_layer(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + bias_attr=False) + return fluid.layers.batch_norm(input=conv, act=act) + + def shortcut(self, input, ch_out, stride): + ch_in = input.shape[1] + if ch_in != ch_out or stride != 1: + return self.conv_bn_layer(input, ch_out, 1, stride) + else: + return input + + def bottleneck_block(self, input, num_filters, stride): + conv0 = self.conv_bn_layer( + input=input, num_filters=num_filters, filter_size=1, act='relu') + conv1 = self.conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + stride=stride, + act='relu') + conv2 = self.conv_bn_layer( + input=conv1, num_filters=num_filters * 4, filter_size=1, act=None) + + short = self.shortcut(input, num_filters * 4, stride) + + return fluid.layers.elementwise_add(x=short, y=conv2, act='relu') + + +def ResNet50(): + model = ResNet(layers=50) + return model + + +def ResNet101(): + model = ResNet(layers=101) + return model + + +def ResNet152(): + model = ResNet(layers=152) + return model diff --git a/python/paddle/fluid/contrib/slim/distillation/distillation_strategy.py b/python/paddle/fluid/contrib/slim/distillation/distillation_strategy.py new file mode 100644 index 0000000000000..3359a15f00361 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/distillation/distillation_strategy.py @@ -0,0 +1,41 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..core.strategy import Strategy +from ....framework import Program, program_guard, Parameter +from .... import layers +import numpy as np +import copy +import re + +__all__ = ['FSPDistillationStrategy'] + + +class FSPDistillationStrategy(Strategy): + def __init__(self, distiller=None, start_epoch=0, end_epoch=10): + super(FSPDistillationStrategy, self).__init__(start_epoch, end_epoch) + self.distiller = distiller + self.train_graph_backup = None + + def on_epoch_begin(self, context): + if self.start_epoch == context.epoch_id: + self.train_graph_backup = context.train_graph + graph = self.distiller.distiller_graph( + context.eval_graph, context.teacher_graphs, context.optimizer, + context.place) + context.train_graph = graph + + def on_epoch_end(self, context): + if context.epoch_id == (self.end_epoch - 1): + context.train_graph = self.train_graph_backup diff --git a/python/paddle/fluid/contrib/slim/distillation/distiller.py b/python/paddle/fluid/contrib/slim/distillation/distiller.py new file mode 100644 index 0000000000000..e9a394de571d6 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/distillation/distiller.py @@ -0,0 +1,121 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .... import layers +from .... import Executor +from .... import Program +from .... import program_guard + +__all__ = ['FSPDistiller'] + + +class FSPDistiller(object): + def __init__(self, student_pairs=None, teacher_pairs=None): + self.student_pairs = student_pairs + self.teacher_pairs = teacher_pairs + + def _feature_map_pairs(self, graph): + pairs = [] + sizes = [] + pair = [] + pre_size = None + for op in graph.all_ops(): + if op.type == 'conv2d': + out_var_name = op.output('Output')[0] + feature_map_size = graph.get_var(out_var_name).shape[2:] + if feature_map_size != pre_size: + if len(pair) == 2 and pair[1] != None: + pairs.append(pair) + sizes.append(pre_size) + pair = [out_var_name, None] + else: + pair[1] = out_var_name + pre_size = feature_map_size + if len(pair) == 2 and pair[1] != None: + pairs.append(pair) + sizes.append(pre_size) + return pairs, sizes + + def distiller_graph(self, student, teachers, optimizer, place): + """ + Generate distillation training graph. + """ + teacher = teachers[0] + for var in teacher.program.list_vars(): + var.stop_gradient = True + # step 1: merge student and teacher into graph + graph = student.clone() + graph.merge(teacher) + if not self.student_pairs: + self.student_pairs = self._feature_map_pairs(student) + if not self.teacher_pairs: + self.teacher_pairs = self._feature_map_pairs(teacher) + # step 2: add fsp loss and backward ops + distiller_pass = FSPDistillerPass(self.student_pairs, + self.teacher_pairs, optimizer, place) + dis_graph = distiller_pass.apply(graph) + return dis_graph + + +class FSPDistillerPass(object): + ''' + Convert graph to fsp distillation training graph + by adding fsp loss and backward operators. + ''' + + def __init__(self, s_pairs, t_pairs, optimizer, place): + self.s_pairs = s_pairs + self.t_pairs = t_pairs + self.optimizer = optimizer + self.place = place + + def apply(self, graph): + ret_graph = graph.clone() + startup_program = Program() + with program_guard(ret_graph.program, startup_program): + losses = [] + for s_pair, t_pair in zip(self.s_pairs, self.t_pairs): + s_pair_start = ret_graph.get_var(s_pair[0]) + s_pair_end = ret_graph.get_var(s_pair[1]) + s_fsp_matrix = self._fsp_matrix(s_pair_start, s_pair_end) + t_pair_start = ret_graph.get_var(t_pair[0]) + t_pair_end = ret_graph.get_var(t_pair[1]) + t_fsp_matrix = self._fsp_matrix(t_pair_start, t_pair_end) + l2_loss = layers.mean( + layers.square(s_fsp_matrix - t_fsp_matrix)) + losses.append(l2_loss) + loss = layers.sum(losses) + self.optimizer.minimize(loss) + + exe = Executor(self.place) + # init variable created when append backward ops. Such as leaning rate + # and accumulators in some optimizer. + exe.run(startup_program, scope=ret_graph.scope) + ret_graph.out_nodes['loss'] = loss.name + return ret_graph + + def fsp_matrix(self, fea_map_0, fea_map_1): + a_channel = fea_map_0.shape[1] + b_channel = fea_map_1.shape[1] + h = fea_map_0.shape[2] + w = fea_map_0.shape[3] + tmp_0 = layers.transpose(fea_map_0, perm=[0, 2, 3, 1]) + tmp_0 = layers.reshape(tmp_0, [-1, h * w, 1, a_channel]) + tmp_0 = layers.expand(tmp_0, expand_times=[1, 1, b_channel, 1]) + tmp_0 = layers.transpose(tmp_0, perm=[0, 1, 3, 2]) + + tmp_1 = layers.transpose(fea_map_1, perm=[0, 2, 3, 1]) + tmp_1 = layers.reshape(tmp_1, [-1, h * w, 1, b_channel]) + tmp_1 = layers.expand(tmp_1, expand_times=[1, 1, a_channel, 1]) + return layers.reduce_mean(tmp_0 * tmp_1, dim=1) diff --git a/python/paddle/fluid/contrib/slim/graph/graph.py b/python/paddle/fluid/contrib/slim/graph/graph.py index c726a7a590cc9..63fc5ac192410 100644 --- a/python/paddle/fluid/contrib/slim/graph/graph.py +++ b/python/paddle/fluid/contrib/slim/graph/graph.py @@ -251,6 +251,28 @@ def clone(self, for_test=False): self.program.clone(for_test), self.scope, copy.deepcopy(self.in_nodes), copy.deepcopy(self.out_nodes)) + def merge(self, graph): + for var in graph.program.list_vars(): + self.program.global_block()._clone_variable(var) + for op in graph.all_ops(): + inputs = {} + outputs = {} + attrs = {} + for input_name in op.input_names: + inputs[input_name] = [ + self.get_var(in_var_name) + for in_var_name in op.input(input_name) + ] + for output_name in op.output_names: + outputs[output_name] = [ + self.get_var(out_var_name) + for out_var_name in op.output(output_name) + ] + for attr_name in op.attr_names: + attrs[attr_name] = op.attr(attr_name) + self.program.global_block().append_op( + type=op.type, inputs=inputs, outputs=outputs, attrs=attrs) + def ops(self): return self.program.global_block().ops diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 8d061f41f09a8..66281f7007a1a 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1487,7 +1487,7 @@ def _clone_variable(self, var): shape=var.shape, dtype=var.dtype, type=var.type, - persistable=True, + persistable=var.persistable, is_data=var.is_data) else: ret_var = self.create_var( @@ -1496,7 +1496,8 @@ def _clone_variable(self, var): dtype=var.dtype, type=var.type, lod_level=var.lod_level, - persistable=True, + persistable=var.persistable, + stop_gradient=var.stop_gradient, is_data=var.is_data) return ret_var