Skip to content

Commit

Permalink
Add the PyTorch passes to the tflite converter.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 662215479
  • Loading branch information
majiddadashi authored and copybara-github committed Aug 13, 2024
1 parent f084804 commit f062507
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions ai_edge_torch/lowertools/odml_torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import torch

from tensorflow.compiler.tf2xla.python import xla as tfxla
from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb

MlirBundle = odml_torch.export.MlirLowered

Expand Down Expand Up @@ -162,7 +163,9 @@ def merged_bundle_to_tfl_model(
)

converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
converter._set_original_model_type(conversion_metadata_fb.ModelType.PYTORCH)
converter._experimental_enable_composite_direct_lowering = True
converter.model_origin_framework = "PYTORCH"

conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)

Expand Down

0 comments on commit f062507

Please sign in to comment.