Skip to content

Commit

Permalink
add export script
Browse files Browse the repository at this point in the history
  • Loading branch information
WT-MM committed Nov 24, 2024
1 parent a1bb322 commit 9a14be4
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
18 changes: 15 additions & 3 deletions sim/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@

from sim.env import run_dir # noqa: E402
from sim.envs import task_registry # noqa: E402
from sim.model_export import ActorCfg, convert_model_to_onnx # noqa: E402
from sim.model_export import ActorCfg, convert_model_to_onnx, get_actor_policy # noqa: E402
from sim.utils.helpers import get_args # noqa: E402
from sim.utils.logger import Logger # noqa: E402
from kinfer.export.pytorch import export_to_onnx

import torch # isort: skip

Expand Down Expand Up @@ -81,8 +82,19 @@ def play(args: argparse.Namespace) -> None:
# export policy as a onnx module (used to run it on web)
if args.export_onnx:
path = ppo_runner.alg.actor_critic
convert_model_to_onnx(path, ActorCfg(), save_path="policy.onnx")
print("Exported policy as onnx to: ", path)
policy_cfg = ActorCfg()
actor_model, sim2sim_info, input_tensors = get_actor_policy(path, policy_cfg)

# Merge policy_cfg and sim2sim_info into a single config object
export_config = {**vars(policy_cfg), **sim2sim_info}

policy = export_to_onnx(
actor_model,
input_tensors=input_tensors,
config=export_config,
save_path="kinfer_policy.onnx"
)
print("Exported policy as kinfer-compatible onnx to: ", path)

# Prepare for logging
env_logger = Logger(env.dt)
Expand Down
2 changes: 2 additions & 0 deletions sim/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ wandb
tensorboard==2.14.0
onnxscript
# onnxruntime

kinfer
8 changes: 3 additions & 5 deletions sim/sim2sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tqdm import tqdm

from sim.h5_logger import HDF5Logger
from model_export import ActorCfg, get_actor_policy, convert_model_to_onnx
from sim.model_export import ActorCfg, get_actor_policy
from kinfer.export.pytorch import export_to_onnx


Expand Down Expand Up @@ -363,10 +363,8 @@ def new_func(args, policy_cfg):
if args.load_model.endswith(".onnx"):
policy = ort.InferenceSession(args.load_model)
else:
# Export function is able to infer input shapes
# actor_model = new_func(args, policy_cfg)
# actor_model = torch.jit.load(args.load_model)
actor_model, sim2sim_info, input_tensors = get_actor_policy(args.load_model, policy_cfg)

# Merge policy_cfg and sim2sim_info into a single config object
export_config = {**vars(policy_cfg), **sim2sim_info}
print(export_config)
Expand All @@ -377,7 +375,7 @@ def new_func(args, policy_cfg):
save_path="kinfer_test.onnx"
)
# policy = convert_model_to_onnx(args.load_model, policy_cfg, save_path="policy.onnx")

model_info = parse_modelmeta(
policy.get_modelmeta().custom_metadata_map.items(),
verbose=True,
Expand Down

0 comments on commit 9a14be4

Please sign in to comment.