forked from ARISE-Initiative/robosuite-benchmark
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rollout.py
114 lines (96 loc) · 3.79 KB
/
rollout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from util.rlkit_utils import simulate_policy
from util.arguments import add_rollout_args, parser
import robosuite as suite
from robosuite.wrappers import GymWrapper
from robosuite.controllers import ALL_CONTROLLERS, load_controller_config
import numpy as np
import torch
import imageio
import os
import json
from signal import signal, SIGINT
from sys import exit
os.environ['KMP_DUPLICATE_LIB_OK'] = "True"
# Add and parse arguments
add_rollout_args()
args = parser.parse_args()
# Define callbacks
video_writer = None
video_writer_obs = None
def handler(signal_received, frame):
# Handle any cleanup here
print('SIGINT or CTRL-C detected. Closing video writer and exiting gracefully')
video_writer.close()
exit(0)
# Tell Python to run the handler() function when SIGINT is recieved
signal(SIGINT, handler)
if __name__ == "__main__":
# Set random seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# Get path to saved model
kwargs_fpath = os.path.join(args.load_dir, "variant.json")
try:
with open(kwargs_fpath) as f:
kwargs = json.load(f)
except FileNotFoundError:
print("Error opening default controller filepath at: {}. "
"Please check filepath and try again.".format(kwargs_fpath))
# Grab / modify env args
env_args = kwargs["eval_environment_kwargs"]
if args.horizon is not None:
env_args["horizon"] = args.horizon
env_args["render_camera"] = args.camera
env_args["hard_reset"] = True
env_args["ignore_done"] = True
# Specify camera name if we're recording a video
# if args.record_video:
# env_args["camera_names"] = args.camera
# env_args["camera_heights"] = 512
# env_args["camera_widths"] = 512
# Setup video recorder if necesssary
if args.record_video:
# Grab name of this rollout combo
video_name = "{}-{}-{}".format(
env_args["env_name"], "".join(env_args["robots"]), env_args["controller"]).replace("_", "-")
obs_video_name = video_name + '-obs'
# print("video_name, obs_video_name: ", video_name, obs_video_name)
# Calculate appropriate fps
fps = int(env_args["control_freq"])
# Define video writer
video_writer = imageio.get_writer("{}.mp4".format(video_name), fps=fps)
video_writer_obs = imageio.get_writer("{}.mp4".format(obs_video_name), fps=fps)
# Pop the controller
controller = env_args.pop("controller")
if controller in ALL_CONTROLLERS:
controller_config = load_controller_config(default_controller=controller)
else:
controller_config = load_controller_config(custom_fpath=controller)
# Create env
env_suite = suite.make(**env_args,
controller_configs=controller_config,
has_renderer=not args.record_video,
# has_offscreen_renderer=args.record_video,
# use_object_obs=True,
# use_camera_obs=args.record_video,
reward_shaping=True
)
# Make sure we only pass in the proprio and object obs (no images)
# keys = ["object-state"]
# for idx in range(len(env_suite.robots)):
# keys.append(f"robot{idx}_proprio-state")
# Wrap environment so it's compatible with Gym API
print("env_suite: ", type(env_suite), env_suite.camera_names)
env = GymWrapper(env_suite)
# Run rollout
simulate_policy(
env=env,
model_path=os.path.join(args.load_dir, "params.pkl"),
horizon=env_args["horizon"],
render=not args.record_video,
video_writer=video_writer,
video_writer_obs=video_writer_obs,
num_episodes=args.num_episodes,
printout=True,
use_gpu=args.gpu,
)