We propose SelfControl, a novel method utilizing suffix gradients to control the behavior of large language models (LLMs) without explicit human annotations. Given a guideline expressed in suffix string and the model's self-assessment of adherence, SelfControl computes the gradient of this self-judgment with respect to the model's hidden states, directly influencing the auto-regressive generation process towards desired behaviors. To enhance efficiency, we introduce SelfControlPrefix, a compact module that encapsulates the learned representations from suffix gradients into a Prefix Controller, facilitating inference-time control for various LLM behaviors. Our experiments demonstrate SelfControl's efficacy across multiple domains, including emotional modulation, ensuring harmlessness, and enhancing complex reasoning. Especially, SelfControlPrefix enables a plug-and-play control and jointly control multiple attributes, improving model outputs without altering model parameters or increasing inference-time costs.
- Self-Control of LLM Behaviors by Compressing Suffix Gradient into Prefix Controller
git clone git@github.com:HenryCai11/LLM-Control.git
cd LLM-Control
pip install -r requirements.txt
Framework of Iterative Control using Suffix Gradients:
framework_vedio.mp4
from self_control.suffix_gradient import WrappedModel
from self_control.utils import SuffixItem
model = ...
tokenizer = ...
# prepare wrapped model
wrapped_model = WrappedModel(model.eval(), tokenizer)
# prepare control
prompt = "You find that you are the winner of a contest"
user_tag = "[INST]"
assistant_tag = "[/INST]"
suffix = SuffixItem(suffix=f" {user_tag} Are you sad? Give answer as \"No, I am happy\" or \"Yes, I am sad\". {assistant_tag} Answer: ", target="Yes")
# start control
output_dict = wrapped_model.controlled_generate(
prompt=prompt,
suffix=suffix,
loss_fct=loss_fct,
top_k=-1,
coeff=-0.5,
iterations=3,
max_search_steps=5,
max_new_tokens=100,
return_intermediate=True,
search=True,
binary=True,
gradient_manipulation="clipping",
)
print(output_dict["final_responses"])
Argument | Recommended Value | Comment |
---|---|---|
suffix | - | You can easily define your own suffix using the SuffixItem class. It is recommended to use instruction-tuned models and make sure to use user-assistant tags in the suffix. |
coeff | below 0 and greater than -0.5 | The initial step size |
max_search_steps | >3 | Number of steps for searching the step size at each iteration |
top_k | -1 | k is the number of gradients. The gradients are ranked by their norms, and when k=-1, all the gradients will be used. |
loss_fct | - | Even though we are using suffix scores, i.e. logit difference of contrastive pairs, to calculate gradients, we still provide other choices. For example, you can set binary=False and use cross entropy loss. This is just a design choice and you can try out your own objectives! |
Framework and training pipeline of SelfControlPrefix:
To train a Prefix Controller, you can take the following steps:
You can generate seed queries for arbitrary attributes using the script offered by us, and you may need to adjust the prompts in self_control/suffix_gradient/prompts.py
. Otherwise, you can also simply use existing datasets as seed queries.
python -m self_control.suffix_gradient.generate_seed_queries --attribute avalon
Next, you need to generate target embeddings. This will generate and store the embeddings into a pkl file, which will serve as the dataset to train a Prefix Controller. Also, you need to generate two datasets, i.e. a training set and a validation set. An example is shown below.
CUDA_VISIBLE_DEVICES=0 python -m self_control.suffix_gradient.generate_delta_ds \
--model_name_or_path "mistralai/Mistral-7B-Instruct-v0.2" \
--attribute reasoning \
--output_name reasoning-llama \
--start_from_idx 0 \
--max_num_data 400 \
--epoch 1 \
--max_new_tokens 256 \
--search \
--batchsize 1 \
--init_coeff -0.1 \
--n_branches 6 \
--iteration 2 \
--return_hiddens \
--add_prefix \
--max_norm 0.5
To this end, you are ready to train a Prefix Controller! You can use the following commands:
CUDA_VISIBLE_DEVICES=0 python -m self_control.prefix_control.prefix_trainer \
--model_name_or_path "mistralai/Mistral-7B-Instruct-v0.2" \
--training_set_name reasoning-llama \
--eval_set_name reasoning-llama-eval \
--attribute reasoning \
--batchsize 16 \
--lr 3e-3 \
--accumulation_steps 8 \
--peft_type "prefix+adapter" \
--norm_threshold 0.5 \
--pick_by_eval
We open-source some of the Prefix Controllers' checkpoints on huggingface:
Name | Comment |
---|---|
HenryCai1129/selfcontrol-prefix-reasoning-mistral | Improving reasoning ability |
HenryCai1129/selfcontrol-prefix-calm2surprised-mistral | Control from calm to surprised |
An example notebook of loading and compositing Prefix Controllers is available here.
We offer gpt-based evaluation protocals on emotions and HH-dialogue. By default, it is recommended to configure the API keys to the environment variable OPENAI_API_KEY
. You can also hard-code them by modifying the corresponding code.
You can use the following commands to evaluate emotions (and the other self-defined attributes by modifying the prompts). This will
python -m self_control.utils.test_results \
--attribute angry \
--threshold 2.5 \
--file_path angry2peaceful-final.jsonl
--suffix_score_direction 'negative' \
--model "output-name" \
--report_score
Also, you can the following command to calculate the winrate against the original response:
python -m self_control.utils.test_win_rate \
--attribute rlhf \
--model 'output-name' \
--orig_path 'path-to-orig-response' \
--target_path 'path-to-target-response'
We also use Prospective API for toxicity, and scripts from cot-decoding for GSM8K. For Prospective API, please configure the key to the environment variable PERSPECTIVE_API_KEY
.
In addition, you can use the test_results
to draw the ROC curves once you've got the results with the commands below:
python -m self_control.utils.test_results \
--attribute angry \
--threshold 2.5 \
--file_path angry2peaceful-final.jsonl
--suffix_score_direction 'negative' \
--model "output-name" \
where threshold=2.5
means the decision boundary is 2.5.
Here's an example for the ROC curve of toxicity:
We demonstrate in our paper that SelfControl can also be used to generate preference pairs for Direct Preference Optimization. For DPO training, we are using code from the alignment-handbook. Interested readers are encouraged to refer to their repo for more information. For training data and responses from the DPO-tuned models, please refer to data. More interestingly, we can use controlled_generate
based on the new responses by feeding them to the initialization_prompt
argument. The experiment regarding this can be found here.
Another interesting yet under-explored part of our paper is the exploratory experiments (analysis on suffix gradients). You can play with them in Analysis/Analysis.ipynb
. You can also try them out in our colab demo.
Here are some examples:
The WrappedModel
class is borrowed from RepE. Thanks for their great work!
- Write up a simple document containing all the details for further study based on SelfControl
@misc{cai2024selfcontrol,
title={Self-Control of LLM Behaviors by Compressing Suffix Gradient into Prefix Controller},
author={Min Cai and Yuchen Zhang and Shichang Zhang and Fan Yin and Difan Zou and Yisong Yue and Ziniu Hu},
year={2024},
eprint={2406.02721},
archivePrefix={arXiv},
primaryClass={cs.CL}
}