diff --git a/README.md b/README.md index a0761ef..9150f6b 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,17 @@ Modifying: Querying: 1. tensorboard: the recorded variables are added to tensorboard events and can be loaded via standard tensorboard tools. ![img.png](resource/tb-img.png) -2. easy_plot: We give some APIs to load and visualize the data in CSV files. The results will be something like this: +2. easy_plot: We give some APIs to load and visualize the data in CSV files. The results will be something like this: + ```python + from RLA.easy_plot.plot_func_v2 import plot_func + data_root='your_project' + task = 'sac_test' + regs = [ + '2022/03/01/21-[12]*' + ] + _ = plot_func(data_root=data_root, task_table_name=task, + regs=regs , split_keys=['info', 'van_sac', 'alpha_max'], metrics=['perf/rewards']) + ``` ![](resource/sample-plot.png) @@ -86,7 +96,7 @@ We also list the RL research projects using RLA as follows: ```angular2html git clone https://github.com/xionghuichen/RLAssistant.git cd RLAssistant -python setup.py install +pip install -e . ``` @@ -99,13 +109,19 @@ We build an example project for integrating RLA, which can be seen in ./example/ 1. We define the property of the database in `rla_config.yaml`. You can construct your YAML file based on the template in ./example/simplest_code/rla_config.yaml. 2. We define the property of the table in exp_manager.config. Before starting your experiment, you should configure the global object RLA.easy_log.tester.exp_manager like this. ```python - from RLA.easy_log.tester import exp_manager + from RLA import exp_manager kwargs = {'env_id': 'Hopper-v2', 'lr': 1e-3} exp_manager.set_hyper_param(**kwargs) # kwargs are the hyper-parameters for your experiment exp_manager.add_record_param(["env_id"]) # add parts of hyper-parameters to name the index of data items for better readability. task_name = 'demo_task' # define your task - rla_data_root = '../' # the place to store the data items. - exp_manager.configure(task_name, private_config_path='../../../rla_config.yaml', data_root=rla_data_root) + + def get_package_path(): + return os.path.dirname(os.path.abspath(__file__)) + + rla_data_root = get_package_path() # the place to store the data items. + + rla_config = os.path.join(get_package_path(), 'rla_config.yaml') + exp_manager.configure(task_table_name=task_name, rla_config=rla_config, data_root=rla_data_root) exp_manager.log_files_gen() # initialize the data items. exp_manager.print_args() ``` @@ -124,9 +140,9 @@ We build an example project for integrating RLA, which can be seen in ./example/ We record scalars by `RLA.easy_log.logger`: ```python -from RLA.easy_log import logger +from RLA import logger import tensorflow as tf -from RLA.easy_log.time_step import time_step_holder +from RLA import time_step_holder for i in range(1000): # time-steps (iterations) @@ -143,7 +159,7 @@ for i in range(1000): We save checkpoints of neural networks by `exp_manager.save_checkpoint`. ```python -from RLA.easy_log.tester import exp_manager +from RLA import exp_manager exp_manager.new_saver() for i in range(1000): @@ -157,7 +173,7 @@ Currently we can record complex-structure data based on tensorboard: ```python # from tensorflow summary import tensorflow as tf -from RLA.easy_log import logger +from RLA import logger summary = tf.Summary() logger.log_from_tf_summary(summary) # from tensorboardX writer @@ -169,7 +185,7 @@ We will develop APIs to record common-used complex-structure data in RLA.easy_lo Now we give a MatplotlibRecorder tool to manage your figures generated by matplotlib: ```python -from RLA.easy_log.complex_data_recorder import MatplotlibRecorder as mpr +from RLA import MatplotlibRecorder as mpr def plot_func(): import matplotlib.pyplot as plt plt.plot([1,1,1], [2,2,2]) @@ -191,7 +207,7 @@ Result visualization: For example, lanuch tensorboard by `tensorboard --logdir ./example/simplest_code/log/demo_task/2022/03`. We can see results: ![img.png](resource/demo-tb-res.png) 2. Easy_plot toolkit: The intermediate scalar variables are saved in a CSV file in `${data_root}/log/${task_name}/${index_name}/progress.csv`. - We develop high-level APIs to load the CSV files from multiple experiments and group the lines by custom keys. We give an example to use easy_plot toolkit in example/plot_res.ipynb. + We develop high-level APIs to load the CSV files from multiple experiments and group the lines by custom keys. We give an example to use easy_plot toolkit in https://github.com/xionghuichen/RLAssistant/blob/main/example/plot_res.ipynb and more user cases in https://github.com/xionghuichen/RLAssistant/blob/main/test/test_plot.py The result will be something like this: ![img.png](resource/demo-easy-to-plot-res.png) 3. View data in "results" directory directly: other type of data are stored in `${data_root}/results/${task_name}/${index_name}` @@ -217,7 +233,7 @@ We manage the items in the database via toolkits in rla_scripts. Currently, the 3. Send to remote [TODO] 4. Download from remote [TODO] -We can use the above tools after copying the rla_scripts to our research project and modifying the DATA_ROOT in config.py to locate the root of the RLA database. +We can use the above tools after copying the rla_scripts to our research project and modifying the DATA_ROOT in config.py to locate the root of the RLA database. We give several user cases in https://github.com/xionghuichen/RLAssistant/blob/main/test/test_scripts.py ## Distributed training & centralized logs diff --git a/RLA/__init__.py b/RLA/__init__.py index 796741c..0a4478b 100644 --- a/RLA/__init__.py +++ b/RLA/__init__.py @@ -1,3 +1,5 @@ from RLA.easy_log.tester import exp_manager from RLA.easy_log import logger -from RLA.easy_plot.plot_func_v2 import plot_func \ No newline at end of file +from RLA.easy_log.time_step import time_step_holder +from RLA.easy_plot.plot_func_v2 import plot_func +from RLA.easy_log.complex_data_recorder import MatplotlibRecorder \ No newline at end of file diff --git a/RLA/easy_log/exp_loader.py b/RLA/easy_log/exp_loader.py index b9a82b7..6d98c0c 100644 --- a/RLA/easy_log/exp_loader.py +++ b/RLA/easy_log/exp_loader.py @@ -2,7 +2,7 @@ from RLA.easy_log.tester import exp_manager, Tester import copy import argparse -from typing import Optional, OrderedDict, Union, Dict, Any +from typing import Optional from RLA.const import DEFAULT_X_NAME from pprint import pprint @@ -88,10 +88,10 @@ def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list: load_res = {} if var_prefix is not None: loaded_tester.new_saver(var_prefix=var_prefix, max_to_keep=1) - _, load_res = loaded_tester.load_checkpoint() + _, load_res = loaded_tester.load_checkpoint(ckp_index) else: loaded_tester.new_saver(max_to_keep=1) - _, load_res = loaded_tester.load_checkpoint() + _, load_res = loaded_tester.load_checkpoint(ckp_index) hist_variables = {} if variable_list is not None: for v in variable_list: diff --git a/RLA/easy_log/logger.py b/RLA/easy_log/logger.py index 8bcccb2..d43c0a2 100644 --- a/RLA/easy_log/logger.py +++ b/RLA/easy_log/logger.py @@ -406,7 +406,7 @@ def timestep(): ma_dict = {} -def ma_record_tabular(key, val, record_len, ignore_nan=False, exclude:Optional[Union[str, Tuple[str, ...]]]=None): +def ma_record_tabular(key, val, record_len, ignore_nan=False, exclude:Optional[Union[str, Tuple[str, ...]]]=None, freq:Optional[int]=None): if key not in ma_dict: ma_dict[key] = deque(maxlen=record_len) if ignore_nan: @@ -415,7 +415,10 @@ def ma_record_tabular(key, val, record_len, ignore_nan=False, exclude:Optional[U else: ma_dict[key].append(val) if len(ma_dict[key]) == record_len: - record_tabular(key, np.mean(ma_dict[key]), exclude) + record_tabular(key, np.mean(ma_dict[key]), exclude, freq) + + +lst_print_dict = {} def logkv(key, val, exclude:Optional[Union[str, Tuple[str, ...]]]=None, freq:Optional[int]=None): """ @@ -426,8 +429,11 @@ def logkv(key, val, exclude:Optional[Union[str, Tuple[str, ...]]]=None, freq:Opt :param key: (Any) save to log this key :param val: (Any) save to log this value """ - if freq is None or timestep() % freq == 0: + if key not in lst_print_dict: + lst_print_dict[key] = -np.inf + if freq is None or timestep() - lst_print_dict[key] >= freq: get_current().logkv(key, val, exclude) + lst_print_dict[key] = timestep() def log_from_tf_summary(summary): @@ -463,12 +469,12 @@ def logkv_mean(key, val): """ get_current().logkv_mean(key, val) -def logkvs(d, exclude:Optional[Union[str, Tuple[str, ...]]]=None): +def logkvs(d, prefix:Optional[str]='', exclude:Optional[Union[str, Tuple[str, ...]]]=None): """ Log a dictionary of key-value pairs """ for (k, v) in d.items(): - logkv(k, v, exclude) + logkv(prefix+k, v, exclude) def log_key_value(keys, values, prefix_name=''): diff --git a/RLA/easy_log/tester.py b/RLA/easy_log/tester.py index d98e69d..cc2dc1c 100644 --- a/RLA/easy_log/tester.py +++ b/RLA/easy_log/tester.py @@ -15,7 +15,6 @@ import os.path as osp import pprint -import tensorboardX from RLA.easy_log.time_step import time_step_holder from RLA.easy_log import logger @@ -134,7 +133,7 @@ def configure(self, task_table_name: str, rla_config: Union[str, dict], data_roo :param is_master_node: In "distributed training & centralized logs" mode (By set SEND_LOG_FILE in rla_config.yaml to True), you should mark the master node (is_master_node=True) to collect logs of the slave nodes (is_master_node=False). :type is_master_node: bool - : param code_root: Define the root of your codebase (for backup) explicitly. It will be in the same location as rla_config.yaml by default. + :param code_root: Define the root of your codebase (for backup) explicitly. It will be in the same location as rla_config.yaml by default. """ if isinstance(rla_config, str): self.private_config = load_yaml(rla_config) @@ -543,11 +542,31 @@ def update_fph(self, cum_epochs): # self.last_record_fph_time = cur_time logger.dump_tabular() - def time_record(self, name): + def time_record(self, name:str): + """ + [deprecated] see RLA.easy_log.time_used_recorder + record the consumed time of your code snippet. call this function to start a recorder. + "name" is identifier to distinguish different recorder and record different snippets at the same time. + call time_record_end to end a recorder. + :param name: identifier of your code snippet. + :type name: str + :return: + :rtype: + """ assert name not in self._rc_start_time self._rc_start_time[name] = time.time() - def time_record_end(self, name): + def time_record_end(self, name:str): + """ + [deprecated] see RLA.easy_log.time_used_recorder + record the consumed time of your code snippet. call this function to start a recorder. + "name" is identifier to distinguish different recorder and record different snippets at the same time. + call time_record_end to end a recorder. + :param name: identifier of your code snippet. + :type name: str + :return: + :rtype: + """ end_time = time.time() start_time = self._rc_start_time[name] logger.record_tabular("time_used/{}".format(name), end_time - start_time) @@ -566,23 +585,46 @@ def new_saver(self, max_to_keep, var_prefix=None): import tensorflow as tf if var_prefix is None: var_prefix = '' - var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, var_prefix) - logger.info("save variable :") - for v in var_list: - logger.info(v) - self.saver = tf.train.Saver(var_list=var_list, max_to_keep=max_to_keep, filename=self.checkpoint_dir, save_relative_paths=True) + try: + var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, var_prefix) + logger.info("save variable :") + for v in var_list: + logger.info(v) + self.saver = tf.train.Saver(var_list=var_list, max_to_keep=max_to_keep, filename=self.checkpoint_dir, + save_relative_paths=True) + + except AttributeError as e: + self.max_to_keep = max_to_keep + # tf.compat.v1.disable_eager_execution() + # tf = tf.compat.v1 + # var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, var_prefix) elif self.dl_framework == FRAMEWORK.torch: self.max_to_keep = max_to_keep else: raise NotImplementedError - def save_checkpoint(self, model_dict: Optional[dict]=None, related_variable: Optional[dict]=None): + def save_checkpoint(self, model_dict: Optional[dict] = None, related_variable: Optional[dict] = None): if self.dl_framework == FRAMEWORK.tensorflow: import tensorflow as tf iter = self.time_step_holder.get_time() cpt_name = osp.join(self.checkpoint_dir, 'checkpoint') logger.info("save checkpoint to ", cpt_name, iter) - self.saver.save(tf.get_default_session(), cpt_name, global_step=iter) + try: + self.saver.save(tf.get_default_session(), cpt_name, global_step=iter) + except AttributeError as e: + if model_dict is None: + logger.warn("call save_checkpoints without passing a model_dict") + return + if self.checkpoint_keep_list is None: + self.checkpoint_keep_list = [] + iter = self.time_step_holder.get_time() + # tf.compat.v1.disable_eager_execution() + # tf = tf.compat.v1 + # self.saver.save(tf.get_default_session(), cpt_name, global_step=iter) + + tf.train.Checkpoint(**model_dict).save(tester.checkpoint_dir + "checkpoint-{}".format(iter)) + self.checkpoint_keep_list.append(iter) + self.checkpoint_keep_list = self.checkpoint_keep_list[-1 * self.max_to_keep:] elif self.dl_framework == FRAMEWORK.torch: import torch if self.checkpoint_keep_list is None: @@ -602,6 +644,7 @@ def save_checkpoint(self, model_dict: Optional[dict]=None, related_variable: Opt for k, v in related_variable.items(): self.add_custom_data(k, v, type(v), mode='replace') self.add_custom_data(DEFAULT_X_NAME, time_step_holder.get_time(), int, mode='replace') + self.serialize_object_and_save() def load_checkpoint(self, ckp_index=None): if self.dl_framework == FRAMEWORK.tensorflow: @@ -613,6 +656,7 @@ def load_checkpoint(self, ckp_index=None): ckpt_path = tf.train.latest_checkpoint(cpt_name) else: ckpt_path = tf.train.latest_checkpoint(cpt_name, ckp_index) + logger.info("load ckpt_path {}".format(ckpt_path)) self.saver.restore(tf.get_default_session(), ckpt_path) max_iter = ckpt_path.split('-')[-1] return int(max_iter), None diff --git a/RLA/easy_plot/plot_func_v2.py b/RLA/easy_plot/plot_func_v2.py index 1acb282..5e18185 100644 --- a/RLA/easy_plot/plot_func_v2.py +++ b/RLA/easy_plot/plot_func_v2.py @@ -11,14 +11,15 @@ from RLA import logger from RLA.const import DEFAULT_X_NAME -from RLA.query_tool import experiment_data_query +from RLA.query_tool import experiment_data_query, extract_valid_index + from RLA.easy_plot import plot_util from RLA.easy_log.const import LOG, ARCHIVE_TESTER, OTHER_RESULTS -def default_key_to_legend(parse_list, y_name): - task_split_key = '.'.join(parse_list) +def default_key_to_legend(parse_dict, split_keys, y_name): + task_split_key = '.'.join(f'{k}={parse_dict[k]}' for k in split_keys) return task_split_key + ' eval:' + y_name @@ -26,7 +27,7 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me use_buf=False, verbose=True, xlim: Optional[tuple] = None, xlabel: Optional[str] = DEFAULT_X_NAME, ylabel: Optional[str] = None, - scale_dict: Optional[dict] = None, replace_legend_keys: Optional[list] = None, + scale_dict: Optional[dict] = None, regs2legends: Optional[list] = None, key_to_legend_fn: Optional[Callable] = default_key_to_legend, save_name: Optional[str] = None, *args, **kwargs): """ @@ -34,7 +35,7 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me The function is to load your experiments and plot curves. You can group several experiments into a single figure through this function. It is completed by loading experiments satisfying [data_root, task_table_name, regs] pattern, - grouping by "split_keys" or by the "regs" terms (see replace_legend_keys), and plotting the customized "metrics". + grouping by "split_keys" or by the "regs" terms (see regs2legends), and plotting the customized "metrics". The function support several configure to customize the figure, including xlim, xlabel, ylabel, key_to_legend_fn, etc. The function also supports several configure to post-process your log data, including resample, smooth_step, scale_dict, key_to_legend_fn, etc. @@ -61,7 +62,13 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me :param scale_dict: a function dict, to map the value of the metrics through customize functions. e.g.,set metrics = ['return'], scale_dict = {'return': lambda x: np.log(x)}, then we will plot a log-scale return. :type scale_dict: Optional[dict] - :param args: set the label of the y axes. + :param regs2legends: use regex-to-legend mode to plot the figure. Each iterm in regs will be gouped into a curve. + In this reg2legend_map mode, you should define the lgend name for each curve. See test/test_plot/test_reg_map_mode for details. + :type regs2legends: Optional[list] = None + :param key_to_legend_fn: we give a default function to stringify the k-v pairs. you can customize your own function in key_to_legend_fn. + See default_key_to_legend for the detault way and test/test_plot/test_customize_legend_name_mode for details. + :type key_to_legend_fn: Optional[Callable] = default_key_to_legend + :param args/kwargs: send other parameters to plot_util.plot_results :return: :rtype: @@ -98,17 +105,17 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me if ylabel is None: ylabel = metrics - if replace_legend_keys is not None: - assert len(replace_legend_keys) == len(regs) and len(metrics) == 1, \ + if regs2legends is not None: + assert len(regs2legends) == len(regs) and len(metrics) == 1, \ "In manual legend-key mode, the number of keys should be one-to-one matched with regs" - # if len(replace_legend_keys) == len(regs): + # if len(regs2legends) == len(regs): group_fn = lambda r: split_by_reg(taskpath=r, reg_group=reg_group, y_names=y_names) else: group_fn = lambda r: picture_split(taskpath=r, split_keys=split_keys, y_names=y_names, key_to_legend_fn=key_to_legend_fn) _, _, lgd, texts, g2lf, score_results = \ plot_util.plot_results(results, xy_fn= lambda r, y_names: csv_to_xy(r, DEFAULT_X_NAME, y_names, final_scale_dict), - group_fn=group_fn, average_group=True, ylabel=ylabel, xlabel=xlabel, replace_legend_keys=replace_legend_keys, *args, **kwargs) + group_fn=group_fn, average_group=True, ylabel=ylabel, xlabel=xlabel, regs2legends=regs2legends, *args, **kwargs) print("--- complete process ---") if save_name is not None: import os @@ -127,9 +134,10 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me def split_by_reg(taskpath, reg_group, y_names): task_split_key = "None" for i , reg_k in enumerate(reg_group.keys()): - if taskpath.dirname in reg_group[reg_k]: - assert task_split_key == "None", "one experiment should belong to only one reg_group" - task_split_key = str(i) + for result in reg_group[reg_k]: + if taskpath.dirname == result.dirname: + assert task_split_key == "None", "one experiment should belong to only one reg_group" + task_split_key = str(i) assert len(y_names) == 1 return task_split_key, y_names @@ -137,15 +145,17 @@ def split_by_reg(taskpath, reg_group, y_names): def split_by_task(taskpath, split_keys, y_names, key_to_legend_fn): pair_delimiter = '&' kv_delimiter = '=' - parse_list = [] + parse_dict = {} for split_key in split_keys: if split_key in taskpath.hyper_param: - parse_list.append(split_key + '=' + str(taskpath.hyper_param[split_key])) + parse_dict[split_key] = str(taskpath.hyper_param[split_key]) + # parse_list.append(split_key + '=' + str(taskpath.hyper_param[split_key])) else: - parse_list.append(split_key + '=NF') + parse_dict[split_key] = 'NF' + # parse_list.append(split_key + '=NF') param_keys = [] for y_name in y_names: - param_keys.append(key_to_legend_fn(parse_list, y_name)) + param_keys.append(key_to_legend_fn(parse_dict, split_keys, y_name)) return param_keys, y_names diff --git a/RLA/easy_plot/plot_util.py b/RLA/easy_plot/plot_util.py index 42779aa..d0535e1 100644 --- a/RLA/easy_plot/plot_util.py +++ b/RLA/easy_plot/plot_util.py @@ -305,7 +305,7 @@ def plot_results( ylabel=None, title=None, replace_legend_keys=None, - replace_legend_sort=None, + regs2legends=None, pretty=False, bound_line=None, colors=None, @@ -505,6 +505,8 @@ def allequal(qs): legend_lines = legend_lines[sorted_index] if replace_legend_keys is not None: legend_keys = np.array(replace_legend_keys) + if regs2legends is not None: + legend_keys = np.array(regs2legends) # if replace_legend_sort is not None: # sorted_index = replace_legend_sort # else: @@ -521,17 +523,17 @@ def allequal(qs): if shaded_err: res = g2lf[original_legend_keys[index] + '-se'] res[0].update(props={"color": colors[index % len(colors)]}) - print("{}-err : ({:.2f} \pm {:.2f})".format(legend_keys[index], res[1][-1], res[2][-1])) + print("{}-err : ({:.3f} $\pm$ {:.3f})".format(legend_keys[index], res[1][-1], res[2][-1])) score_results[legend_keys[index]+'-err'] = [res[1][-1], res[2][-1]] if shaded_std: res = g2lf[original_legend_keys[index] + '-ss'] res[0].update(props={"color": colors[index % len(colors)]}) - print("{}-std :({:.2f} \pm {:.2f})".format(legend_keys[index], res[1][-1], res[2][-1])) + print("{}-std :({:.3f} $\pm$ {:.3f})".format(legend_keys[index], res[1][-1], res[2][-1])) score_results[legend_keys[index]+'-std'] = [res[1][-1], res[2][-1]] if shaded_range: res = g2lf[original_legend_keys[index] + '-sr'] res[0].update(props={"color": colors[index % len(colors)]}) - print("{}-range : ({:.2f}, {:.2f})".format(legend_keys[index], res[1][-1], res[2][-1])) + print("{}-range : ({:.3f}, {:.3f})".format(legend_keys[index], res[1][-1], res[2][-1])) score_results[legend_keys[index]+'-range'] = [res[1][-1], res[2][-1]] if bound_line is not None: diff --git a/example/sb3_ppo_example/ppo/main.py b/example/sb3_ppo_example/ppo/main.py index 5bc6f91..3cd7146 100644 --- a/example/sb3_ppo_example/ppo/main.py +++ b/example/sb3_ppo_example/ppo/main.py @@ -22,7 +22,7 @@ def mujoco_arg_parser(): task_name = 'demo_task' exp_manager.set_hyper_param(**vars(args)) exp_manager.add_record_param(["info", "seed", 'env']) -exp_manager.configure(task_name, private_config_path='../rla_config.yaml', data_root='../') +exp_manager.configure(task_name, rla_config='../rla_config.yaml', data_root='../') exp_manager.log_files_gen() exp_manager.print_args() diff --git a/example/sb_ppo_example/ppo2/run_mujoco.py b/example/sb_ppo_example/ppo2/run_mujoco.py index f0fcf0a..1c0090b 100644 --- a/example/sb_ppo_example/ppo2/run_mujoco.py +++ b/example/sb_ppo_example/ppo2/run_mujoco.py @@ -64,7 +64,7 @@ def main(): task_name = 'demo_task' exp_manager.set_hyper_param(**vars(args)) exp_manager.add_record_param(["info", "seed", 'env']) - exp_manager.configure(task_name, private_config_path='../rla_config.yaml', data_root='../') + exp_manager.configure(task_name, rla_config='../rla_config.yaml', data_root='../') exp_manager.log_files_gen() exp_manager.print_args() # [RLA] optional: diff --git a/example/simplest_code/project/main.py b/example/simplest_code/project/main.py index f9934dd..c0fad91 100644 --- a/example/simplest_code/project/main.py +++ b/example/simplest_code/project/main.py @@ -28,7 +28,7 @@ def get_param(): task_name = 'demo_task' rla_data_root = '../' -exp_manager.configure(task_name, private_config_path='../rla_config.yaml', data_root=rla_data_root) +exp_manager.configure(task_name, rla_config='../rla_config.yaml', data_root=rla_data_root) exp_manager.log_files_gen() exp_manager.print_args() diff --git a/test/test_plot.py b/test/test_plot.py index bfc67a7..e895b7f 100644 --- a/test/test_plot.py +++ b/test/test_plot.py @@ -1,26 +1,85 @@ # Created by xionghuichen at 2022/8/10 # Email: chenxh@lamda.nju.edu.cn from test._base import BaseTest +import numpy as np from RLA.easy_log.log_tools import DeleteLogTool, Filter from RLA.easy_log.log_tools import ArchiveLogTool, ViewLogTool from RLA.easy_log.tester import exp_manager - +from RLA import plot_func import os class ScriptTest(BaseTest): - def test_plot(self): - from RLA import plot_func + + def get_basic_info(self): data_root = 'test_data_root' task = 'demo_task' + return data_root, task + + def test_plot_basic(self): + data_root, task = self.get_basic_info() + regs = [ '2022/03/01/21-[12]*' ] _ = plot_func(data_root=data_root, task_table_name=task, regs=regs, split_keys=['learning_rate'], metrics=['perf/mse']) + # customize the figure _ = plot_func(data_root=data_root, task_table_name=task, regs=regs, split_keys=['learning_rate'], metrics=['perf/mse'], ylim=(0, 0.1)) _ = plot_func(data_root=data_root, task_table_name=task, regs=regs, split_keys=['learning_rate'], metrics=['perf/mse'], ylim=(0, 0.1), xlabel='epochs', ylabel='reward ratio', ) + + + def test_pretty_plot(self): + data_root, task = self.get_basic_info() + + regs = [ + '2022/03/01/21-[12]*' + ] + # save image + _ = plot_func(data_root=data_root, task_table_name=task, regs=regs, split_keys=['learning_rate'], + metrics=['perf/mse'], ylim=(0, 0.1), xlabel='epochs', ylabel='reward ratio', + shaded_range=False, show_number=False, pretty=True) _ = plot_func(data_root=data_root, task_table_name=task, regs=regs, split_keys=['learning_rate'], metrics=['perf/mse'], ylim=(0, 0.1), xlabel='epochs', ylabel='reward ratio', + shaded_range=False, pretty=True, save_name='saved_image.png') + + def test_reg_map_mode(self): + # reg-map mode. + data_root, task = self.get_basic_info() + regs = [ + '2022/03/01/21-[12]*learning_rate=0.01*', + '2022/03/01/21-[12]*learning_rate=0.00*', + ] + _ = plot_func(data_root=data_root, task_table_name=task, regs=regs, split_keys=['learning_rate'], + metrics=['perf/mse'], regs2legends=['lr=0.01', 'lr<=0.001'], shaded_range=False, pretty=True) + + def test_customize_legend_name_mode(self): + data_root, task = self.get_basic_info() + regs = [ + '2022/03/01/21-[12]*' + ] + + def my_key_to_legend(parse_dict, split_keys, y_name): + + task_split_key = '.'.join(f'{k}={parse_dict[k]}' for k in split_keys) + task_split_key = task_split_key.replace('learning_rate', 'α') + return task_split_key + + _ = plot_func(data_root=data_root, task_table_name=task, regs=regs, split_keys=['learning_rate'], + metrics=['perf/mse'], + key_to_legend_fn=my_key_to_legend, + shaded_range=False, pretty=True, show_number=False) + + def test_post_process(self): + data_root, task = self.get_basic_info() + regs = [ + '2022/03/01/21-[12]*' + ] + + _ = plot_func(data_root=data_root, task_table_name=task, regs=regs, split_keys=['learning_rate'], + metrics=['perf/mse'], + scale_dict={'perf/mse': lambda x: np.log(x)}, + ylabel='RMSE', + shaded_range=False, pretty=True, show_number=False) diff --git a/test/test_proj/proj/test_manager.py b/test/test_proj/proj/test_manager.py index 8c07721..a156f53 100644 --- a/test/test_proj/proj/test_manager.py +++ b/test/test_proj/proj/test_manager.py @@ -24,7 +24,7 @@ def _init_proj(self, config_yaml, **kwargs): task_name = 'test_manger_demo_task' rla_data_root = os.path.join(DATABASE_ROOT, 'test_data_root') config_yaml['BACKUP_CONFIG']['backup_code_dir'] = ['proj'] - exp_manager.configure(task_name, private_config_path=config_yaml, data_root=rla_data_root, + exp_manager.configure(task_name, rla_config=config_yaml, data_root=rla_data_root, code_root=CODE_ROOT, **kwargs) exp_manager.log_files_gen() exp_manager.print_args() @@ -66,6 +66,8 @@ def test_log_tf(self): if i % 20 == 0: exp_manager.save_checkpoint() if i % 10 == 0: + logger.ma_record_tabular("perf/mse-long", np.mean(mse_loss.detach().cpu().numpy()), 10, freq=25) + logger.record_tabular("y_out-long", np.mean(y), freq=25) def plot_func(): import matplotlib.pyplot as plt testX = np.repeat(np.expand_dims(np.arange(-10, 10, 0.1), axis=-1), repeats=kwargs["input_size"], axis=-1) @@ -109,6 +111,8 @@ def test_log_torch(self): logger.ma_record_tabular("perf/mse", np.mean(mse_loss.detach().cpu().numpy()), 10) logger.record_tabular("y_out", np.mean(y)) if i % 10 == 0: + logger.ma_record_tabular("perf/mse-long", np.mean(mse_loss.detach().cpu().numpy()), 10, freq=25) + logger.record_tabular("y_out-long", np.mean(y), freq=25) def plot_func(): import matplotlib.pyplot as plt testX = np.repeat(np.expand_dims(np.arange(-10, 10, 0.1), axis=-1), repeats=kwargs["input_size"], axis=-1)