Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multilabel integration into task chains as attribute extraction #902

Merged
merged 5 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 90 additions & 58 deletions src/autolabel/confidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import scipy.stats as stats
import os
import logging
import difflib

from autolabel.schema import LLMAnnotation, ConfidenceCacheEntry
from autolabel.schema import LLMAnnotation, ConfidenceCacheEntry, TaskType
from autolabel.models import BaseModel
from autolabel.cache import BaseCache

Expand Down Expand Up @@ -76,6 +77,62 @@ def logprob_average(
count += 1
return logprob_cumulative ** (1.0 / count) if count > 0 else 0

def _logprob_average_per_label(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rajasbansal for review

self,
logprobs: list,
label: str,
delimiter: str = ";",
**kwargs,
) -> Dict[str, float]:
logprob_per_label = {}
curr_logprob_average = self.logprob_average(logprobs)
logprob_per_label = {
curr_label: curr_logprob_average for curr_label in label.split(delimiter)
}
conf_label_keys, prev_key_index, curr_key = {}, 0, ""
for i in range(len(logprobs)):
for curr_chars in list(logprobs[i].keys())[0]:
if delimiter in curr_chars:
curr_key += curr_chars.split(delimiter)[0]
conf_label_keys[curr_key] = self.logprob_average(
logprobs[prev_key_index:i]
)
prev_key_index = i
curr_key = curr_chars.split(delimiter)[-1]
else:
curr_key += curr_chars
if len(curr_key) > 0:
conf_label_keys[curr_key] = self.logprob_average(logprobs[prev_key_index:])

for conf_label_candiate in conf_label_keys:
closest_match, closest_match_score = None, 0
for label in logprob_per_label:
# The SequenceMatcher class is used to compare two sequences. It is especially useful for comparing sequences of characters.
# None - This is a function that is used to compare the two sequences. If it is None, the default function is used.
# label - The first sequence to compare
# conf_label_candiate - The second sequence to compare

# The find_longest_match function returns a named tuple with the following fields:
# a - The start of the matching subsequence in the first sequence
# b - The start of the matching subsequence in the second sequence
# size - The length of the matching subsequence

longest_substring = difflib.SequenceMatcher(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: comment on what this function does/what the arguments mean here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added comments!

None, label, conf_label_candiate
).find_longest_match(0, len(label), 0, len(conf_label_candiate))
if (
longest_substring.size
/ (1e-6 + max(len(label), len(conf_label_candiate)))
) > closest_match_score:
closest_match = label
closest_match_score = longest_substring.size / (
1e-6 + max(len(label), len(conf_label_candiate))
)
if closest_match is not None:
logprob_per_label[closest_match] = conf_label_keys[conf_label_candiate]

return logprob_per_label

def logprob_average_per_label(
self,
model_generation: LLMAnnotation,
Expand All @@ -86,55 +143,20 @@ def logprob_average_per_label(
This function calculates the confidence score per label when there are multiple labels in the response (i.e. multilabel tasks). This will return
a confidence score per label.
"""
logprob_per_label = {}
logprobs = model_generation.generation_info["logprobs"]["top_logprobs"]
if logprobs is None or len(logprobs) == 0:
return logprob_per_label

# Remove all characters before a '\n' character in the logprobs, as
# this is parsed during the generation process
# In this case if input logprobs is
# [{"xx\nc": -1.2},{"Ab\n": -1.2}, {"Abc": -1.2}, {";": -1.3}, {"B": -1.4}, {"cd": -1.6}, {";": -1.5}, {"C": -1.4}]
# The output logprobs would be [{"Abc": -1.2}, {";": -1.3}, {"B": -1.4}, {"cd": -1.6}, {";": -1.5}, {"C": -1.4}]
for i in range(len(logprobs) - 1, -1, -1):
cur_key = list(logprobs[i].keys())[0]
if "\n" in cur_key:
new_key = cur_key.split("\n")[-1].strip()
if not new_key:
logprobs = logprobs[i + 1 :]
else:
logprobs[i] = {new_key: logprobs[i][cur_key]}
logprobs = logprobs[i:]
break
Comment on lines -94 to -108
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was needed because of a bug where sometime the model will output a lot of characters like \n or some explanation followed by a \n to give the label. We handle this in parse_llm_response by removing everything before the last \n. Is the expectation that this will never happen now because of guided generation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep with guided generation this shouldn't happen now


# Suppose the output for which we compute confidence is "Abc;Bcd;C"
# In this case the logprobs can be a list of dictionaries like
# [{"Abc": -1.2}, {";": -1.3}, {"B": -1.4}, {"cd": -1.6}, {";": -1.5}, {"C": -1.4}]
curr_label = ""
logprob_start_idx = 0
for i in range(len(logprobs)):
for char in list(logprobs[i].keys())[0]:
if char == delimiter:
logprob_end_idx = i if logprob_start_idx < i else i + 1
logprob_per_label[curr_label.strip()] = self.logprob_average(
logprobs[logprob_start_idx:logprob_end_idx]
)
curr_label = ""
logprob_start_idx = i + 1
else:
curr_label += char

# Average the logprobs for the last label (or only label if there is just one label)
if logprob_start_idx < len(logprobs):
logprob_per_label[curr_label.strip()] = self.logprob_average(
logprobs[logprob_start_idx:]
)
return logprob_per_label
return {}
return self._logprob_average_per_label(
logprobs=logprobs,
label=model_generation.label,
delimiter=delimiter,
)

def logprob_average_per_key(
self,
model_generation: LLMAnnotation,
logprobs: Union[list, dict],
keys: list,
keys: Dict[str, str],
**kwargs,
):
"""
Expand Down Expand Up @@ -162,7 +184,7 @@ def logprob_average_per_key(

# Find the locations of each key in the logprobs as indices
# into the logprobs list
locations = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment on why we need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just implementation choice - allows to avoid setting logprob_per_key[locations[-1][2]] at the end explicitly.

locations = [(len(logprobs), len(logprobs), "")]
for key in keys:
key_to_find = f'"{key}":'
loc = full_string.find(key_to_find)
Expand All @@ -180,21 +202,27 @@ def logprob_average_per_key(

if len(logprob_per_key) != 0:
logger.warning("Some keys not found in logprobs")

for i in range(len(locations) - 1):
# Average the logprobs from the end of this to the start of the next token
# This means that we average the logprobs of all tokens from the end of the key token
# to the start of the next key token thus for the key "A" this would average the tokens
# responsible for generating "B",
logprob_per_key[locations[i][2]] = self.logprob_average(
logprobs[locations[i][1] + 1 : locations[i + 1][0]]
)
if len(locations) > 0 and len(logprobs) > locations[-1][1] + 1:
# Average the logprobs from the end of the last token to the end of the logprobs
logprob_per_key[locations[-1][2]] = self.logprob_average(
logprobs[locations[-1][1] + 1 :]
)

curr_key = locations[i][2]
if (
curr_key in keys
and keys[curr_key] == TaskType.MULTILABEL_CLASSIFICATION
):
logger.error(
f"Calculating logprob average per label for key {curr_key} with logprobs {logprobs[locations[i][1] + 1 : locations[i + 1][0]]}"
)
logprob_per_key[curr_key] = self._logprob_average_per_label(
logprobs[locations[i][1] + 1 : locations[i + 1][0]],
label=model_generation.label[curr_key],
)
else:
logprob_per_key[curr_key] = self.logprob_average(
logprobs[locations[i][1] + 1 : locations[i + 1][0]]
)
return logprob_per_key

async def p_true(self, model_generation: LLMAnnotation, **kwargs) -> float:
Expand Down Expand Up @@ -236,7 +264,12 @@ def return_empty_logprob(
model_generation.confidence_score = 0
return model_generation.confidence_score

async def calculate(self, model_generation: LLMAnnotation, **kwargs) -> float:
async def calculate(
self,
model_generation: LLMAnnotation,
keys: Optional[Dict] = None,
**kwargs,
) -> float:
if self.score_type not in self.SUPPORTED_CALCULATORS:
raise NotImplementedError()

Expand Down Expand Up @@ -282,12 +315,11 @@ async def calculate(self, model_generation: LLMAnnotation, **kwargs) -> float:
return model_generation
logprobs = model_generation.generation_info["logprobs"]["top_logprobs"]

keys = None
if self.score_type == "logprob_average_per_key":
assert isinstance(
model_generation.label, dict
), "logprob_average_per_key requires a dict label from attribute extraction"
keys = model_generation.label.keys()
assert keys is not None, "Keys must be provided for logprob_average_per_key"

confidence = self.SUPPORTED_CALCULATORS[self.score_type](
model_generation=model_generation,
Expand Down
13 changes: 12 additions & 1 deletion src/autolabel/labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,20 @@ async def arun(

if self.config.confidence():
try:
keys = (
{
attribute_dict.get(
"name", ""
): attribute_dict.get("task_type", "")
for attribute_dict in self.config.attributes()
}
if self.config.task_type()
== TaskType.ATTRIBUTE_EXTRACTION
else None
)
annotation.confidence_score = (
await self.confidence.calculate(
model_generation=annotation
model_generation=annotation, keys=keys
)
)
if (
Expand Down
2 changes: 1 addition & 1 deletion src/autolabel/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class LLMAnnotation(BaseModel):
successfully_labeled: bool
label: Any
curr_sample: Optional[bytes] = ""
confidence_score: Optional[float] = None
confidence_score: Optional[Union[float, Dict[str, float]]] = None
generation_info: Optional[Dict[str, Any]] = None
raw_response: Optional[str] = ""
explanation: Optional[str] = ""
Expand Down
20 changes: 14 additions & 6 deletions src/autolabel/tasks/attribute_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def _construct_attribute_json(self) -> str:

attribute_name = attribute_dict["name"]
attribute_desc = attribute_dict["description"]
if TaskType.MULTILABEL_CLASSIFICATION == attribute_dict.get(
"task_type", ""
):
attribute_desc += " Output should be a list of labels from the options provided below, separated by semicolons."

if "options" in attribute_dict:
attribute_options = attribute_dict["options"]
attribute_desc += f"\nOptions:\n{','.join(attribute_options)}"
Expand Down Expand Up @@ -268,14 +273,17 @@ def parse_llm_response(

if successfully_labeled:
for attribute in self.config.attributes():
attr_options = attribute.get("options")
attr_options, attr_type = attribute.get("options"), attribute.get(
"task_type"
)
if attr_options is not None and len(attr_options) > 0:
attr_label = str(llm_label.get(attribute["name"]))
if attr_label is not None and attr_label not in attr_options:
logger.warning(
f"Attribute {attr_label} from the LLM response {llm_label} is not in the labels list"
)
llm_label.pop(attribute["name"], None)
if attr_type == TaskType.CLASSIFICATION:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also check that every label that we get for multilabel is one of the options if the guidelines were followed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't need this either moving forward with guided generation. Will be removing classification as well soon!

if attr_label is not None and attr_label not in attr_options:
logger.warning(
f"Attribute {attr_label} from the LLM response {llm_label} is not in the labels list"
)
llm_label.pop(attribute["name"], None)

return LLMAnnotation(
curr_sample=pickle.dumps(curr_sample),
Expand Down
Loading