Skip to content

Commit

Permalink
Conv only for latency (PaddlePaddle#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
ceci3 authored and wanghaoshuang committed Jan 16, 2020
1 parent 50d69ec commit 76f0d7e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 12 deletions.
1 change: 0 additions & 1 deletion paddleslim/analysis/flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def _graph_flops(graph, only_conv=True, detail=False):
y_shape = op.inputs("Y")[0].shape()
if x_shape[0] == -1:
x_shape[0] = 1
flops += x_shape[0] * x_shape[1] * y_shape[1]

op_flops = x_shape[0] * x_shape[1] * y_shape[1]
flops += op_flops
Expand Down
21 changes: 10 additions & 11 deletions paddleslim/analysis/latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self):
def latency(self, graph):
pass

def _get_ops_from_graph(self, graph):
def _get_ops_from_graph(self, graph, only_conv):
assert isinstance(graph, GraphWrapper)
ops = []
i = 0
Expand All @@ -33,22 +33,20 @@ def _get_ops_from_graph(self, graph):
tmp = self._conv_op_args(op)
elif op.type() in [
'elementwise_add', 'elementwise_mul', 'elementwise_max'
]:
] and only_conv == False:
tmp = self._eltwise_op_args(op)
elif op.type() in [
'relu', 'prelu', 'sigmoid', 'relu6', 'elu', 'brelu',
'leaky_relu'
]:
] and only_conv == False:
tmp = self._activation_op_args(op)
elif op.type() == 'batch_norm':
elif op.type() == 'batch_norm' and only_conv == False:
tmp = self._batch_norm_op_args(op)
elif op.type() == 'pool2d':
elif op.type() == 'pool2d' and only_conv == False:
tmp = self._pooling_op_args(op)
elif op.type() == 'batch_norm':
tmp = self._batch_norm_op_args(op)
elif op.type() == 'softmax':
elif op.type() == 'softmax' and only_conv == False:
tmp = self._softmax_op_args(op)
elif op.type() == 'mul':
elif op.type() == 'mul' and only_conv == False:
tmp = self._fc_op_args(op)
else:
tmp = None
Expand Down Expand Up @@ -268,19 +266,20 @@ def _op_latency(self, op_str):
assert op_str in self._table
return self._table[op_str]

def latency(self, graph):
def latency(self, graph, only_conv=True):
"""
Get latency of target graph.
Args:
- graph(GrapWrapper | Program): The graph to be evaluated.
- only_conv(bool): only evaluated convolution layer if `only_conv` is true. Default: True.
Returns:
latency(float): The latency of given graph on current evaluator.
"""
total_latency = 0
if isinstance(graph, Program):
graph = GraphWrapper(graph)
assert isinstance(graph, GraphWrapper)
for op in self._get_ops_from_graph(graph):
for op in self._get_ops_from_graph(graph, only_conv):
total_latency += self._op_latency(
self._delimiter.join(map(lambda x: str(x), op)))
return total_latency

0 comments on commit 76f0d7e

Please sign in to comment.