Skip to content

Commit

Permalink
Fix rm bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Harry24k committed Dec 13, 2023
1 parent e8506d1 commit 19b6265
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 15 deletions.
2 changes: 1 addition & 1 deletion mair/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@

# from .utils.datasets import Datasets

__version__ = "1.0.2"
__version__ = "1.0.3"
4 changes: 2 additions & 2 deletions mair/defenses/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,13 @@ def progress_end(self):

def print(self, record_type, str, *args, **kargs):
if record_type is not None:
print(str)
print(str, *args, **kargs)
self.log.append(str)
self.save_log()

def print_only(self, record_type, str, *args, **kargs):
if record_type is not None:
print(str)
print(str, *args, **kargs)

def _add_progress_time(self, dict_record):
dict_record["s/it"] = np.array(self._progress_times).mean()
Expand Down
21 changes: 10 additions & 11 deletions mair/defenses/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,16 @@ def fit(
start_iter = 0

# Print train information
if record_type:
self.rm.print("[%s]" % self.__class__.__name__)
self.rm.print("Training Information.")
self.rm.print("-Epochs: %s" % n_epochs)
self.rm.print("-Optimizer: %s" % self.optimizer)
self.rm.print("-Scheduler: %s" % self.scheduler)
self.rm.print("-Minmizer: %s" % self.minimizer)
self.rm.print("-Save Path: %s" % save_path)
self.rm.print("-Save Type: %s" % str(save_type))
self.rm.print("-Record Type: %s" % str(record_type))
self.rm.print("-Device: %s" % self.device)
self.rm.print(record_type, "[%s]" % self.__class__.__name__)
self.rm.print(record_type, "Training Information.")
self.rm.print(record_type, "-Epochs: %s" % n_epochs)
self.rm.print(record_type, "-Optimizer: %s" % self.optimizer)
self.rm.print(record_type, "-Scheduler: %s" % self.scheduler)
self.rm.print(record_type, "-Minmizer: %s" % self.minimizer)
self.rm.print(record_type, "-Save Path: %s" % save_path)
self.rm.print(record_type, "-Save Type: %s" % str(save_type))
self.rm.print(record_type, "-Record Type: %s" % str(record_type))
self.rm.print(record_type, "-Device: %s" % self.device)

# Start training
for epoch in range(start_epoch, n_epochs):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

setuptools.setup(
name="mair",
version="1.0.2",
version="1.0.3",
description="MAIR is a PyTorch-based adversarial training framework.",
author="Harry Kim",
author_email="24k.harry@gmail.com",
Expand Down

0 comments on commit 19b6265

Please sign in to comment.