Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #14

Merged
merged 35 commits into from
Oct 4, 2022
Merged

Dev #14

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
0d25abf
Merge pull request #19 from xionghuichen/dev
xionghuichen Jun 22, 2022
cf30b18
Update README.md
xionghuichen Jun 22, 2022
33e6aee
Dev (#20)
xionghuichen Jul 13, 2022
74c7712
fix: minor changes for version compatibility
xionghuichen Jul 14, 2022
e36a1bc
Merge branch 'main' of github.com:xionghuichen/RLAssistant into dev
xionghuichen Jul 14, 2022
3bdbf3e
Dev (#21)
xionghuichen Jul 14, 2022
96ba639
fix: a bug of sorting in torch-version checkpoint loading
xionghuichen Jul 14, 2022
680c6be
Dev (#22)
xionghuichen Jul 14, 2022
44c2cff
refactor: robust multi-key plot implementation
xionghuichen Jul 21, 2022
ce25856
feat: supoort pretty plotter
xionghuichen Jul 23, 2022
efc4815
refactor(log plotter): record scores of the log plotter
xionghuichen Jul 23, 2022
3fca57c
fix(exp_loader): add parameter ckp_index
xionghuichen Jul 23, 2022
017efc2
refactor(rla_script): add start_server to start_pretty_plotter.py
xionghuichen Jul 23, 2022
7c0b7dc
update readme
xionghuichen Jul 23, 2022
4239bd6
Merge branch 'main' of github.com:xionghuichen/RLAssistant into dev
xionghuichen Jul 23, 2022
29e4932
Dev (#23)
xionghuichen Jul 23, 2022
5a0c180
Merge branch 'main' of github.com:xionghuichen/RLAssistant into dev
xionghuichen Jul 27, 2022
d3ff59d
rm unsolved merge
xionghuichen Jul 27, 2022
f67a10a
Dev (#24)
xionghuichen Jul 27, 2022
65d2859
feat: tf-v2 compatible
xionghuichen Jul 27, 2022
845f3ab
refactor: add timestep recorder. refactor on exp_loader
xionghuichen Aug 10, 2022
9f799c3
test: add test data
xionghuichen Aug 10, 2022
a277416
feat(plot): track the hyper-parameter from the exp_manager instead of…
xionghuichen Aug 10, 2022
e9b29a7
Dev (#25)
xionghuichen Aug 10, 2022
00d26d2
Merge branch 'main' of github.com:polixir/RLAssistant into dev
xionghuichen Aug 10, 2022
f71bdcd
test(plot): add user cases and documents
xionghuichen Aug 11, 2022
a5f88dd
test(plot): add user cases
xionghuichen Aug 11, 2022
23349cd
Merge branch 'main' of github.com:xionghuichen/RLAssistant into dev
xionghuichen Aug 11, 2022
26ee5d3
Merge branch 'main' of github.com:polixir/RLAssistant into dev
xionghuichen Aug 11, 2022
bf87b78
Dev (#26)
xionghuichen Aug 11, 2022
4efee65
Update README.md (#28)
xionghuichen Aug 11, 2022
6140b9f
simplify codes
xionghuichen Sep 12, 2022
e9fb6fd
refactor: more robust freq print implementation
xionghuichen Sep 14, 2022
bf3f65b
update readme
xionghuichen Sep 25, 2022
8fdc403
update readme
xionghuichen Oct 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,17 @@ Modifying:
Querying:
1. tensorboard: the recorded variables are added to tensorboard events and can be loaded via standard tensorboard tools.
![img.png](resource/tb-img.png)
2. easy_plot: We give some APIs to load and visualize the data in CSV files. The results will be something like this:
2. easy_plot: We give some APIs to load and visualize the data in CSV files. The results will be something like this:
```python
from RLA.easy_plot.plot_func_v2 import plot_func
data_root='your_project'
task = 'sac_test'
regs = [
'2022/03/01/21-[12]*'
]
_ = plot_func(data_root=data_root, task_table_name=task,
regs=regs , split_keys=['info', 'van_sac', 'alpha_max'], metrics=['perf/rewards'])
```
![](resource/sample-plot.png)


Expand All @@ -86,7 +96,7 @@ We also list the RL research projects using RLA as follows:
```angular2html
git clone https://github.com/xionghuichen/RLAssistant.git
cd RLAssistant
python setup.py install
pip install -e .
```


Expand All @@ -99,13 +109,19 @@ We build an example project for integrating RLA, which can be seen in ./example/
1. We define the property of the database in `rla_config.yaml`. You can construct your YAML file based on the template in ./example/simplest_code/rla_config.yaml.
2. We define the property of the table in exp_manager.config. Before starting your experiment, you should configure the global object RLA.easy_log.tester.exp_manager like this.
```python
from RLA.easy_log.tester import exp_manager
from RLA import exp_manager
kwargs = {'env_id': 'Hopper-v2', 'lr': 1e-3}
exp_manager.set_hyper_param(**kwargs) # kwargs are the hyper-parameters for your experiment
exp_manager.add_record_param(["env_id"]) # add parts of hyper-parameters to name the index of data items for better readability.
task_name = 'demo_task' # define your task
rla_data_root = '../' # the place to store the data items.
exp_manager.configure(task_name, private_config_path='../../../rla_config.yaml', data_root=rla_data_root)

def get_package_path():
return os.path.dirname(os.path.abspath(__file__))

rla_data_root = get_package_path() # the place to store the data items.

rla_config = os.path.join(get_package_path(), 'rla_config.yaml')
exp_manager.configure(task_table_name=task_name, rla_config=rla_config, data_root=rla_data_root)
exp_manager.log_files_gen() # initialize the data items.
exp_manager.print_args()
```
Expand All @@ -124,9 +140,9 @@ We build an example project for integrating RLA, which can be seen in ./example/

We record scalars by `RLA.easy_log.logger`:
```python
from RLA.easy_log import logger
from RLA import logger
import tensorflow as tf
from RLA.easy_log.time_step import time_step_holder
from RLA import time_step_holder

for i in range(1000):
# time-steps (iterations)
Expand All @@ -143,7 +159,7 @@ for i in range(1000):

We save checkpoints of neural networks by `exp_manager.save_checkpoint`.
```python
from RLA.easy_log.tester import exp_manager
from RLA import exp_manager
exp_manager.new_saver()

for i in range(1000):
Expand All @@ -157,7 +173,7 @@ Currently we can record complex-structure data based on tensorboard:
```python
# from tensorflow summary
import tensorflow as tf
from RLA.easy_log import logger
from RLA import logger
summary = tf.Summary()
logger.log_from_tf_summary(summary)
# from tensorboardX writer
Expand All @@ -169,7 +185,7 @@ We will develop APIs to record common-used complex-structure data in RLA.easy_lo
Now we give a MatplotlibRecorder tool to manage your figures generated by matplotlib:

```python
from RLA.easy_log.complex_data_recorder import MatplotlibRecorder as mpr
from RLA import MatplotlibRecorder as mpr
def plot_func():
import matplotlib.pyplot as plt
plt.plot([1,1,1], [2,2,2])
Expand All @@ -191,7 +207,7 @@ Result visualization:
For example, lanuch tensorboard by `tensorboard --logdir ./example/simplest_code/log/demo_task/2022/03`. We can see results:
![img.png](resource/demo-tb-res.png)
2. Easy_plot toolkit: The intermediate scalar variables are saved in a CSV file in `${data_root}/log/${task_name}/${index_name}/progress.csv`.
We develop high-level APIs to load the CSV files from multiple experiments and group the lines by custom keys. We give an example to use easy_plot toolkit in example/plot_res.ipynb.
We develop high-level APIs to load the CSV files from multiple experiments and group the lines by custom keys. We give an example to use easy_plot toolkit in https://github.com/xionghuichen/RLAssistant/blob/main/example/plot_res.ipynb and more user cases in https://github.com/xionghuichen/RLAssistant/blob/main/test/test_plot.py
The result will be something like this:
![img.png](resource/demo-easy-to-plot-res.png)
3. View data in "results" directory directly: other type of data are stored in `${data_root}/results/${task_name}/${index_name}`
Expand All @@ -217,7 +233,7 @@ We manage the items in the database via toolkits in rla_scripts. Currently, the
3. Send to remote [TODO]
4. Download from remote [TODO]

We can use the above tools after copying the rla_scripts to our research project and modifying the DATA_ROOT in config.py to locate the root of the RLA database.
We can use the above tools after copying the rla_scripts to our research project and modifying the DATA_ROOT in config.py to locate the root of the RLA database. We give several user cases in https://github.com/xionghuichen/RLAssistant/blob/main/test/test_scripts.py


## Distributed training & centralized logs
Expand Down
4 changes: 3 additions & 1 deletion RLA/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from RLA.easy_log.tester import exp_manager
from RLA.easy_log import logger
from RLA.easy_plot.plot_func_v2 import plot_func
from RLA.easy_log.time_step import time_step_holder
from RLA.easy_plot.plot_func_v2 import plot_func
from RLA.easy_log.complex_data_recorder import MatplotlibRecorder
6 changes: 3 additions & 3 deletions RLA/easy_log/exp_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from RLA.easy_log.tester import exp_manager, Tester
import copy
import argparse
from typing import Optional, OrderedDict, Union, Dict, Any
from typing import Optional
from RLA.const import DEFAULT_X_NAME
from pprint import pprint

Expand Down Expand Up @@ -88,10 +88,10 @@ def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list:
load_res = {}
if var_prefix is not None:
loaded_tester.new_saver(var_prefix=var_prefix, max_to_keep=1)
_, load_res = loaded_tester.load_checkpoint()
_, load_res = loaded_tester.load_checkpoint(ckp_index)
else:
loaded_tester.new_saver(max_to_keep=1)
_, load_res = loaded_tester.load_checkpoint()
_, load_res = loaded_tester.load_checkpoint(ckp_index)
hist_variables = {}
if variable_list is not None:
for v in variable_list:
Expand Down
16 changes: 11 additions & 5 deletions RLA/easy_log/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def timestep():
ma_dict = {}


def ma_record_tabular(key, val, record_len, ignore_nan=False, exclude:Optional[Union[str, Tuple[str, ...]]]=None):
def ma_record_tabular(key, val, record_len, ignore_nan=False, exclude:Optional[Union[str, Tuple[str, ...]]]=None, freq:Optional[int]=None):
if key not in ma_dict:
ma_dict[key] = deque(maxlen=record_len)
if ignore_nan:
Expand All @@ -415,7 +415,10 @@ def ma_record_tabular(key, val, record_len, ignore_nan=False, exclude:Optional[U
else:
ma_dict[key].append(val)
if len(ma_dict[key]) == record_len:
record_tabular(key, np.mean(ma_dict[key]), exclude)
record_tabular(key, np.mean(ma_dict[key]), exclude, freq)


lst_print_dict = {}

def logkv(key, val, exclude:Optional[Union[str, Tuple[str, ...]]]=None, freq:Optional[int]=None):
"""
Expand All @@ -426,8 +429,11 @@ def logkv(key, val, exclude:Optional[Union[str, Tuple[str, ...]]]=None, freq:Opt
:param key: (Any) save to log this key
:param val: (Any) save to log this value
"""
if freq is None or timestep() % freq == 0:
if key not in lst_print_dict:
lst_print_dict[key] = -np.inf
if freq is None or timestep() - lst_print_dict[key] >= freq:
get_current().logkv(key, val, exclude)
lst_print_dict[key] = timestep()


def log_from_tf_summary(summary):
Expand Down Expand Up @@ -463,12 +469,12 @@ def logkv_mean(key, val):
"""
get_current().logkv_mean(key, val)

def logkvs(d, exclude:Optional[Union[str, Tuple[str, ...]]]=None):
def logkvs(d, prefix:Optional[str]='', exclude:Optional[Union[str, Tuple[str, ...]]]=None):
"""
Log a dictionary of key-value pairs
"""
for (k, v) in d.items():
logkv(k, v, exclude)
logkv(prefix+k, v, exclude)


def log_key_value(keys, values, prefix_name=''):
Expand Down
66 changes: 55 additions & 11 deletions RLA/easy_log/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import os.path as osp
import pprint

import tensorboardX

from RLA.easy_log.time_step import time_step_holder
from RLA.easy_log import logger
Expand Down Expand Up @@ -134,7 +133,7 @@ def configure(self, task_table_name: str, rla_config: Union[str, dict], data_roo
:param is_master_node: In "distributed training & centralized logs" mode (By set SEND_LOG_FILE in rla_config.yaml to True),
you should mark the master node (is_master_node=True) to collect logs of the slave nodes (is_master_node=False).
:type is_master_node: bool
: param code_root: Define the root of your codebase (for backup) explicitly. It will be in the same location as rla_config.yaml by default.
:param code_root: Define the root of your codebase (for backup) explicitly. It will be in the same location as rla_config.yaml by default.
"""
if isinstance(rla_config, str):
self.private_config = load_yaml(rla_config)
Expand Down Expand Up @@ -543,11 +542,31 @@ def update_fph(self, cum_epochs):
# self.last_record_fph_time = cur_time
logger.dump_tabular()

def time_record(self, name):
def time_record(self, name:str):
"""
[deprecated] see RLA.easy_log.time_used_recorder
record the consumed time of your code snippet. call this function to start a recorder.
"name" is identifier to distinguish different recorder and record different snippets at the same time.
call time_record_end to end a recorder.
:param name: identifier of your code snippet.
:type name: str
:return:
:rtype:
"""
assert name not in self._rc_start_time
self._rc_start_time[name] = time.time()

def time_record_end(self, name):
def time_record_end(self, name:str):
"""
[deprecated] see RLA.easy_log.time_used_recorder
record the consumed time of your code snippet. call this function to start a recorder.
"name" is identifier to distinguish different recorder and record different snippets at the same time.
call time_record_end to end a recorder.
:param name: identifier of your code snippet.
:type name: str
:return:
:rtype:
"""
end_time = time.time()
start_time = self._rc_start_time[name]
logger.record_tabular("time_used/{}".format(name), end_time - start_time)
Expand All @@ -566,23 +585,46 @@ def new_saver(self, max_to_keep, var_prefix=None):
import tensorflow as tf
if var_prefix is None:
var_prefix = ''
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, var_prefix)
logger.info("save variable :")
for v in var_list:
logger.info(v)
self.saver = tf.train.Saver(var_list=var_list, max_to_keep=max_to_keep, filename=self.checkpoint_dir, save_relative_paths=True)
try:
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, var_prefix)
logger.info("save variable :")
for v in var_list:
logger.info(v)
self.saver = tf.train.Saver(var_list=var_list, max_to_keep=max_to_keep, filename=self.checkpoint_dir,
save_relative_paths=True)

except AttributeError as e:
self.max_to_keep = max_to_keep
# tf.compat.v1.disable_eager_execution()
# tf = tf.compat.v1
# var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, var_prefix)
elif self.dl_framework == FRAMEWORK.torch:
self.max_to_keep = max_to_keep
else:
raise NotImplementedError

def save_checkpoint(self, model_dict: Optional[dict]=None, related_variable: Optional[dict]=None):
def save_checkpoint(self, model_dict: Optional[dict] = None, related_variable: Optional[dict] = None):
if self.dl_framework == FRAMEWORK.tensorflow:
import tensorflow as tf
iter = self.time_step_holder.get_time()
cpt_name = osp.join(self.checkpoint_dir, 'checkpoint')
logger.info("save checkpoint to ", cpt_name, iter)
self.saver.save(tf.get_default_session(), cpt_name, global_step=iter)
try:
self.saver.save(tf.get_default_session(), cpt_name, global_step=iter)
except AttributeError as e:
if model_dict is None:
logger.warn("call save_checkpoints without passing a model_dict")
return
if self.checkpoint_keep_list is None:
self.checkpoint_keep_list = []
iter = self.time_step_holder.get_time()
# tf.compat.v1.disable_eager_execution()
# tf = tf.compat.v1
# self.saver.save(tf.get_default_session(), cpt_name, global_step=iter)

tf.train.Checkpoint(**model_dict).save(tester.checkpoint_dir + "checkpoint-{}".format(iter))
self.checkpoint_keep_list.append(iter)
self.checkpoint_keep_list = self.checkpoint_keep_list[-1 * self.max_to_keep:]
elif self.dl_framework == FRAMEWORK.torch:
import torch
if self.checkpoint_keep_list is None:
Expand All @@ -602,6 +644,7 @@ def save_checkpoint(self, model_dict: Optional[dict]=None, related_variable: Opt
for k, v in related_variable.items():
self.add_custom_data(k, v, type(v), mode='replace')
self.add_custom_data(DEFAULT_X_NAME, time_step_holder.get_time(), int, mode='replace')
self.serialize_object_and_save()

def load_checkpoint(self, ckp_index=None):
if self.dl_framework == FRAMEWORK.tensorflow:
Expand All @@ -613,6 +656,7 @@ def load_checkpoint(self, ckp_index=None):
ckpt_path = tf.train.latest_checkpoint(cpt_name)
else:
ckpt_path = tf.train.latest_checkpoint(cpt_name, ckp_index)
logger.info("load ckpt_path {}".format(ckpt_path))
self.saver.restore(tf.get_default_session(), ckpt_path)
max_iter = ckpt_path.split('-')[-1]
return int(max_iter), None
Expand Down
Loading