Skip to content

Commit

Permalink
Add default walker to handle unregister operator in pruning. (PaddleP…
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghaoshuang authored Apr 9, 2020
1 parent 84d7653 commit 502f761
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 13 deletions.
3 changes: 0 additions & 3 deletions paddleslim/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ def get(self, key):
return self._module_dict.get(key, None)

def _register_module(self, module_class):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, but receive {}.'.format(
type(module_class)))
module_name = module_class.__name__
if module_name in self._module_dict:
raise KeyError('{} is already registered in {}.'.format(
Expand Down
4 changes: 1 addition & 3 deletions paddleslim/prune/idx_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,11 @@ def default_idx_selector(group, ratio):
list: pruned indexes
"""
assert (isinstance(graph, GraphWrapper))
name, axis, score = group[
0] # sort channels by the first convolution's score
sorted_idx = score.argsort()

pruned_num = len(sorted_idx) * ratio
pruned_num = int(round(len(sorted_idx) * ratio))
pruned_idx = sorted_idx[:pruned_num]

idxs = []
Expand Down Expand Up @@ -94,7 +93,6 @@ def optimal_threshold(group, ratio):
list: pruned indexes
"""
assert (isinstance(graph, GraphWrapper))
name, axis, score = group[
0] # sort channels by the first convolution's score

Expand Down
29 changes: 27 additions & 2 deletions paddleslim/prune/prune_walker.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,11 @@ def _prune_op(self, op, var, pruned_axis, pruned_idx, visited=None):
if visited is not None:
self.visited = visited
cls = PRUNE_WORKER.get(op.type())
assert cls is not None, "The walker of {} is not registered.".format(
op.type())
if cls is None:
_logger.warn(
"{} op will be pruned by default walker to keep the shapes of input and output being same because its walker is not registered.".
format(op.type()))
cls = PRUNE_WORKER.get("default_walker")
_logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}".format(
self.op, op, pruned_axis, var.name()))
walker = cls(op,
Expand Down Expand Up @@ -236,6 +239,7 @@ def __init__(self, op, pruned_params, visited):
super(elementwise_mul, self).__init__(op, pruned_params, visited)


@PRUNE_WORKER.register
class activation(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(activation, self).__init__(op, pruned_params, visited)
Expand All @@ -256,6 +260,27 @@ def _prune(self, var, pruned_axis, pruned_idx):
self._prune_op(op, out_var, pruned_axis, pruned_idx)


@PRUNE_WORKER.register
class default_walker(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(default_walker, self).__init__(op, pruned_params, visited)

def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.all_outputs():
for in_var in self.op.inputs():
if len(in_var.shape()) == len(var.shape()):
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)

for out_var in self.op.all_outputs():
if len(out_var.shape()) == len(var.shape()):
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)


@PRUNE_WORKER.register
class uniform_random_batch_size_like(activation):
def __init__(self, op, pruned_params, visited):
Expand Down
8 changes: 3 additions & 5 deletions paddleslim/prune/pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ class Pruner():
def __init__(self,
criterion="l1_norm",
idx_selector="default_idx_selector"):
self.criterion = criterion
self.channel_sortor = channel_sortor
if isinstance(criterion, str):
self.criterion = CRITERION.get(criterion)
else:
Expand Down Expand Up @@ -98,7 +96,7 @@ def prune(self,
param_v = graph.var(param)
pruned_num = int(round(param_v.shape()[0] * ratio))
pruned_idx = [0] * pruned_num
for name, aixs in group:
for name, axis in group:
pruned_params.append((name, axis, pruned_idx))

else:
Expand All @@ -109,10 +107,10 @@ def prune(self,
values = np.array(scope.find_var(name).get_tensor())
group_values.append((name, values, axis))

scores = self.criterion(group_with_values,
scores = self.criterion(group_values,
graph) # [(name, axis, score)]

pruned_params = self.idx_selector(scores)
pruned_params.extend(self.idx_selector(scores, ratio))

merge_pruned_params = {}
for param, pruned_axis, pruned_idx in pruned_params:
Expand Down

0 comments on commit 502f761

Please sign in to comment.