Skip to content

Commit

Permalink
Add NHWC TODO (PaddlePaddle#96)
Browse files Browse the repository at this point in the history
* Add TODO LIST
  • Loading branch information
GuoxiaWang authored Nov 19, 2021
1 parent 40f4612 commit 3a09f50
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
4 changes: 3 additions & 1 deletion scripts/perf_runner.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ fi

if [ $dtype = "fp16" ]; then
fp16=True
data_format=NHWC
data_format=NCHW
# TODO(GuoxiaWang): remove NCHW when PRelu support NHWC
# data_format=NHWC
else
fp16=False
data_format=NCHW
Expand Down
23 changes: 13 additions & 10 deletions static/backbones/iresnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self,
units = [3, 8, 36, 3]
filter_list = [64, 64, 128, 256, 512]
num_stages = 4

if data_format == 'NHWC':
image = paddle.tensor.transpose(image, [0, 2, 3, 1])

Expand All @@ -86,6 +86,7 @@ def __init__(self,
momentum=0.9,
data_layout=data_format,
is_test=False if is_train else True)
# TODO(GuoxiaWang): add data_format attr
input_blob = paddle.static.nn.prelu(
input_blob,
mode="channel",
Expand All @@ -95,18 +96,20 @@ def __init__(self,
for i in range(num_stages):
for j in range(units[i]):
input_blob = self.residual_unit_v3(
input_blob,
filter_list[i + 1],
3,
2 if j == 0 else 1,
1,
is_train, data_format)
input_blob, filter_list[i + 1], 3, 2
if j == 0 else 1, 1, is_train, data_format)
fc1 = self.get_fc1(input_blob, is_train, dropout, data_format)

self.output_dict['feature'] = fc1

def residual_unit_v3(self, in_data, num_filter, filter_size, stride, pad,
is_train, data_format="NCHW"):
def residual_unit_v3(self,
in_data,
num_filter,
filter_size,
stride,
pad,
is_train,
data_format="NCHW"):

bn1 = paddle.static.nn.batch_norm(
input=in_data,
Expand All @@ -132,7 +135,7 @@ def residual_unit_v3(self, in_data, num_filter, filter_size, stride, pad,
momentum=0.9,
data_layout=data_format,
is_test=False if is_train else True)
# prelu = paddle.nn.functional.relu6(bn2)
# TODO(GuoxiaWang): add data_format attr
prelu = paddle.static.nn.prelu(
bn2,
mode="channel",
Expand Down

0 comments on commit 3a09f50

Please sign in to comment.