-
Notifications
You must be signed in to change notification settings - Fork 1
/
Envs.cpp
50 lines (49 loc) · 1.56 KB
/
Envs.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include "Envs.h"
Envs::Res Envs::reset() {
Res res;
for (int i = 0; i < NUM_ENVS; i++) {
_envs[i].reset();
// for(auto e: _envs.at(i).get_the_1st_observation()){
// res.feature.at(i).emplace_back(static_cast<float>(e));
// }
res.feature.at(i) = _envs.at(i).get_the_1st_observation();
res.reward.emplace_back(0);
res.done.emplace_back(false);
}
return res;
}
vector<vector<double>> Envs::reset(int env_idx) {
_envs[env_idx].reset();
return _envs[env_idx].get_the_1st_observation();
}
Envs::Res Envs::step(const std::vector<std::vector<double>> &actions) {
Res res;
Drcu::Res drcu_res;
for (int i = 0; i < NUM_ENVS; i++) {
drcu_res = _envs[i].step(actions.at(i));
res.feature.at(i) = drcu_res.feature;
res.reward.emplace_back(drcu_res.reward);
res.done.emplace_back(drcu_res.done);
}
return res;
}
Envs::Res Envs::init(int argc, char **short_format_argv) {
Res res;
std::vector<std::string> argv (argc);
for (int i = 0; i < argc; ++i) {
argv[i] = short_format_argv[i];
}
for (int i = 0; i < NUM_ENVS; i++) {
_envs[i].init(argv);
// for(auto e: _envs.at(i).get_the_1st_observation()){
// res.feature.at(i).emplace_back(static_cast<float>(e));
// }
res.feature.at(i) = _envs.at(i).get_the_1st_observation();
res.reward.emplace_back(0);
res.done.emplace_back(false);
}
return res;
}
std::array<double, 4> Envs::get_all_vio() const {
return _envs[0].get_all_vio();
}