Skip to content

Commit

Permalink
Allow confidence computation per attribute and setting label columns (#…
Browse files Browse the repository at this point in the history
…572)

Co-authored-by: Rajas Bansal <rajas@refuel.ai>
  • Loading branch information
rajasbansal and rajasbansal authored Sep 20, 2023
1 parent acb8755 commit 089b375
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 8 deletions.
39 changes: 39 additions & 0 deletions src/autolabel/confidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
self.SUPPORTED_CALCULATORS = {
"logprob_average": self.logprob_average,
"p_true": self.p_true,
"logprob_average_per_key": self.logprob_average_per_key,
}
self.BASE_API = "https://refuel-llm.refuel.ai/"
self.REFUEL_API_ENV = "REFUEL_API_KEY"
Expand Down Expand Up @@ -62,6 +63,36 @@ def logprob_average(
count += 1
return logprob_cumulative / count if count > 0 else 0

def logprob_average_per_key(
self,
logprobs: list,
keys: list,
**kwargs,
):
"""
This function calculates the confidence score per key. This will return
a confidence score per key.
"""

# Find all '"' and '",'
# This is a hacky way to find all the keys in the prompt
indices = []
for ind, logprob in enumerate(logprobs):
key = list(logprob.keys())[0]
if key == '"' or key == '",':
indices.append(ind)
if len(indices) != 4 * len(keys):
logger.error("Unable to find all keys in prompt")
return {key: 0 for key in keys}

# Find the logprob for each key
logprob_per_key = {}
for i, key in enumerate(keys):
logprob_per_key[key] = self.logprob_average(
logprobs[indices[4 * i + 2] + 1 : indices[4 * i + 3]]
)
return logprob_per_key

def p_true(self, model_generation: LLMAnnotation, prompt: str, **kwargs) -> float:
p_true_prompt = f"{prompt}{model_generation.raw_response} \n Is the answer to the last example correct? Answer in one word on the same line [Yes/No]: "

Expand Down Expand Up @@ -105,9 +136,17 @@ def calculate(self, model_generation: LLMAnnotation, **kwargs) -> float:
return model_generation
logprobs = model_generation.generation_info["logprobs"]["top_logprobs"]

keys = None
if self.score_type == "logprob_average_per_key":
assert isinstance(
model_generation.label, dict
), "logprob_average_per_key requires a dict label from attribute extraction"
keys = model_generation.label.keys()

confidence = self.SUPPORTED_CALCULATORS[self.score_type](
model_generation=model_generation,
logprobs=logprobs,
keys=keys,
**kwargs,
)
return confidence
Expand Down
17 changes: 12 additions & 5 deletions src/autolabel/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,13 @@ def __init__(
else df[label_column].tolist()
)
else:
attributes = [attr["name"] for attr in self.config.attributes()]
gt_labels = {
name: df[name].tolist() if name in df.keys() else None
for name in attributes
}
gt_labels = {}
for attr in self.config.attributes():
name = attr["name"]
column_name = attr["label_column"] if "label_column" in attr else name
gt_labels[name] = (
df[column_name].tolist() if column_name in df.keys() else None
)

self.df = df
self.inputs = inputs
Expand Down Expand Up @@ -148,6 +150,11 @@ def process_labels(
self.df[self.generate_label_name("confidence")] = [
x.confidence_score for x in llm_labels
]
if self.config.task_type() == TaskType.ATTRIBUTE_EXTRACTION:
for attr in self.config.attributes():
self.df[self.generate_label_name("confidence", attr["name"])] = [
x.confidence_score[attr["name"]] for x in llm_labels
]

# Add the LLM explanations to the dataframe if chain of thought is set in config
if self.config.chain_of_thought():
Expand Down
7 changes: 4 additions & 3 deletions src/autolabel/labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,10 @@ def __init__(
self.llm: BaseModel = ModelFactory.from_config(
self.config, cache=self.generation_cache
)
self.confidence = ConfidenceCalculator(
score_type="logprob_average", llm=self.llm
)
score_type = "logprob_average"
if self.config.task_type() == TaskType.ATTRIBUTE_EXTRACTION:
score_type = "logprob_average_per_key"
self.confidence = ConfidenceCalculator(score_type=score_type, llm=self.llm)
self.example_selector = example_selector

# Only used if we don't use task management
Expand Down

0 comments on commit 089b375

Please sign in to comment.