diff --git a/paddleslim/core/registry.py b/paddleslim/core/registry.py index b746e5089f4c3..8d222cf0199b9 100644 --- a/paddleslim/core/registry.py +++ b/paddleslim/core/registry.py @@ -25,9 +25,6 @@ def get(self, key): return self._module_dict.get(key, None) def _register_module(self, module_class): - if not inspect.isclass(module_class): - raise TypeError('module must be a class, but receive {}.'.format( - type(module_class))) module_name = module_class.__name__ if module_name in self._module_dict: raise KeyError('{} is already registered in {}.'.format( diff --git a/paddleslim/prune/idx_selector.py b/paddleslim/prune/idx_selector.py index b17348ea6f3e8..58cf1111aad35 100644 --- a/paddleslim/prune/idx_selector.py +++ b/paddleslim/prune/idx_selector.py @@ -52,12 +52,11 @@ def default_idx_selector(group, ratio): list: pruned indexes """ - assert (isinstance(graph, GraphWrapper)) name, axis, score = group[ 0] # sort channels by the first convolution's score sorted_idx = score.argsort() - pruned_num = len(sorted_idx) * ratio + pruned_num = int(round(len(sorted_idx) * ratio)) pruned_idx = sorted_idx[:pruned_num] idxs = [] @@ -94,7 +93,6 @@ def optimal_threshold(group, ratio): list: pruned indexes """ - assert (isinstance(graph, GraphWrapper)) name, axis, score = group[ 0] # sort channels by the first convolution's score diff --git a/paddleslim/prune/prune_walker.py b/paddleslim/prune/prune_walker.py index 09bb547e6bf81..6e85be11a3f53 100644 --- a/paddleslim/prune/prune_walker.py +++ b/paddleslim/prune/prune_walker.py @@ -71,8 +71,11 @@ def _prune_op(self, op, var, pruned_axis, pruned_idx, visited=None): if visited is not None: self.visited = visited cls = PRUNE_WORKER.get(op.type()) - assert cls is not None, "The walker of {} is not registered.".format( - op.type()) + if cls is None: + _logger.warn( + "{} op will be pruned by default walker to keep the shapes of input and output being same because its walker is not registered.". + format(op.type())) + cls = PRUNE_WORKER.get("default_walker") _logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}".format( self.op, op, pruned_axis, var.name())) walker = cls(op, @@ -236,6 +239,7 @@ def __init__(self, op, pruned_params, visited): super(elementwise_mul, self).__init__(op, pruned_params, visited) +@PRUNE_WORKER.register class activation(PruneWorker): def __init__(self, op, pruned_params, visited): super(activation, self).__init__(op, pruned_params, visited) @@ -256,6 +260,27 @@ def _prune(self, var, pruned_axis, pruned_idx): self._prune_op(op, out_var, pruned_axis, pruned_idx) +@PRUNE_WORKER.register +class default_walker(PruneWorker): + def __init__(self, op, pruned_params, visited): + super(default_walker, self).__init__(op, pruned_params, visited) + + def _prune(self, var, pruned_axis, pruned_idx): + if var in self.op.all_outputs(): + for in_var in self.op.inputs(): + if len(in_var.shape()) == len(var.shape()): + pre_ops = in_var.inputs() + for op in pre_ops: + self._prune_op(op, in_var, pruned_axis, pruned_idx) + + for out_var in self.op.all_outputs(): + if len(out_var.shape()) == len(var.shape()): + self._visit(out_var, pruned_axis) + next_ops = out_var.outputs() + for op in next_ops: + self._prune_op(op, out_var, pruned_axis, pruned_idx) + + @PRUNE_WORKER.register class uniform_random_batch_size_like(activation): def __init__(self, op, pruned_params, visited): diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py index 317c5a9c91407..8169c56ba30cb 100644 --- a/paddleslim/prune/pruner.py +++ b/paddleslim/prune/pruner.py @@ -41,8 +41,6 @@ class Pruner(): def __init__(self, criterion="l1_norm", idx_selector="default_idx_selector"): - self.criterion = criterion - self.channel_sortor = channel_sortor if isinstance(criterion, str): self.criterion = CRITERION.get(criterion) else: @@ -98,7 +96,7 @@ def prune(self, param_v = graph.var(param) pruned_num = int(round(param_v.shape()[0] * ratio)) pruned_idx = [0] * pruned_num - for name, aixs in group: + for name, axis in group: pruned_params.append((name, axis, pruned_idx)) else: @@ -109,10 +107,10 @@ def prune(self, values = np.array(scope.find_var(name).get_tensor()) group_values.append((name, values, axis)) - scores = self.criterion(group_with_values, + scores = self.criterion(group_values, graph) # [(name, axis, score)] - pruned_params = self.idx_selector(scores) + pruned_params.extend(self.idx_selector(scores, ratio)) merge_pruned_params = {} for param, pruned_axis, pruned_idx in pruned_params: