Skip to content

Commit

Permalink
Fix MaskRCNN conversion bug (#627)
Browse files Browse the repository at this point in the history
Signed-off-by: Tyler Zhu <tylerz@nvidia.com>

Co-authored-by: Tyler Zhu <tylerz@nvidia.com>
  • Loading branch information
Tyler-D and Tyler-D authored Jun 18, 2020
1 parent 81448ca commit 2b8863d
Showing 1 changed file with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
From 00cfd7d7ce323df1f71048d959681349ca4fed78 Mon Sep 17 00:00:00 2001
From 9d7ef2f151d7488ddd06e06924dd04e27c9f92b0 Mon Sep 17 00:00:00 2001
From: Nine Feng <nfeng@nvidia.com>
Date: Wed, 24 Jul 2019 09:22:44 +0800
Subject: [PATCH] Update the Mask_RCNN model from NHWC to NCHW
Expand All @@ -17,7 +17,7 @@ Subject: [PATCH] Update the Mask_RCNN model from NHWC to NCHW
2 files changed, 40 insertions(+), 27 deletions(-)

diff --git a/mrcnn/model.py b/mrcnn/model.py
index 62cb2b0..1508f2c 100644
index 67c2e4f..d297ae8 100644
--- a/mrcnn/model.py
+++ b/mrcnn/model.py
@@ -110,17 +110,17 @@ def identity_block(input_tensor, kernel_size, filters, stage, block,
Expand Down Expand Up @@ -116,23 +116,23 @@ index 62cb2b0..1508f2c 100644


############################################################
@@ -853,8 +864,9 @@ def rpn_graph(feature_map, anchors_per_location, anchor_stride):
@@ -853,8 +864,8 @@ def rpn_graph(feature_map, anchors_per_location, anchor_stride):
activation='linear', name='rpn_class_raw')(shared)

# Reshape to [batch, anchors, 2]
- rpn_class_logits = KL.Lambda(
- lambda t: tf.reshape(t, [tf.shape(t)[0], -1, 2]))(x)
+ x = KL.Permute((2,3,1))(x)
+ rpn_class_logits = KL.Reshape((-1, 2))(x)
+ x = KL.Permute((2,3,1))(x)

# Softmax on last dimension of BG/FG.
rpn_probs = KL.Activation(
@@ -866,7 +878,7 @@ def rpn_graph(feature_map, anchors_per_location, anchor_stride):
@@ -866,7 +877,8 @@ def rpn_graph(feature_map, anchors_per_location, anchor_stride):
activation='linear', name='rpn_bbox_pred')(shared)

# Reshape to [batch, anchors, 4]
- rpn_bbox = KL.Lambda(lambda t: tf.reshape(t, [tf.shape(t)[0], -1, 4]))(x)
+ x = KL.Permute((2,3,1))(x)
+ rpn_bbox = KL.Reshape((-1, 4))(x)

return [rpn_class_logits, rpn_probs, rpn_bbox]
Expand Down

0 comments on commit 2b8863d

Please sign in to comment.