Research Code for "ArCHer: Training Language Model Agents via Hierarchical Multi-Turn RL"
Yifei Zhou, Andrea Zanette, Jiayi Pan, Sergey Levine, Aviral Kumar
This repo supports the following online methods and offline ArCHer implementation can be found in this repo:
- ArCHer
- Online CHAI
- Online Filtered BC
And the following environments
conda create -n archer python==3.10
conda activate archer
git clone https://github.com/YifeiZhou02/ArCHer
cd ArCHer
python -m pip install -e .
python3 -m spacy download en_core_web_sm
Offline datasets and SFT checkpoints used in the paper can be found here.
Change the huggingface_token
and wandb_token
in scripts/config/default.yaml
.
Guess My City, Twenty Questions, Detective Game are directly usable by changing env_load_path
(data to use for each environment), checkpoint_path
(the SFT checkpoint to start with as provided), save_path
(required, the path to save checkpoint and replay buffer) in corresponding configurations in scripts/config
such as scripts/config/archer_20q.yaml
. For Webshop, additional installation is required in addition to modifying paths in the corresponding configuration.
You can directly run experiments with the following commands:
cd scripts
python run.py --config-name archer_20q
Different environments and method can be run with corresponding configurations.
To use the webshop env, you need to do the following setups in addition. This step can be skipped if you do not plan to use Webshop.
Go to WebShop's Github and follow the instructions to install the Webshop env
git clone https://github.com/princeton-nlp/webshop.git webshop
cd webshop
./setup.sh -d all
It turns out the provided installation guide is already outdates, so we need to do the following modifications:
pip install Werkzeug==2.2.2 pip install pydantic==1.10.11 pip install pip install --force-reinstall typing-extensions==4.5.0 beautifulsoup4
conda install mkl=2021
python -m spacy download en_core_web_lg
By default the WebShop only loads 1,000 products for a faster environment preview. To load all products, change web_agent_site/utils.py:
# DEFAULT_ATTR_PATH = join(BASE_DIR, '../data/items_ins_v2_1000.json')
# DEFAULT_FILE_PATH = join(BASE_DIR, '../data/items_shuffle_1000.json')
DEFAULT_ATTR_PATH = join(BASE_DIR, '../data/items_ins_v2.json')
DEFAULT_FILE_PATH = join(BASE_DIR, '../data/items_shuffle.json')
Then start the server at 128.0.0.1:3000
python -m web_agent_site.app --log --attrs
An additional steps is required for running experiments on Webshop.
python -m PATH_TO_WEBSHOP/web_agent_site.app --log --attrs &
cd scripts
python run.py --config-name archer_webshop
Distributed Data Parallel with Accelerate
Experiments on single GPU can be slow (e.g. ArCHer on Twenty Questions can take a week), so this codebase supports Distributed Data Parallel.
First, you will need to set up the config for accelerate by changing the accelerate config file scripts/config/accelerate_config/default_config.yaml
. Then change to run command to:
cd scripts
accelerate launch --config_file accelerate_config/default_config.yaml run.py --config-name archer_20q
cache_dir
: The cache dir for huggingface transformers (for saving pre-trained model weights etc).
huggingface_token
: (Optional) Huggingface token for logging in (access some private models such as llama2).
wandb_key
: This repo uses Weight and Biases for logging, put your wandb key here.
policy_lm
: The model name (from huggingface) for the policy language model. The main results in the paper use gpt2
.
critic_lm
: The model name (from huggingface) for the critic language model. The main results in the paper use roberta-base
.
agent_type
: The algorithm to use, currently supports archer
, chai
, and online_filteredbc
.
use_baseline
: Whether or not to train a separate model as token-level baseline. Will be added soon.
use_lora
: Whether or not to use lora for policy language model.
max_new_tokens
: Maximum number of tokens to generate at each turn from the policy language model.
save_freq
: Number of iterations to save all models and optimizers weights.
eval_freq
: Number of times to do deterministic evaluations.
capacity
: Number of utterance-level interaction tuples (s,a,r,s') that can be saved in the replay buffer.
rollout_size
: Number of trajectories to collect for each iteration.
eval_size
: Number of trajectories to evaluate on for each evaluation.
batch_size
: Training batch size (same both for the actor and critic).
iterations
: Number of total iterations.
epochs
: The number of critic gradient steps for each iteration.
actor_epochs
: The number of actor gradient steps for each iteration.
warmup_iter
: Number of warming up iterations only updating the critic (i.e. the actor is not updated).
grad_accum_steps
: Number of gradient accumulation steps. Note that gradients are not normalized with respect to gradient accumulation steps, so the effective learning rate is learning_rate*grad_accum_steps
.
do_sample
: Whether or not sampling is used for rolling out trajectories to collect data.
temperature
: The temperature when sampling from the policy language model.
critic_lr
: The (unnormalized) learning rate for the critic, please also see grad_accum_steps
.
lm_lr
: The (unnormalized) learning rate for the actor, please also see grad_accum_steps
.
gamma
: The discount factor.
tau
: Polyak constant for soft updating the target network.
max_grad_norm
: Maximum gradient norm clipping threshold after gradient accumulation.
use_wandb
: Whether or not to use Weights and Biases.
checkpoint_path
: The path to the SFT checkpoint to start with, as provided above.
save_path
: (Required) The path to save replay buffer and training checkpoints.
env_name
: Which environment to use, currently supporting: twenty_questions
, guess_my_city
, adventure
(can be used for Detective Game), and webshop
.
env_load_path
: The path where the data for environment comes from (different for each environment, see example configs for each environment).
project_name
: Weights and Biases project name.
Our default configuration runs with GPT2, but it also supports running state-of-the-art LLMs such as Mistral-7B. If you have a machine with larger RAM, simply try:
cd scripts
python run.py --config-name archer_llm_20q
@misc{zhou2024archer,
title={ArCHer: Training Language Model Agents via Hierarchical Multi-Turn RL},
author={Yifei Zhou and Andrea Zanette and Jiayi Pan and Sergey Levine and Aviral Kumar},
year={2024},
eprint={2402.19446},
archivePrefix={arXiv},
primaryClass={cs.LG}
}