diff --git a/RLA/easy_log/logger.py b/RLA/easy_log/logger.py index 6ac17c2..7f6ed4d 100644 --- a/RLA/easy_log/logger.py +++ b/RLA/easy_log/logger.py @@ -688,7 +688,7 @@ def configure(dir=None, format_strs=None, comm=None, framework='tensorflow'): if format_strs is None: format_strs = os.getenv('OPENAI_LOG_FORMAT', 'stdout,log,csv').split(',') format_strs = filter(None, format_strs) - output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] + output_formats = [make_output_format(f, dir, log_suffix, framework) for f in format_strs] warn_output_formats = make_output_format('warn', dir, log_suffix, framework) backup_output_formats = make_output_format('backup', dir, log_suffix, framework) diff --git a/RLA/easy_log/tester.py b/RLA/easy_log/tester.py index 9fce214..e7ba846 100644 --- a/RLA/easy_log/tester.py +++ b/RLA/easy_log/tester.py @@ -187,7 +187,7 @@ def _init_logger(self): self.writer = None # logger configure logger.info("store file %s" % self.pkl_file) - logger.configure(self.log_dir, self.private_config["LOG_USED"]) + logger.configure(self.log_dir, self.private_config["LOG_USED"], framework=self.private_config["DL_FRAMEWORK"]) for fmt in logger.Logger.CURRENT.output_formats: if isinstance(fmt, logger.TensorBoardOutputFormat): self.writer = fmt.writer