-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multilabel integration into task chains as attribute extraction #902
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to do something similar to this https://github.com/refuel-ai/autolabel/blob/main/src/autolabel/tasks/base.py#L224 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm otherwise
@@ -76,6 +77,47 @@ def logprob_average( | |||
count += 1 | |||
return logprob_cumulative ** (1.0 / count) if count > 0 else 0 | |||
|
|||
def _logprob_average_per_label( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rajasbansal for review
Fortunately no. We just return a semicolon separated list as the label for multilabel attributes and the product takes care of splitting them while displaying just like before. Confidence computation has to perform some splitting to get a confidence value for each key however. |
for conf_label_candiate in conf_label_keys: | ||
closest_match, closest_match_score = None, 0 | ||
for label in logprob_per_label: | ||
longest_substring = difflib.SequenceMatcher( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: comment on what this function does/what the arguments mean here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added comments!
src/autolabel/confidence.py
Outdated
longest_substring = difflib.SequenceMatcher( | ||
None, label, conf_label_candiate | ||
).find_longest_match(0, len(label), 0, len(conf_label_candiate)) | ||
if longest_substring.size > closest_match_score: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if one label is contained within another label i.e if the labels are
LessRed, MoreRed, Red
will that lead to an issue here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed the metric to be proportion of the largest string that overlaps, instead of the number of characters. I think that should take care of this case.
# Remove all characters before a '\n' character in the logprobs, as | ||
# this is parsed during the generation process | ||
# In this case if input logprobs is | ||
# [{"xx\nc": -1.2},{"Ab\n": -1.2}, {"Abc": -1.2}, {";": -1.3}, {"B": -1.4}, {"cd": -1.6}, {";": -1.5}, {"C": -1.4}] | ||
# The output logprobs would be [{"Abc": -1.2}, {";": -1.3}, {"B": -1.4}, {"cd": -1.6}, {";": -1.5}, {"C": -1.4}] | ||
for i in range(len(logprobs) - 1, -1, -1): | ||
cur_key = list(logprobs[i].keys())[0] | ||
if "\n" in cur_key: | ||
new_key = cur_key.split("\n")[-1].strip() | ||
if not new_key: | ||
logprobs = logprobs[i + 1 :] | ||
else: | ||
logprobs[i] = {new_key: logprobs[i][cur_key]} | ||
logprobs = logprobs[i:] | ||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was needed because of a bug where sometime the model will output a lot of characters like \n or some explanation followed by a \n to give the label. We handle this in parse_llm_response by removing everything before the last \n. Is the expectation that this will never happen now because of guided generation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep with guided generation this shouldn't happen now
@@ -162,7 +169,7 @@ def logprob_average_per_key( | |||
|
|||
# Find the locations of each key in the logprobs as indices | |||
# into the logprobs list | |||
locations = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment on why we need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just implementation choice - allows to avoid setting logprob_per_key[locations[-1][2]]
at the end explicitly.
f"Attribute {attr_label} from the LLM response {llm_label} is not in the labels list" | ||
) | ||
llm_label.pop(attribute["name"], None) | ||
if attr_type == TaskType.CLASSIFICATION: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also check that every label that we get for multilabel is one of the options if the guidelines were followed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't need this either moving forward with guided generation. Will be removing classification as well soon!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm if #902 (comment) is not an issue
thanks @DhruvaBansal00 lgtm from my end |
Pull Review Summary
Description
A summary of the change. Please also include relevant motivation and context. This could include links to any docs/Slack threads/Github issues other artifacts.
Type of change
Tests
Locally