-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multilabel integration into task chains as attribute extraction #902
Changes from all commits
b69ab27
ff60a52
0167c70
df7f0bb
426cd88
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,8 +7,9 @@ | |
import scipy.stats as stats | ||
import os | ||
import logging | ||
import difflib | ||
|
||
from autolabel.schema import LLMAnnotation, ConfidenceCacheEntry | ||
from autolabel.schema import LLMAnnotation, ConfidenceCacheEntry, TaskType | ||
from autolabel.models import BaseModel | ||
from autolabel.cache import BaseCache | ||
|
||
|
@@ -76,6 +77,62 @@ def logprob_average( | |
count += 1 | ||
return logprob_cumulative ** (1.0 / count) if count > 0 else 0 | ||
|
||
def _logprob_average_per_label( | ||
self, | ||
logprobs: list, | ||
label: str, | ||
delimiter: str = ";", | ||
**kwargs, | ||
) -> Dict[str, float]: | ||
logprob_per_label = {} | ||
curr_logprob_average = self.logprob_average(logprobs) | ||
logprob_per_label = { | ||
curr_label: curr_logprob_average for curr_label in label.split(delimiter) | ||
} | ||
conf_label_keys, prev_key_index, curr_key = {}, 0, "" | ||
for i in range(len(logprobs)): | ||
for curr_chars in list(logprobs[i].keys())[0]: | ||
if delimiter in curr_chars: | ||
curr_key += curr_chars.split(delimiter)[0] | ||
conf_label_keys[curr_key] = self.logprob_average( | ||
logprobs[prev_key_index:i] | ||
) | ||
prev_key_index = i | ||
curr_key = curr_chars.split(delimiter)[-1] | ||
else: | ||
curr_key += curr_chars | ||
if len(curr_key) > 0: | ||
conf_label_keys[curr_key] = self.logprob_average(logprobs[prev_key_index:]) | ||
|
||
for conf_label_candiate in conf_label_keys: | ||
closest_match, closest_match_score = None, 0 | ||
for label in logprob_per_label: | ||
# The SequenceMatcher class is used to compare two sequences. It is especially useful for comparing sequences of characters. | ||
# None - This is a function that is used to compare the two sequences. If it is None, the default function is used. | ||
# label - The first sequence to compare | ||
# conf_label_candiate - The second sequence to compare | ||
|
||
# The find_longest_match function returns a named tuple with the following fields: | ||
# a - The start of the matching subsequence in the first sequence | ||
# b - The start of the matching subsequence in the second sequence | ||
# size - The length of the matching subsequence | ||
|
||
longest_substring = difflib.SequenceMatcher( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: comment on what this function does/what the arguments mean here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added comments! |
||
None, label, conf_label_candiate | ||
).find_longest_match(0, len(label), 0, len(conf_label_candiate)) | ||
if ( | ||
longest_substring.size | ||
/ (1e-6 + max(len(label), len(conf_label_candiate))) | ||
) > closest_match_score: | ||
closest_match = label | ||
closest_match_score = longest_substring.size / ( | ||
1e-6 + max(len(label), len(conf_label_candiate)) | ||
) | ||
if closest_match is not None: | ||
logprob_per_label[closest_match] = conf_label_keys[conf_label_candiate] | ||
|
||
return logprob_per_label | ||
|
||
def logprob_average_per_label( | ||
self, | ||
model_generation: LLMAnnotation, | ||
|
@@ -86,55 +143,20 @@ def logprob_average_per_label( | |
This function calculates the confidence score per label when there are multiple labels in the response (i.e. multilabel tasks). This will return | ||
a confidence score per label. | ||
""" | ||
logprob_per_label = {} | ||
logprobs = model_generation.generation_info["logprobs"]["top_logprobs"] | ||
if logprobs is None or len(logprobs) == 0: | ||
return logprob_per_label | ||
|
||
# Remove all characters before a '\n' character in the logprobs, as | ||
# this is parsed during the generation process | ||
# In this case if input logprobs is | ||
# [{"xx\nc": -1.2},{"Ab\n": -1.2}, {"Abc": -1.2}, {";": -1.3}, {"B": -1.4}, {"cd": -1.6}, {";": -1.5}, {"C": -1.4}] | ||
# The output logprobs would be [{"Abc": -1.2}, {";": -1.3}, {"B": -1.4}, {"cd": -1.6}, {";": -1.5}, {"C": -1.4}] | ||
for i in range(len(logprobs) - 1, -1, -1): | ||
cur_key = list(logprobs[i].keys())[0] | ||
if "\n" in cur_key: | ||
new_key = cur_key.split("\n")[-1].strip() | ||
if not new_key: | ||
logprobs = logprobs[i + 1 :] | ||
else: | ||
logprobs[i] = {new_key: logprobs[i][cur_key]} | ||
logprobs = logprobs[i:] | ||
break | ||
Comment on lines
-94
to
-108
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this was needed because of a bug where sometime the model will output a lot of characters like \n or some explanation followed by a \n to give the label. We handle this in parse_llm_response by removing everything before the last \n. Is the expectation that this will never happen now because of guided generation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep with guided generation this shouldn't happen now |
||
|
||
# Suppose the output for which we compute confidence is "Abc;Bcd;C" | ||
# In this case the logprobs can be a list of dictionaries like | ||
# [{"Abc": -1.2}, {";": -1.3}, {"B": -1.4}, {"cd": -1.6}, {";": -1.5}, {"C": -1.4}] | ||
curr_label = "" | ||
logprob_start_idx = 0 | ||
for i in range(len(logprobs)): | ||
for char in list(logprobs[i].keys())[0]: | ||
if char == delimiter: | ||
logprob_end_idx = i if logprob_start_idx < i else i + 1 | ||
logprob_per_label[curr_label.strip()] = self.logprob_average( | ||
logprobs[logprob_start_idx:logprob_end_idx] | ||
) | ||
curr_label = "" | ||
logprob_start_idx = i + 1 | ||
else: | ||
curr_label += char | ||
|
||
# Average the logprobs for the last label (or only label if there is just one label) | ||
if logprob_start_idx < len(logprobs): | ||
logprob_per_label[curr_label.strip()] = self.logprob_average( | ||
logprobs[logprob_start_idx:] | ||
) | ||
return logprob_per_label | ||
return {} | ||
return self._logprob_average_per_label( | ||
logprobs=logprobs, | ||
label=model_generation.label, | ||
delimiter=delimiter, | ||
) | ||
|
||
def logprob_average_per_key( | ||
self, | ||
model_generation: LLMAnnotation, | ||
logprobs: Union[list, dict], | ||
keys: list, | ||
keys: Dict[str, str], | ||
**kwargs, | ||
): | ||
""" | ||
|
@@ -162,7 +184,7 @@ def logprob_average_per_key( | |
|
||
# Find the locations of each key in the logprobs as indices | ||
# into the logprobs list | ||
locations = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. comment on why we need this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is just implementation choice - allows to avoid setting |
||
locations = [(len(logprobs), len(logprobs), "")] | ||
for key in keys: | ||
key_to_find = f'"{key}":' | ||
loc = full_string.find(key_to_find) | ||
|
@@ -180,21 +202,27 @@ def logprob_average_per_key( | |
|
||
if len(logprob_per_key) != 0: | ||
logger.warning("Some keys not found in logprobs") | ||
|
||
for i in range(len(locations) - 1): | ||
# Average the logprobs from the end of this to the start of the next token | ||
# This means that we average the logprobs of all tokens from the end of the key token | ||
# to the start of the next key token thus for the key "A" this would average the tokens | ||
# responsible for generating "B", | ||
logprob_per_key[locations[i][2]] = self.logprob_average( | ||
logprobs[locations[i][1] + 1 : locations[i + 1][0]] | ||
) | ||
if len(locations) > 0 and len(logprobs) > locations[-1][1] + 1: | ||
# Average the logprobs from the end of the last token to the end of the logprobs | ||
logprob_per_key[locations[-1][2]] = self.logprob_average( | ||
logprobs[locations[-1][1] + 1 :] | ||
) | ||
|
||
curr_key = locations[i][2] | ||
if ( | ||
curr_key in keys | ||
and keys[curr_key] == TaskType.MULTILABEL_CLASSIFICATION | ||
): | ||
logger.error( | ||
f"Calculating logprob average per label for key {curr_key} with logprobs {logprobs[locations[i][1] + 1 : locations[i + 1][0]]}" | ||
) | ||
logprob_per_key[curr_key] = self._logprob_average_per_label( | ||
logprobs[locations[i][1] + 1 : locations[i + 1][0]], | ||
label=model_generation.label[curr_key], | ||
) | ||
else: | ||
logprob_per_key[curr_key] = self.logprob_average( | ||
logprobs[locations[i][1] + 1 : locations[i + 1][0]] | ||
) | ||
return logprob_per_key | ||
|
||
async def p_true(self, model_generation: LLMAnnotation, **kwargs) -> float: | ||
|
@@ -236,7 +264,12 @@ def return_empty_logprob( | |
model_generation.confidence_score = 0 | ||
return model_generation.confidence_score | ||
|
||
async def calculate(self, model_generation: LLMAnnotation, **kwargs) -> float: | ||
async def calculate( | ||
self, | ||
model_generation: LLMAnnotation, | ||
keys: Optional[Dict] = None, | ||
**kwargs, | ||
) -> float: | ||
if self.score_type not in self.SUPPORTED_CALCULATORS: | ||
raise NotImplementedError() | ||
|
||
|
@@ -282,12 +315,11 @@ async def calculate(self, model_generation: LLMAnnotation, **kwargs) -> float: | |
return model_generation | ||
logprobs = model_generation.generation_info["logprobs"]["top_logprobs"] | ||
|
||
keys = None | ||
if self.score_type == "logprob_average_per_key": | ||
assert isinstance( | ||
model_generation.label, dict | ||
), "logprob_average_per_key requires a dict label from attribute extraction" | ||
keys = model_generation.label.keys() | ||
assert keys is not None, "Keys must be provided for logprob_average_per_key" | ||
|
||
confidence = self.SUPPORTED_CALCULATORS[self.score_type]( | ||
model_generation=model_generation, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -71,6 +71,11 @@ def _construct_attribute_json(self) -> str: | |
|
||
attribute_name = attribute_dict["name"] | ||
attribute_desc = attribute_dict["description"] | ||
if TaskType.MULTILABEL_CLASSIFICATION == attribute_dict.get( | ||
"task_type", "" | ||
): | ||
attribute_desc += " Output should be a list of labels from the options provided below, separated by semicolons." | ||
|
||
if "options" in attribute_dict: | ||
attribute_options = attribute_dict["options"] | ||
attribute_desc += f"\nOptions:\n{','.join(attribute_options)}" | ||
|
@@ -268,14 +273,17 @@ def parse_llm_response( | |
|
||
if successfully_labeled: | ||
for attribute in self.config.attributes(): | ||
attr_options = attribute.get("options") | ||
attr_options, attr_type = attribute.get("options"), attribute.get( | ||
"task_type" | ||
) | ||
if attr_options is not None and len(attr_options) > 0: | ||
attr_label = str(llm_label.get(attribute["name"])) | ||
if attr_label is not None and attr_label not in attr_options: | ||
logger.warning( | ||
f"Attribute {attr_label} from the LLM response {llm_label} is not in the labels list" | ||
) | ||
llm_label.pop(attribute["name"], None) | ||
if attr_type == TaskType.CLASSIFICATION: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also check that every label that we get for multilabel is one of the options if the guidelines were followed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We shouldn't need this either moving forward with guided generation. Will be removing classification as well soon! |
||
if attr_label is not None and attr_label not in attr_options: | ||
logger.warning( | ||
f"Attribute {attr_label} from the LLM response {llm_label} is not in the labels list" | ||
) | ||
llm_label.pop(attribute["name"], None) | ||
|
||
return LLMAnnotation( | ||
curr_sample=pickle.dumps(curr_sample), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rajasbansal for review