Skip to content

Commit

Permalink
update rlhf experience generation
Browse files Browse the repository at this point in the history
  • Loading branch information
sunzeyeah committed May 19, 2023
1 parent 5fc2128 commit fa5457d
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 68 deletions.
183 changes: 118 additions & 65 deletions src/models/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,13 +1279,15 @@ def __init__(self, rlhf_engine, args):
self.gamma = 1.0
self.lam = 0.95

def _generate_sequence(self, inputs):
# max_min_length = self.max_answer_seq_len + prompts.shape[1]
def generate_sequence(self, inputs):
self.eval()
print_gpu_utilization("generate_sequence - before model.generate", self.args.local_rank)
print_gpu_utilization_torch("generate_sequence - before model.generate", self.args.local_rank)
batch_size = inputs['input_ids'].shape[0]
prompt_length = inputs['input_ids'].shape[-1]

with torch.no_grad():
logger.debug(f"[_generate_sequence] inputs: {inputs}")
logger.debug(f"[generate_sequence] inputs: {inputs}")
prompts = []
answers = []
outputs = dict()
Expand Down Expand Up @@ -1313,7 +1315,7 @@ def _generate_sequence(self, inputs):
num_return_sequences=self.args.num_return_sequences,
top_p=self.args.top_p,
temperature=self.args.temperature)
logger.debug(f"[_generate_sequence] seq: {seq}")
logger.debug(f"[generate_sequence] seq: {seq}")
for output_ids in seq:
answer = self.tokenizer.decode(output_ids[prompt_length:], skip_special_tokens=True)
prompts.append(prompt)
Expand Down Expand Up @@ -1362,17 +1364,18 @@ def _generate_sequence(self, inputs):
if "pangu" in self.args.actor_model_path:
outputs = self.tokenizer(prompts, max_length=self.args.max_length,
padding="max_length", return_tensors="pt", return_token_type_ids=False)
logger.debug(f"[_generate_sequence] outputs['input_ids'].shape: {outputs['input_ids'].shape}, outputs: {outputs}")
logger.debug(f"[generate_sequence] outputs['input_ids'].shape: {outputs['input_ids'].shape}, outputs: {outputs}")
elif "chatglm" in self.args.actor_model_path:
outputs = self.tokenizer(prompts, answers, max_length=self.args.max_length,
padding="max_length", return_tensors="pt")
logger.debug(f"[_generate_sequence] outputs['input_ids'].shape: {outputs['input_ids'].shape}, outputs: {outputs}")
logger.debug(f"[generate_sequence] outputs['input_ids'].shape: {outputs['input_ids'].shape}, outputs: {outputs}")
elif "glm" in self.args.actor_model_path:
outputs = {key: torch.stack(val) for key, val in outputs.items()}
logger.debug(f"[_generate_sequence] outputs['input_ids'].shape: {outputs['input_ids'].shape}, outputs: {outputs}")
logger.debug(f"[generate_sequence] outputs['input_ids'].shape: {outputs['input_ids'].shape}, outputs: {outputs}")
else:
raise ValueError(f"Unsupported model name: {self.args.actor_model_path}")

print_gpu_utilization("generate_sequence - after model.generate", self.args.local_rank)
print_gpu_utilization_torch("generate_sequence - after model.generate", self.args.local_rank)
# Filter out seq with no asnwers (or very short). This happens when users directly use the pre-training ckpt without supervised finetuning
# NOTE: this will causes each GPU has different number of examples
# ans = seq[:, prompt_length:]
Expand All @@ -1386,25 +1389,22 @@ def _generate_sequence(self, inputs):
# else:
# out_seq.append(seq[i:i + 1])
# out_seq = torch.cat(out_seq, dim=0) # concat output in the batch dim
# logger.debug(f"[_generate_sequence] out_seq: {out_seq}")
# logger.debug(f"[generate_sequence] out_seq: {out_seq}")

return outputs
return outputs, prompt_length

def generate_experience(self, inputs):
print_gpu_utilization("generate_experience - before self._generate_experience", self.args.local_rank)
print_gpu_utilization_torch("generate_experience - before self._generate_experience", self.args.local_rank)
def generate_experience(self, output_sequences, answer_start_indices, device):
self.eval()
outputs = self._generate_sequence(inputs)
self.train()
print_gpu_utilization("generate_experience - after self._generate_experience", self.args.local_rank)
print_gpu_utilization_torch("generate_experience - after self._generate_experience", self.args.local_rank)
print_gpu_utilization("generate_experience - before call actor and critic", self.args.local_rank)
print_gpu_utilization_torch("generate_experience - before call actor and critic", self.args.local_rank)

# pad_token_id = self.tokenizer.pad_token_id
# attention_mask = seq.not_equal(pad_token_id).long()
device = inputs['input_ids'].device
input_ids = outputs['input_ids'].to(device)
attention_mask = outputs['attention_mask'].to(device) if "attention_mask" in outputs else None
position_ids = outputs['position_ids'].to(device) if "position_ids" in outputs else None
input_ids = output_sequences['input_ids'].to(device)
attention_mask = output_sequences['attention_mask'].to(device) if "attention_mask" in output_sequences else None
position_ids = output_sequences['position_ids'].to(device) if "position_ids" in output_sequences else None
print_gpu_utilization("generate_experience - after setting output_sequences device", self.args.local_rank)
print_gpu_utilization_torch("generate_experience - after setting output_sequences device", self.args.local_rank)

with torch.no_grad():
output = self.actor_model(input_ids, attention_mask=attention_mask, position_ids=position_ids)
Expand All @@ -1424,7 +1424,8 @@ def generate_experience(self, inputs):
logits_ref = output_ref.logits

return {
'prompts': inputs['input_ids'],
# 'prompts': inputs['input_ids'],
'answer_start_indices': answer_start_indices,
'input_ids': input_ids,
'attention_mask': attention_mask,
'position_ids': position_ids,
Expand All @@ -1434,76 +1435,100 @@ def generate_experience(self, inputs):
'rewards': reward_score
}

def compute_rewards(self, prompts, log_probs, ref_log_probs, reward_score, action_mask):
logger.debug(f"[compute_rewards] prompts: {prompts.shape}, log_probs: {log_probs.shape}, ref_log_probs: {ref_log_probs.shape}, "
def compute_rewards(self, starts, log_probs, ref_log_probs, reward_score, action_mask):
'''
:param starts: List of indices of the starting index of answer
:param log_probs: shape=batch_size * (max_length-1)
:param ref_log_probs: shape=batch_size * (max_length-1)
:param reward_score: shape=batch_size
:param action_mask: shape=batch_size * (answer_length)
:return:
'''
logger.debug(f"[compute_rewards] log_probs: {log_probs.shape}, ref_log_probs: {ref_log_probs.shape}, "
f"reward_score: {reward_score.shape}, action_mask: {action_mask.shape}")
kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_log_probs)
rewards = kl_divergence_estimate
logger.debug(f"before rewards: {rewards.shape}")
start = prompts.shape[1] - 1
ends = start + action_mask.sum(1)
# start = prompts.shape[1] - 1
# ends = start + action_mask.sum(1)
sums = action_mask.sum(1)
reward_clip = torch.clamp(reward_score, -self.clip_reward_value,
self.clip_reward_value)
batch_size = log_probs.shape[0]
for j in range(batch_size):
logger.debug(f"j={j}, ends[j]={ends[j]}, rewards[j, start:ends[j]]: {rewards[j, start:ends[j]].shape}")
rewards[j, start:ends[j]][-1] += reward_clip[j]
logger.debug(f"j={j}, sums[j]={sums[j]}, rewards[j, start:ends[j]]: {rewards[j, starts[j]:(starts[j]+sums[j])].shape}")
rewards[j, starts[j]:(starts[j]+sums[j])][-1] += reward_clip[j]
logger.debug(f"after rewards: {rewards.shape}")
return rewards

def train_rlhf(self, inputs):
# train the rlhf mode here
### process the old outputs
prompts = inputs['prompts']
log_probs = inputs['logprobs']
ref_log_probs = inputs['ref_logprobs']
reward_score = inputs['rewards']
values = inputs['value']
attention_mask = inputs['attention_mask']
position_ids = inputs['position_ids']
input_ids = inputs['input_ids']

start = prompts.size()[-1] - 1
# process the old outputs
answer_start_indices = inputs['answer_start_indices']
log_probs = inputs['logprobs'] # shape=batch_size * (max_length-1)
ref_log_probs = inputs['ref_logprobs'] # shape=batch_size * (max_length-1)
reward_score = inputs['rewards'] # shape=batch_size
values = inputs['value'] # shape=batch_size * (max_length-1)
attention_mask = inputs['attention_mask'] # shape=batch_size * max_length or shape=batch_size * max_length * max_length
position_ids = inputs['position_ids'] # shape=batch_size * 2 * max_length
input_ids = inputs['input_ids'] # shape=batch_size * max_length
logger.debug(f"[train_rlhf] answer_start_indices: {answer_start_indices}, "
f"log_probs shape: {log_probs.shape}, ref_log_probs shape: {ref_log_probs.shape}, "
f"reward_score shape: {reward_score.shape}, values shape: {values.shape}, "
f"attention_mask shape: {attention_mask.shape if attention_mask is not None else None},"
f"position_ids shape: {position_ids.shape if position_ids is not None else None},"
f"input_ids shape: {input_ids.shape}")

batch_size = input_ids.size()[0]
if attention_mask is not None and len(attention_mask.shape) == 2:
action_mask = attention_mask[:, 1:][:, start:]
# action_mask = attention_mask[:, 1:][:, start:]
action_mask = attention_mask[:, 1:]
else:
answer_ids = input_ids[:, 1:][:, start:]
batch_size = answer_ids.shape[0]
answer_length = answer_ids.shape[-1]
# answer_ids = input_ids[:, 1:][:, start:]
# batch_size = answer_ids.shape[0]
# answer_length = answer_ids.shape[-1]
answer_length = input_ids.shape[-1] - 1
action_mask = torch.ones((batch_size, answer_length), dtype=torch.long, device=input_ids.device)
for i, j in (answer_ids == self.tokenizer.pad_token_id).nonzero():
for i, j in (input_ids[:, 1:] == self.tokenizer.pad_token_id).nonzero():
action_mask[i, j] = 0
for i in range(batch_size):
# set mask of prompt to 0
action_mask[i, :answer_start_indices[i]] = 0
logger.debug(f"[train_rlhf] action_mask shape: {action_mask.shape}")

# compute advantages and returns
print_gpu_utilization("train_rlhf - before compute reward and advantages", self.args.local_rank)
print_gpu_utilization_torch("train_rlhf - before compute reward and advantages", self.args.local_rank)
old_values = values
with torch.no_grad():
old_rewards = self.compute_rewards(prompts, log_probs,
old_rewards = self.compute_rewards(answer_start_indices, log_probs,
ref_log_probs, reward_score,
action_mask)
advantages, returns = self.get_advantages_and_returns(old_values, old_rewards, start)
advantages, returns = self.get_advantages_and_returns(old_values, old_rewards, answer_start_indices)
logger.debug(f"[train_rlhf] old_rewards shape: {old_rewards.shape}, advantages shape: {advantages.shape}, returns shape: {returns.shape}")
print_gpu_utilization("train_rlhf - after compute reward and advantages", self.args.local_rank)
print_gpu_utilization_torch("train_rlhf - after compute reward and advantages", self.args.local_rank)

### process the new outputs
# update actor and critic
self.train()
batch = {'input_ids': input_ids, "attention_mask": attention_mask, "position_ids": position_ids}
actor_prob = self.actor_model(**batch, use_cache=False).logits
actor_prob = self.actor_model(**batch, use_cache=False).logits # shape=batch_size * max_length * vocab_size
print_gpu_utilization("train_rlhf - after self.actor_model", self.args.local_rank)
print_gpu_utilization_torch("train_rlhf - after self.actor_model", self.args.local_rank)
actor_log_prob = gather_log_probs(actor_prob[:, :-1, :], input_ids[:, 1:])
actor_loss = self.actor_loss_fn(actor_log_prob[:, start:],
log_probs[:, start:], advantages,
actor_loss = self.actor_loss_fn(actor_log_prob,
log_probs, advantages,
action_mask)
self.actor_model.backward(actor_loss)
print_gpu_utilization("train_rlhf - after actor backward", self.args.local_rank)
print_gpu_utilization_torch("train_rlhf - after actor backward", self.args.local_rank)
self.actor_model.step()
print_gpu_utilization("train_rlhf - after actor step", self.args.local_rank)
print_gpu_utilization_torch("train_rlhf - after actor step", self.args.local_rank)
value = self.critic_model.reward(**batch, use_cache=False)[0][:, :-1]
value = self.critic_model.reward(**batch, use_cache=False)[0][:, :-1] # shape=batch_size * (max_length-1)
print_gpu_utilization("train_rlhf - after self.critic_model", self.args.local_rank)
print_gpu_utilization_torch("train_rlhf - after self.critic_model", self.args.local_rank)
critic_loss = self.critic_loss_fn(value[:, start:], old_values[:, start:],
critic_loss = self.critic_loss_fn(value, old_values,
returns, action_mask)
self.critic_model.backward(critic_loss)
print_gpu_utilization("train_rlhf - after critic backward", self.args.local_rank)
Expand Down Expand Up @@ -1539,20 +1564,48 @@ def critic_loss_fn(self, values, old_values, returns, mask):
torch.max(vf_loss1, vf_loss2) * mask) / mask.sum()
return vf_loss

def get_advantages_and_returns(self, values, rewards, start):
def get_advantages_and_returns(self, values, rewards, starts):
'''
:param values: shape=batch_size * (max_length-1)
:param rewards: shape=batch_size * (max_length-1)
:param start: List of indices of the starting index of answer
:return:
'''
# Generalized advantage estimation (https://arxiv.org/abs/1707.06347)
logger.debug(f"[get_advantages_and_returns] values: {values.shape}, rewards: {rewards.shape}, start: {start}")
lastgaelam = 0
advantages_reversed = []
logger.debug(f"[get_advantages_and_returns] values: {values.shape}, rewards: {rewards.shape}, starts: {starts}")
batch_size = rewards.size()[0]
length = rewards.size()[-1]
for t in reversed(range(start, length)):
nextvalues = values[:, t + 1] if t < length - 1 else 0.0
delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]
lastgaelam = delta + self.gamma * self.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
logger.debug(f"advantages: {advantages.shape}, values[:, start:]: {values[:, start:].shape}")
returns = advantages + values[:, start:]

# lastgaelam = 0
# advantages_reversed = []
# for t in reversed(range(start, length)):
# nextvalues = values[:, t + 1] if t < length - 1 else 0.0
# delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]
# lastgaelam = delta + self.gamma * self.lam * lastgaelam
# advantages_reversed.append(lastgaelam)
# advantages = torch.stack(advantages_reversed[::-1], dim=1)
# logger.debug(f"advantages: {advantages.shape}, values[:, start:]: {values[:, start:].shape}")
# returns = advantages + values[:, start:]

advantages = []
returns = []
for i in range(batch_size):
lastgaelam = 0
advantages_reversed = []
for t in reversed(range(starts[i], length)):
nextvalues = values[i, t + 1] if t < length - 1 else 0.0
delta = rewards[i, t] + self.gamma * nextvalues - values[i, t]
lastgaelam = delta + self.gamma * self.lam * lastgaelam
advantages_reversed.append(lastgaelam)
# set advantage of prompt to 0 (will be ignored when multiplied with action_mask)
advantages_reversed.extend([0]*starts[i])
advantage = torch.tensor(advantages_reversed[::-1], device=values.device, dtype=values.dtype)
advantages.append(advantage)
returns.append(advantage + values[i])
advantages = torch.stack(advantages)
returns = torch.stack(returns)

return advantages.detach(), returns

def _validate_training_mode(self):
Expand Down
19 changes: 16 additions & 3 deletions src/train_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,16 +270,29 @@ def main():
prompt_iter = iter(prompt_train_dataloader)
pretrain_iter = iter(pretrain_dataloader)
step = 0
while True:
# for step, (batch_prompt, batch_pretrain) in enumerate(zip(prompt_train_dataloader, pretrain_dataloader)):
while True:
# generate sequence: generate only one sequence at a time, aggregate to form a batch
answer_start_indices = []
output_sequences = dict()
for _ in range(args.train_batch_size):
try:
batch_prompt = next(prompt_iter)
batch_prompt = {k: v.to(device) for k, v in batch_prompt.items()}
out = trainer.generate_experience(batch_prompt)
exp_dataset = exp_mini_dataset.add(out)
outputs, prompt_length = trainer.generate_sequence(batch_prompt)
answer_start_indices.append(prompt_length-1)
for key, val in outputs.items():
if key not in output_sequences:
output_sequences[key] = []
output_sequences[key].append(val[0])
except StopIteration:
break
if len(output_sequences) > 0:
output_sequences = {key: torch.stack(val) for key, val in output_sequences.items()}
output_experiences = trainer.generate_experience(output_sequences, answer_start_indices, device)
exp_dataset = exp_mini_dataset.add(output_experiences)
else:
exp_dataset = None

try:
batch_pretrain = next(pretrain_iter)
Expand Down

0 comments on commit fa5457d

Please sign in to comment.