Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update summarize.py to select different netstate options #4174

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 89 additions & 7 deletions tools/extra/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,89 @@
'444', '103;30', '107;30']
DISCONNECTED_COLOR = '41'

def read_net(filename):
def read_net(filename, phase=None, stages=None, level=None):
net = caffe_pb2.NetParameter()
with open(filename) as f:
protobuf.text_format.Parse(f.read(), net)
return net

if phase is not None or stages is not None or level is not None:
# Create the rule
rule = caffe_pb2.NetStateRule()
if phase.lower() == 'train':
protobuf.text_format.Merge('phase: TRAIN', rule)
elif phase.lower() == 'test':
protobuf.text_format.Merge('phase: TEST', rule)
if stages:
for stage in stages:
protobuf.text_format.Merge('stage: "%s"' % stage, rule)
if level is not None:
protobuf.text_format.Merge('level: %d' % level, rule)

print '>>> NetStateRule'
print protobuf.text_format.MessageToString(rule)

# Filter by the rule
layers = []
for layer in net.layer:
if layer_meets_rule(layer, rule):
layers.append(layer)
return layers

else:
return net.layer


def layer_meets_rule(layer, state):
"""
Returns True if this layer will be included in the given NetStateRule
Logic copied from Net::FilterNet()
"""
# If no include rules are specified, the layer is included by default and
# only excluded if it meets one of the exclude rules.
layer_included = len(layer.include) == 0

for exclude_rule in layer.exclude:
if state_meets_rule(state, exclude_rule):
layer_included = False
break

for include_rule in layer.include:
if state_meets_rule(state, include_rule):
layer_included = True
break

return layer_included


def state_meets_rule(state, rule):
"""
Returns True if the given state meets the given NetStateRule
Logic copied from Net::StateMeetsRule()
"""
if rule.HasField('phase'):
if rule.phase != state.phase:
return False

if rule.HasField('min_level'):
if state.level < rule.min_level:
return False

if rule.HasField('max_level'):
if state.level > rule.max_level:
return False

# The state must contain ALL of the rule's stages
for stage in rule.stage:
if stage not in state.stage:
return False

# The state must contain NONE of the rule's not_stages
for stage in rule.not_stage:
if stage in state.stage:
return False

return True


def format_param(param):
out = []
Expand Down Expand Up @@ -60,15 +138,15 @@ def print_table(table, max_width):
row_str += ' ' * max(right_col - printed_len(row_str), 0)
print row_str

def summarize_net(net):
def summarize_net(layers):
disconnected_tops = set()
for lr in net.layer:
for lr in layers:
disconnected_tops |= set(lr.top)
disconnected_tops -= set(lr.bottom)

table = []
colors = {}
for lr in net.layer:
for lr in layers:
tops = []
for ind, top in enumerate(lr.top):
color = colors.setdefault(top, COLORS[len(colors) % len(COLORS)])
Expand Down Expand Up @@ -130,10 +208,14 @@ def main():
parser.add_argument('filename', help='net prototxt file to summarize')
parser.add_argument('-w', '--max-width', help='maximum field width',
type=int, default=30)
parser.add_argument('--phase', help='NetState.phase')
parser.add_argument('--stage', help='NetState.stage',
nargs='*', dest='stages')
parser.add_argument('--level', type=int, help='NetState.level')
args = parser.parse_args()

net = read_net(args.filename)
table = summarize_net(net)
layers = read_net(args.filename, phase=args.phase, stages=args.stages, level=args.level)
table = summarize_net(layers)
print_table(table, max_width=args.max_width)

if __name__ == '__main__':
Expand Down