-
Notifications
You must be signed in to change notification settings - Fork 3
/
model.py
117 lines (95 loc) · 4.61 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import math
from typing import List, Optional, Tuple, Union
from pprint import pprint
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
from transformers import LlamaModel, LlamaForCausalLM, LlamaPreTrainedModel, LlamaTokenizer
from transformers import BertModel, BertPreTrainedModel
class LlamaRewardModel(LlamaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
self.reward_head = nn.Linear(config.hidden_size, 1, bias=False)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def floating_point_ops(self, inputs):
return 0
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
pooling_type: str = "average",
padding_side: str = "right",
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1)).to(hidden_states.device)
else:
sequence_lengths = -1
if attention_mask is None:
attention_mask = torch.ne(input_ids, self.config.pad_token_id).float()
# print("hidden_states shape {}".format(hidden_states.shape))
# print("attention_mask shape {}".format(attention_mask.shape))
attention_mask_ext = attention_mask.unsqueeze(-1)
if pooling_type in ["last", "eos"]:
offset = 1 if pooling_type == "eos" else 2
if padding_side == "right":
pooled_hidden_state = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths - offset]
else:
pooled_hidden_state = hidden_states[torch.arange(batch_size, device=hidden_states.device), - offset]
elif pooling_type == "average":
pooled_hidden_state = (hidden_states * attention_mask_ext).sum(dim=1) / attention_mask_ext.sum(dim=1)
elif pooling_type == "max":
pooled_hidden_state = (hidden_states * attention_mask_ext).max(dim=1)[0]
else:
raise ValueError("The pooling method {} is not implemented!!".format(pooling_type))
pooled_logits = self.reward_head(pooled_hidden_state)
#pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
return {
"lm_logits": lm_logits,
"rm_logits": pooled_logits,
"hidden_states": transformer_outputs[0],
"rm_embeddings": pooled_hidden_state
}