Skip to content

Commit

Permalink
Merge pull request #4 from polixir/dev
Browse files Browse the repository at this point in the history
fix bugs:

1. data root assignment;
2. tensorboard injection
  • Loading branch information
xionghuichen authored Jun 4, 2022
2 parents 2af72f6 + 7d2df4a commit b54d1e2
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 10 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,9 @@ PS:
2. An alternative way is building your own NFS for your physical machines and locate data_root to the NFS.

# TODO
- [ ] video visualization.
- [ ] support sftp-based sync.
- [ ] support custom data structure saving and loading.
- [ ] support video visualization.
- [ ] add comments and documents to the functions.
- [ ] add an auto integration script.
- [ ] download / upload experiment logs through timestamp.
Expand Down
13 changes: 6 additions & 7 deletions RLA/easy_log/exp_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ class ExperimentLoader(object):
def __init__(self):
self.task_name = exp_manager.hyper_param.get('loaded_task_name', None)
self.load_date = exp_manager.hyper_param.get('loaded_date', None)
self.root = getattr(exp_manager, 'root', None)
self.data_root = None
self.data_root = getattr(exp_manager, 'root', None)
pass

def config(self, task_name, record_date, root):
Expand All @@ -49,12 +48,12 @@ def is_valid_config(self):
logger.warn("meet invalid loader config when use it")
logger.warn("load_date", self.load_date)
logger.warn("task_name", self.task_name)
logger.warn("root", self.root)
logger.warn("root", self.data_root)
return False

def import_hyper_parameters(self, hp_to_overwrite: Optional[list] = None):
if self.is_valid_config:
load_tester = Tester.load_tester(self.load_date, self.task_name, self.root)
load_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root)
target_hp = copy.deepcopy(exp_manager.hyper_param)
target_hp.update(load_tester.hyper_param)
if hp_to_overwrite is not None:
Expand All @@ -75,7 +74,7 @@ def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list:
:return:
"""
if self.is_valid_config:
loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.root)
loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root)
# load checkpoint
load_res = {}
if var_prefix is not None:
Expand All @@ -100,7 +99,7 @@ def fork_log_files(self):
if self.is_valid_config:
global exp_manager
assert isinstance(exp_manager, Tester)
loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.root)
loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root)
# copy log file
exp_manager.log_file_copy(loaded_tester)
# copy attribute
Expand All @@ -109,4 +108,4 @@ def fork_log_files(self):
exp_manager.private_config = loaded_tester.private_config


exp_loader = experimental_loader = ExperimentLoader()
exp_loader = experimental_loader = ExperimentLoader()
2 changes: 1 addition & 1 deletion RLA/easy_log/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ def configure(dir=None, format_strs=None, comm=None, framework='tensorflow'):
if format_strs is None:
format_strs = os.getenv('OPENAI_LOG_FORMAT', 'stdout,log,csv').split(',')
format_strs = filter(None, format_strs)
output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
output_formats = [make_output_format(f, dir, log_suffix, framework) for f in format_strs]
warn_output_formats = make_output_format('warn', dir, log_suffix, framework)
backup_output_formats = make_output_format('backup', dir, log_suffix, framework)

Expand Down
2 changes: 1 addition & 1 deletion RLA/easy_log/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def _init_logger(self):
self.writer = None
# logger configure
logger.info("store file %s" % self.pkl_file)
logger.configure(self.log_dir, self.private_config["LOG_USED"])
logger.configure(self.log_dir, self.private_config["LOG_USED"], framework=self.private_config["DL_FRAMEWORK"])
for fmt in logger.Logger.CURRENT.output_formats:
if isinstance(fmt, logger.TensorBoardOutputFormat):
self.writer = fmt.writer
Expand Down

0 comments on commit b54d1e2

Please sign in to comment.