Skip to content

Commit

Permalink
refactoring and removing unused imports and codes
Browse files Browse the repository at this point in the history
  • Loading branch information
Demirrr committed Sep 23, 2024
1 parent 367fbca commit 88a8c65
Showing 1 changed file with 0 additions and 40 deletions.
40 changes: 0 additions & 40 deletions dicee/trainer/torch_trainer_ddp.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
import os
import torch
import time
from torch.nn.parallel import DistributedDataParallel as DDP
from typing import Iterable
from dicee.abstracts import AbstractTrainer
from dicee.static_funcs_training import efficient_zero_grad
from torch.utils.data import DataLoader
from tqdm import tqdm

torch.set_float32_matmul_precision('high')

# DDP with gradiant accumulation https://gist.github.com/mcarilli/bf013d2d2f4b4dd21ade30c9b52d5e2e
def print_peak_memory(prefix, device):
if device == 0:
print(f"{prefix}: {torch.cuda.max_memory_allocated(device) // 1e6}MB ")

def make_iterable_verbose(iterable_object, verbose, desc="Default", position=None, leave=True) -> Iterable:
if verbose:
return tqdm(iterable_object, desc=desc, position=position, leave=leave)
Expand Down Expand Up @@ -151,36 +144,10 @@ def _run_epoch(self, epoch: int) -> float:
self.train_dataset_loader.sampler.set_epoch(epoch)
epoch_loss = 0
i = 0
construct_mini_batch_time = None
for i, z in enumerate(self.train_dataset_loader):
source, targets = self.extract_input_outputs(z)
start_time = time.time()
# if construct_mini_batch_time:
# construct_mini_batch_time = start_time - construct_mini_batch_time
batch_loss = self._run_batch(source, targets)
epoch_loss += batch_loss
"""
if True: # self.local_rank == self.global_rank==0:
if construct_mini_batch_time:
print(
f"Global:{self.global_rank}"
f" | Local:{self.local_rank}"
f" | Epoch:{epoch + 1}"
f" | Batch:{i + 1}"
f" | Loss:{batch_loss}"
f" | ForwardBackwardUpdate:{(time.time() - start_time):.2f}sec"
f" | BatchConst.:{construct_mini_batch_time:.2f}sec")
else:
print(
f"Global:{self.global_rank}"
f" | Local:{self.local_rank}"
f" | Epoch:{epoch + 1}"
f" | Batch:{i + 1}"
f" | Loss:{batch_loss}"
f" | ForwardBackwardUpdate:{(time.time() - start_time):.2f}secs")
"""

construct_mini_batch_time = time.time()
return epoch_loss / (i + 1)

def train(self):
Expand All @@ -206,13 +173,6 @@ def train(self):
else:
tqdm_bar.set_postfix_str(f"loss_step={batch_loss:.5f}, loss_epoch={batch_loss:.5f}")
avg_epoch_loss = epoch_loss / len(self.train_dataset_loader)
"""
print(f"Global:{self.global_rank}"
f" | Local:{self.local_rank}"
f" | Epoch:{epoch + 1}"
f" | Loss:{epoch_loss:.8f}"
f" | Runtime:{(time.time() - start_time) / 60:.3f}mins")
"""

if self.local_rank == self.global_rank == 0:
self.model.module.loss_history.append(avg_epoch_loss)
Expand Down

0 comments on commit 88a8c65

Please sign in to comment.