Skip to content

Commit

Permalink
Implement checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
CovERUshKA committed Jul 30, 2024
1 parent 871079c commit ea4084e
Show file tree
Hide file tree
Showing 9 changed files with 311 additions and 397 deletions.
2 changes: 1 addition & 1 deletion data/autoexec_server.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ loglevel 2
sv_score_folder "records"

# Max players on server
sv_max_clients 64
#sv_max_clients 64

# Max players with the same IP address
sv_max_clients_per_ip 4
Expand Down
17 changes: 9 additions & 8 deletions src/engine/server/NN/ModelManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ int64_t n_out = 7;
double stdrt = 2e-2;
double learning_rate = 1e-4; // Default: 1e-3

int64_t mini_batch_size = 8192; // 4096, 8192, 16384
int64_t mini_batch_size = 8192; // 4096, 8192, 16384, 32768
int64_t ppo_epochs = 3; // Default: 4
double dbeta = 1e-3; // Default: 1e-3
double clip_param = 0.2; // Default: 0.2
Expand Down Expand Up @@ -66,7 +66,7 @@ void generate_random_hyperparameters()
return;
}

ModelManager::ModelManager(){
ModelManager::ModelManager(size_t batch_size, size_t count_players){
printf("1 %i\n", torch::cuda::is_available());
//net_module.eval();
//torch::set_num_threads(4);
Expand All @@ -76,14 +76,15 @@ ModelManager::ModelManager(){
ac->normal(0., stdrt);
//ac->eval();
opt = std::make_shared<torch::optim::Adam>(ac->parameters(), learning_rate);
//torch::load(ac, "train\\up_lr\\models\\last_model.pt");
//torch::load(*opt, "train\\up_lr\\models\\last_optimizer.pt");
torch::load(ac, "train\\1722357361235\\models\\last_model.pt");
torch::load(*opt, "train\\1722357361235\\models\\last_optimizer.pt");
cout << "Learning rate: " << learning_rate << " Gamma: " << gamma << " Beta: " << dbeta << " clip_param: " << clip_param << " Epochs: " << ppo_epochs << " Mini batch size: " << mini_batch_size << endl;
//Sleep(7000);
ac->to(device);
printf("2\n");
//Sleep(7000);
// opt(ac->parameters(), 1e-3);
PPO::Initilize(batch_size, count_players);
}

std::vector<ModelOutput> ModelManager::Decide(std::vector<ModelInputInputs> &input_inputs, std::vector<ModelInputBlocks> input_blocks)
Expand All @@ -106,12 +107,12 @@ std::vector<ModelOutput> ModelManager::Decide(std::vector<ModelInputInputs> &inp
//printf("1.2\n");
one_hotted_blocks = one_hotted_blocks.to(torch::kF32);
//printf("1.3\n");
one_hotted_blocks = one_hotted_blocks.view({64, -1});
one_hotted_blocks = one_hotted_blocks.view({(long long)input_inputs.size(), -1});
//printf("1.4\n");
torch::Tensor state_forward = torch::cat({state_inputs, one_hotted_blocks}, 1);
//printf("2\n");
//states.push_back(state);
// printf("33\n");
//printf("33\n");
// Play.
//int64_t decide_time = std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::steady_clock::now().time_since_epoch()).count();
//cout << state_forward.sizes() << endl;
Expand Down Expand Up @@ -441,7 +442,7 @@ void ModelManager::SaveReplays()
return;
}

void ModelManager::Update()
void ModelManager::Update(double& avg_training_loss)
{
// Update.
//printf("Updating the network.\n");
Expand Down Expand Up @@ -469,7 +470,7 @@ void ModelManager::Update()
//torch::Tensor t_advantages = t_returns - t_values.slice(0, 0, rewards.size());
//printf("3");
//printf("UPDATING111\n");
PPO::update(ac, opt, rewards.size(), ppo_epochs, mini_batch_size, dbeta, gamma, device, clip_param);
avg_training_loss = PPO::update(ac, opt, rewards.size(), ppo_epochs, mini_batch_size, dbeta, gamma, device, clip_param);
//printf("UPDATed\n");
//printf("4");
// c = 0;
Expand Down
4 changes: 2 additions & 2 deletions src/engine/server/NN/ModelManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ struct ModelOutput

struct ModelManager
{
ModelManager();
ModelManager(size_t batch_size, size_t count_players);

ModelOutput Decide(ModelInputInputs &input);
std::vector<ModelOutput> Decide(std::vector<ModelInputInputs> &input, std::vector<ModelInputBlocks> blocks);
Expand All @@ -66,7 +66,7 @@ struct ModelManager
void Reward(float reward, bool done);
void SaveReplays();

void Update();
void Update(double &avg_training_loss);

void Save(std::string filename);

Expand Down
1 change: 1 addition & 0 deletions src/engine/server/NN/Models.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ struct ActorCriticImpl : public torch::nn::Module
// Forward pass.
auto forward(torch::Tensor x) -> std::tuple<torch::Tensor, torch::Tensor>
{
//torch::NoGradGuard no_grad;
// Actor.
//printf("1\n");
//mu_ = torch::relu(a_lin1_->forward(x));
Expand Down
Loading

0 comments on commit ea4084e

Please sign in to comment.