forked from djole/IR2L
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_es.py
92 lines (70 loc) · 2.6 KB
/
main_es.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
from train_test_model import inner_loop_ppo
from arguments import get_args
from env_util import register_set_goal
from math import log
from a2c_ppo_acktr.model import init_ppo, PolicyWithInstinct
from a2c_ppo_acktr.envs import make_vec_envs
from simpleES import EvolutionStrategy
from exp_dir_util import get_experiment_save_dir, get_start_gen_idx
from functools import partial
import torch
import numpy as np
def get_model_weights(model: PolicyWithInstinct):
params = model.get_evolvable_params()
copy_params = []
for p in params:
copy_params.append(p.data.clone().detach().numpy())
return copy_params
# Fitness function
def es_fitness_funct(parameters, env_list, args, num_steps, num_updates):
weights = parameters[:-1]
# Fix the learning rate to what is defined in the arguments
learning_rate = args.lr #parameters[-1][0]
learning_rate = -learning_rate if learning_rate < 0 else learning_rate
goal_info = [
inner_loop_ppo(
weights, args, learning_rate, num_steps, num_updates, run_idx=num_att, input_envs=env_list[num_att]
)
for num_att in range(len(env_list))
]
goal_fitnesses, _, _ = list(zip(*goal_info))
return sum(goal_fitnesses)
if __name__ == "__main__":
pop_size = 504
num_steps = 1500
args = get_args()
# set up the parallelization
try:
from mpipool import Pool
pool = Pool()
except:
pool = None
experiment_save_dir = get_experiment_save_dir(args)
env_name = register_set_goal(0)
init_sigma = args.init_sigma
envs = make_vec_envs(
env_name, args.seed, 1, args.gamma, None, torch.device("cpu"), False
)
if args.load_ga:
last_iter = get_start_gen_idx(args.load_ga, experiment_save_dir) - 1
start_weights = torch.load(os.path.join(experiment_save_dir, f"saved_weights_gen_{last_iter}.dat"))
else:
blueprint_model = init_ppo(envs, log(init_sigma))
start_weights = get_model_weights(blueprint_model)
start_weights.append(np.array([args.lr]))
#fitness_function = make_es_fitness_funct(args, num_steps, 1, args.num_goal_samples)
fitness_function = partial(
es_fitness_funct, args=args, num_steps=num_steps, num_updates=1
)
es = EvolutionStrategy(
start_weights,
fitness_function,
args,
population_size=pop_size,
sigma=0.1,
learning_rate=0.1,
decay=0.995,
experiment_save_dir=experiment_save_dir
)
es.run(1000, pool=pool, print_step=1, start_iteration=get_start_gen_idx(args.load_ga, experiment_save_dir))