Skip to content

Commit

Permalink
feat(easy_log,easy_plot): add parameter.json to LOG folder. support t…
Browse files Browse the repository at this point in the history
…he hyper-parmater loading (for easy_plot) from LOG folder.
  • Loading branch information
xionghuichen committed Feb 7, 2023
1 parent d64e64d commit 4f13598
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 13 deletions.
3 changes: 2 additions & 1 deletion RLA/auto_ftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import traceback
from RLA.const import *
from RLA.easy_log import logger

import pysftp


Expand All @@ -17,6 +16,7 @@ def ftp_factory(name, server, username, password, port, ignore=None):
else:
raise NotImplementedError


class FTPHandler(object):

def __init__(self, ftp_server, username, password, port, ignore=None):
Expand Down Expand Up @@ -139,6 +139,7 @@ def close(self):
self.ftp.quit()
self.ftp.close()


class SFTPHandler(FTPHandler):

def __init__(self, sftp_server, username, password, port, ignore=None):
Expand Down
2 changes: 1 addition & 1 deletion RLA/easy_log/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
OTHER_RESULTS = 'results'
ARCHIVED_TABLE = 'arc'
default_log_types = [LOG, CODE, CHECKPOINT, ARCHIVE_TESTER, OTHER_RESULTS]

HYPARAM = 'parameter'

class LoadTesterMode:
FORK_TO_NEW = 'fork'
Expand Down
8 changes: 5 additions & 3 deletions RLA/easy_log/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,11 @@ def log_files_gen(self):
self.serialize_object_and_save()
self.__copy_source_code(self.run_file, code_dir)
self._feed_hyper_params_to_tb()
params = self.hyper_param
for param_dir in [self.code_dir, self.log_dir]:
with open(osp.join(param_dir, HYPARAM + '.json'), 'w') as f:
json.dump(params, f, sort_keys=True, indent=4, allow_nan=True, default=lambda o: '<not serializable>')
print("gen:", osp.join(param_dir, 'parameter.json'))
self.print_log_dir()

def update_log_files_location(self, root:str):
Expand Down Expand Up @@ -782,9 +787,6 @@ def print_args(self):
# formatted_log_name = self.log_name_formatter(self.get_task_table_name(), self.record_date)
params = exp_manager.hyper_param
# params['formatted_log_name'] = formatted_log_name
json.dump(params, open(osp.join(self.code_dir, 'parameter.json'), 'w'),
sort_keys=True, indent=4, allow_nan=True, default=lambda o: '<not serializable>')
print("gen:", osp.join(self.code_dir, 'parameter.json'))


def print_large_memory_variable(self):
Expand Down
13 changes: 7 additions & 6 deletions RLA/easy_plot/plot_func_v2.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
# Created by xionghuichen at 2022/8/10
# Email: chenxh@lamda.nju.edu.cn
import glob
import json
import os.path as osp
import os
import dill
import copy
import numpy as np
from typing import Dict, List, Tuple, Type, Union, Optional, Callable
import matplotlib.pyplot as plt

from RLA import logger
from RLA.const import DEFAULT_X_NAME
from RLA.query_tool import experiment_data_query, extract_valid_index

from RLA.easy_plot import plot_util
from RLA.easy_log.const import LOG, ARCHIVE_TESTER, OTHER_RESULTS
from RLA.easy_log.const import LOG, ARCHIVE_TESTER, OTHER_RESULTS, HYPARAM



Expand All @@ -25,7 +24,6 @@ def default_key_to_legend(parse_dict, split_keys, y_name, use_y_name=True):
else:
return task_split_key


def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, metrics:list,
use_buf=False, verbose=True,
x_bound: Optional[int]=None,
Expand Down Expand Up @@ -97,7 +95,11 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me
if verbose:
print("find log", v.dirname)
counter += 1
result.hyper_param = tester_dict[k].exp_manager.hyper_param
if os.path.exists(osp.join(v.dirname, HYPARAM + '.json')):
with open(osp.join(v.dirname, HYPARAM + '.json')) as f:
result.hyper_param = json.load(f)
else:
result.hyper_param = tester_dict[k].exp_manager.hyper_param
results.append(result)
reg_group[reg].append(result)
print("find log number", counter)
Expand Down Expand Up @@ -126,7 +128,6 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me
split_by_metrics=split_by_metrics, regs2legends=regs2legends, *args, **kwargs)
print("--- complete process ---")
if save_name is not None:
import os
file_name = osp.join(data_root, OTHER_RESULTS, 'easy_plot', save_name)
os.makedirs(os.path.dirname(file_name), exist_ok=True)
if lgd is not None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"env_id": "Test-v1",
"info": "default exp info",
"input_size": 16,
"learning_rate": 0.001,
"loaded_date": true,
"loaded_task_name": "",
"seed": 88
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"input_size": 16,
"learning_rate": 0.0001
}
2 changes: 1 addition & 1 deletion test/test_plot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 16,
"id": "a55937ed",
"metadata": {
"scrolled": false
Expand Down
14 changes: 13 additions & 1 deletion test/test_proj/proj/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_log_tf(self):

exp_manager.new_saver(var_prefix='', max_to_keep=1)
# synthetic target function.
for i in range(0, 1000):
for i in range(0, 100):
exp_manager.time_step_holder.set_time(i)
x_input = np.random.normal(0, 3, [64, kwargs["input_size"]])
y = target_func(x_input)
Expand Down Expand Up @@ -143,10 +143,22 @@ def test_sent_to_master(self):
yaml = self._load_rla_config()
try:
from test.test_proj.proj import private_config
# try to import libs
except ImportError as e:
print("[WARN] for this test, you should config your username, password, and the remote root firstly.")
return
# raise RuntimeError
try:
if private_config.protocol == 'ftp':
import ftplib
elif private_config.protocol == 'sftp':
import pysftp
else:
raise NotImplementedError
except ImportError as e:
print(e)
print(f"[WARN] the select protocol {private_config.protocol} cannot be loaded. skip the unittest.")
return
yaml['DL_FRAMEWORK'] = 'torch'
yaml['SEND_LOG_FILE'] = True
yaml['REMOTE_SETTING']['ftp_server'] = '127.0.0.1'
Expand Down

0 comments on commit 4f13598

Please sign in to comment.