Skip to content

Commit

Permalink
Update export.py
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Sep 18, 2021
1 parent 7bf68ad commit b191ace
Showing 1 changed file with 29 additions and 29 deletions.
58 changes: 29 additions & 29 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
# YOLOv5 TorchScript model export
try:
print(f'\n{prefix} starting export with torch {torch.__version__}...')
LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
f = file.with_suffix('.torchscript.pt')

ts = torch.jit.trace(model, im, strict=False)
(optimize_for_mobile(ts) if optimize else ts).save(f)

print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'{prefix} export failure: {e}')
LOGGER.info(f'{prefix} export failure: {e}')


def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
Expand All @@ -68,7 +68,7 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst
check_requirements(('onnx',))
import onnx

print(f'\n{prefix} starting export with onnx {onnx.__version__}...')
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
f = file.with_suffix('.onnx')

torch.onnx.export(model, im, f, verbose=False, opset_version=opset,
Expand All @@ -83,27 +83,27 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst
# Checks
model_onnx = onnx.load(f) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model
# print(onnx.helper.printable_graph(model_onnx.graph)) # print
# LOGGER.info(onnx.helper.printable_graph(model_onnx.graph)) # print

# Simplify
if simplify:
try:
check_requirements(('onnx-simplifier',))
import onnxsim

print(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
model_onnx, check = onnxsim.simplify(
model_onnx,
dynamic_input_shape=dynamic,
input_shapes={'images': list(im.shape)} if dynamic else None)
assert check, 'assert check failed'
onnx.save(model_onnx, f)
except Exception as e:
print(f'{prefix} simplifier failure: {e}')
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
print(f"{prefix} run --dynamic ONNX model inference with: 'python detect.py --weights {f}'")
LOGGER.info(f'{prefix} simplifier failure: {e}')
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
LOGGER.info(f"{prefix} run --dynamic ONNX model inference with: 'python detect.py --weights {f}'")
except Exception as e:
print(f'{prefix} export failure: {e}')
LOGGER.info(f'{prefix} export failure: {e}')


def export_coreml(model, im, file, prefix=colorstr('CoreML:')):
Expand All @@ -113,17 +113,17 @@ def export_coreml(model, im, file, prefix=colorstr('CoreML:')):
check_requirements(('coremltools',))
import coremltools as ct

print(f'\n{prefix} starting export with coremltools {ct.__version__}...')
LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
f = file.with_suffix('.mlmodel')

model.train() # CoreML exports should be placed in model.train() mode
ts = torch.jit.trace(model, im, strict=False) # TorchScript model
ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255.0, bias=[0, 0, 0])])
ct_model.save(f)

print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'\n{prefix} export failure: {e}')
LOGGER.info(f'\n{prefix} export failure: {e}')

return ct_model

Expand All @@ -138,7 +138,7 @@ def export_saved_model(model, im, file, dynamic,
from tensorflow import keras
from models.tf import TFModel, TFDetect

print(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
f = str(file).replace('.pt', '_saved_model')
batch_size, ch, *imgsz = list(im.shape) # BCHW

Expand All @@ -152,9 +152,9 @@ def export_saved_model(model, im, file, dynamic,
keras_model.summary()
keras_model.save(f, save_format='tf')

print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'\n{prefix} export failure: {e}')
LOGGER.info(f'\n{prefix} export failure: {e}')

return keras_model

Expand All @@ -165,7 +165,7 @@ def export_pb(keras_model, im, file, prefix=colorstr('TensorFlow GraphDef:')):
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

print(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
f = file.with_suffix('.pb')

m = tf.function(lambda x: keras_model(x)) # full model
Expand All @@ -174,9 +174,9 @@ def export_pb(keras_model, im, file, prefix=colorstr('TensorFlow GraphDef:')):
frozen_func.graph.as_graph_def()
tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)

print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'\n{prefix} export failure: {e}')
LOGGER.info(f'\n{prefix} export failure: {e}')


def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('TensorFlow Lite:')):
Expand All @@ -185,7 +185,7 @@ def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('Te
import tensorflow as tf
from models.tf import representative_dataset_gen

print(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
batch_size, ch, *imgsz = list(im.shape) # BCHW
f = str(file).replace('.pt', '-fp16.tflite')

Expand All @@ -205,10 +205,10 @@ def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('Te

tflite_model = converter.convert()
open(f, "wb").write(tflite_model)
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')

except Exception as e:
print(f'\n{prefix} export failure: {e}')
LOGGER.info(f'\n{prefix} export failure: {e}')


def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
Expand All @@ -217,17 +217,17 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
check_requirements(('tensorflowjs',))
import tensorflowjs as tfjs

print(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
f = str(file).replace('.pt', '_web_model') # js dir
f_pb = file.with_suffix('.pb') # *.pb path

cmd = f"tensorflowjs_converter --input_format=tf_frozen_model " \
f"--output_node_names='Identity,Identity_1,Identity_2,Identity_3' {f_pb} {f}"
subprocess.run(cmd, shell=True)

print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'\n{prefix} export failure: {e}')
LOGGER.info(f'\n{prefix} export failure: {e}')


@torch.no_grad()
Expand Down Expand Up @@ -278,7 +278,7 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'

for _ in range(2):
y = model(im) # dry runs
print(f"\n{colorstr('PyTorch:')} starting from {file} ({file_size(file):.1f} MB)")
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} ({file_size(file):.1f} MB)")

# Exports
if 'torchscript' in include:
Expand All @@ -301,9 +301,9 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
export_tfjs(model, im, file)

# Finish
print(f'\nExport complete ({time.time() - t:.2f}s)'
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
f'\nVisualize with https://netron.app')
LOGGER.info(f'\nExport complete ({time.time() - t:.2f}s)'
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
f'\nVisualize with https://netron.app')


def parse_opt():
Expand Down

0 comments on commit b191ace

Please sign in to comment.