Skip to content

Commit

Permalink
allow security segmentation models to be used in auto annotation (cva…
Browse files Browse the repository at this point in the history
  • Loading branch information
benhoff authored and Chris Lee-Messer committed Mar 5, 2020
1 parent c043228 commit 3826bfd
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions cvat/apps/auto_annotation/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,19 @@ def __init__(self, model, weights):

iter_inputs = iter(network.inputs)
self._input_blob_name = next(iter_inputs)
self._input_info_name = ''
self._output_blob_name = next(iter(network.outputs))

self._require_image_info = False

info_names = ('image_info', 'im_info')

# NOTE: handeling for the inclusion of `image_info` in OpenVino2019
if 'image_info' in network.inputs:
if any(s in network.inputs for s in info_names):
self._require_image_info = True
if self._input_blob_name == 'image_info':
self._input_info_name = set(network.inputs).intersection(info_names)
self._input_info_name = self._input_info_name.pop()
if self._input_blob_name in info_names:
self._input_blob_name = next(iter_inputs)

self._net = plugin.load(network=network, num_requests=2)
Expand All @@ -56,7 +61,7 @@ def infer(self, image):
info[0, 1] = w
# frame number
info[0, 2] = 1
inputs['image_info'] = info
inputs[self._input_info_name] = info

results = self._net.infer(inputs)
if len(results) == 1:
Expand Down

0 comments on commit 3826bfd

Please sign in to comment.