From b69ab270c7c9d852463cd72b638ae4d840edf35b Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Tue, 17 Sep 2024 17:54:17 -0700 Subject: [PATCH 1/5] Multilabel confidence computation in attribute extraction --- src/autolabel/confidence.py | 136 +++++++++++--------- src/autolabel/labeler.py | 17 ++- src/autolabel/schema.py | 2 +- src/autolabel/tasks/attribute_extraction.py | 20 ++- 4 files changed, 106 insertions(+), 69 deletions(-) diff --git a/src/autolabel/confidence.py b/src/autolabel/confidence.py index 74953c15..6fbf62a0 100644 --- a/src/autolabel/confidence.py +++ b/src/autolabel/confidence.py @@ -7,8 +7,9 @@ import scipy.stats as stats import os import logging +import difflib -from autolabel.schema import LLMAnnotation, ConfidenceCacheEntry +from autolabel.schema import LLMAnnotation, ConfidenceCacheEntry, TaskType from autolabel.models import BaseModel from autolabel.cache import BaseCache @@ -76,6 +77,47 @@ def logprob_average( count += 1 return logprob_cumulative ** (1.0 / count) if count > 0 else 0 + def _logprob_average_per_label( + self, + logprobs: list, + label: str, + delimiter: str = ";", + **kwargs, + ) -> Dict[str, float]: + logprob_per_label = {} + curr_logprob_average = self.logprob_average(logprobs) + logprob_per_label = { + curr_label: curr_logprob_average for curr_label in label.split(delimiter) + } + conf_label_keys, prev_key_index, curr_key = {}, 0, "" + for i in range(len(logprobs)): + for curr_chars in list(logprobs[i].keys())[0]: + if delimiter in curr_chars: + curr_key += curr_chars.split(delimiter)[0] + conf_label_keys[curr_key] = self.logprob_average( + logprobs[prev_key_index:i] + ) + prev_key_index = i + curr_key = curr_chars.split(delimiter)[-1] + else: + curr_key += curr_chars + if len(curr_key) > 0: + conf_label_keys[curr_key] = self.logprob_average(logprobs[prev_key_index:]) + + for conf_label_candiate in conf_label_keys: + closest_match, closest_match_score = None, 0 + for label in logprob_per_label: + longest_substring = difflib.SequenceMatcher( + None, label, conf_label_candiate + ).find_longest_match(0, len(label), 0, len(conf_label_candiate)) + if longest_substring.size > closest_match_score: + closest_match = label + closest_match_score = longest_substring.size + if closest_match is not None: + logprob_per_label[closest_match] = conf_label_keys[conf_label_candiate] + + return logprob_per_label + def logprob_average_per_label( self, model_generation: LLMAnnotation, @@ -86,55 +128,18 @@ def logprob_average_per_label( This function calculates the confidence score per label when there are multiple labels in the response (i.e. multilabel tasks). This will return a confidence score per label. """ - logprob_per_label = {} logprobs = model_generation.generation_info["logprobs"]["top_logprobs"] if logprobs is None or len(logprobs) == 0: - return logprob_per_label - - # Remove all characters before a '\n' character in the logprobs, as - # this is parsed during the generation process - # In this case if input logprobs is - # [{"xx\nc": -1.2},{"Ab\n": -1.2}, {"Abc": -1.2}, {";": -1.3}, {"B": -1.4}, {"cd": -1.6}, {";": -1.5}, {"C": -1.4}] - # The output logprobs would be [{"Abc": -1.2}, {";": -1.3}, {"B": -1.4}, {"cd": -1.6}, {";": -1.5}, {"C": -1.4}] - for i in range(len(logprobs) - 1, -1, -1): - cur_key = list(logprobs[i].keys())[0] - if "\n" in cur_key: - new_key = cur_key.split("\n")[-1].strip() - if not new_key: - logprobs = logprobs[i + 1 :] - else: - logprobs[i] = {new_key: logprobs[i][cur_key]} - logprobs = logprobs[i:] - break - - # Suppose the output for which we compute confidence is "Abc;Bcd;C" - # In this case the logprobs can be a list of dictionaries like - # [{"Abc": -1.2}, {";": -1.3}, {"B": -1.4}, {"cd": -1.6}, {";": -1.5}, {"C": -1.4}] - curr_label = "" - logprob_start_idx = 0 - for i in range(len(logprobs)): - for char in list(logprobs[i].keys())[0]: - if char == delimiter: - logprob_end_idx = i if logprob_start_idx < i else i + 1 - logprob_per_label[curr_label.strip()] = self.logprob_average( - logprobs[logprob_start_idx:logprob_end_idx] - ) - curr_label = "" - logprob_start_idx = i + 1 - else: - curr_label += char - - # Average the logprobs for the last label (or only label if there is just one label) - if logprob_start_idx < len(logprobs): - logprob_per_label[curr_label.strip()] = self.logprob_average( - logprobs[logprob_start_idx:] - ) - return logprob_per_label + return {} + return self._logprob_average_per_label( + logprobs, delimiter, label=model_generation.label + ) def logprob_average_per_key( self, + model_generation: LLMAnnotation, logprobs: Union[list, dict], - keys: list, + keys: Dict[str, str], **kwargs, ): """ @@ -162,7 +167,7 @@ def logprob_average_per_key( # Find the locations of each key in the logprobs as indices # into the logprobs list - locations = [] + locations = [(len(logprobs), len(logprobs), "")] for key in keys: key_to_find = f'"{key}":' loc = full_string.find(key_to_find) @@ -174,27 +179,38 @@ def logprob_average_per_key( end_token = mapping[loc + len(key_to_find) - 1] locations.append((start_token, end_token, key)) locations.sort() + + logger.error(f"Keys: {keys}") + logger.error(f"Locations: {locations}") + logger.error(f"Logprobs: {logprobs}") # Here, the locations consist of the start and end *token* indices for each key # i.e for the keys A and B, we find the start and end tokens where they are found in the logprobs list # and store them in locations. For eg - locations can be [(1, 3, "A"), (9, 12, "C")] if len(logprob_per_key) != 0: logger.warning("Some keys not found in logprobs") - for i in range(len(locations) - 1): # Average the logprobs from the end of this to the start of the next token # This means that we average the logprobs of all tokens from the end of the key token # to the start of the next key token thus for the key "A" this would average the tokens # responsible for generating "B", - logprob_per_key[locations[i][2]] = self.logprob_average( - logprobs[locations[i][1] + 1 : locations[i + 1][0]] - ) - if len(locations) > 0 and len(logprobs) > locations[-1][1] + 1: - # Average the logprobs from the end of the last token to the end of the logprobs - logprob_per_key[locations[-1][2]] = self.logprob_average( - logprobs[locations[-1][1] + 1 :] - ) - + curr_key = locations[i][2] + if ( + curr_key in keys + and keys[curr_key] == TaskType.MULTILABEL_CLASSIFICATION + ): + logger.error( + f"Calculating logprob average per label for key {curr_key} with logprobs {logprobs[locations[i][1] + 1 : locations[i + 1][0]]}" + ) + logprob_per_key[curr_key] = self._logprob_average_per_label( + logprobs[locations[i][1] + 1 : locations[i + 1][0]], + label=model_generation.label[curr_key], + ) + else: + logprob_per_key[curr_key] = self.logprob_average( + logprobs[locations[i][1] + 1 : locations[i + 1][0]] + ) + logger.error(f"Logprob per key: {logprob_per_key}") return logprob_per_key async def p_true(self, model_generation: LLMAnnotation, **kwargs) -> float: @@ -236,7 +252,12 @@ def return_empty_logprob( model_generation.confidence_score = 0 return model_generation.confidence_score - async def calculate(self, model_generation: LLMAnnotation, **kwargs) -> float: + async def calculate( + self, + model_generation: LLMAnnotation, + keys: Optional[Dict] = None, + **kwargs, + ) -> float: if self.score_type not in self.SUPPORTED_CALCULATORS: raise NotImplementedError() @@ -282,12 +303,11 @@ async def calculate(self, model_generation: LLMAnnotation, **kwargs) -> float: return model_generation logprobs = model_generation.generation_info["logprobs"]["top_logprobs"] - keys = None if self.score_type == "logprob_average_per_key": assert isinstance( model_generation.label, dict ), "logprob_average_per_key requires a dict label from attribute extraction" - keys = model_generation.label.keys() + assert keys is not None, "Keys must be provided for logprob_average_per_key" confidence = self.SUPPORTED_CALCULATORS[self.score_type]( model_generation=model_generation, diff --git a/src/autolabel/labeler.py b/src/autolabel/labeler.py index 7c539146..91b4413f 100644 --- a/src/autolabel/labeler.py +++ b/src/autolabel/labeler.py @@ -244,9 +244,7 @@ async def arun( console=self.console, ) if self.console_output - else tqdm(indices) - if self.use_tqdm - else indices + else tqdm(indices) if self.use_tqdm else indices ): chunk = dataset.inputs[current_index] examples = [] @@ -323,9 +321,20 @@ async def arun( if self.config.confidence(): try: + keys = ( + { + attribute_dict.get( + "name", "" + ): attribute_dict.get("task_type", "") + for attribute_dict in self.config.attributes() + } + if self.config.task_type() + == TaskType.ATTRIBUTE_EXTRACTION + else None + ) annotation.confidence_score = ( await self.confidence.calculate( - model_generation=annotation + model_generation=annotation, keys=keys ) ) if ( diff --git a/src/autolabel/schema.py b/src/autolabel/schema.py index 98adf2c5..48e70b3e 100644 --- a/src/autolabel/schema.py +++ b/src/autolabel/schema.py @@ -117,7 +117,7 @@ class LLMAnnotation(BaseModel): successfully_labeled: bool label: Any curr_sample: Optional[bytes] = "" - confidence_score: Optional[float] = None + confidence_score: Optional[Union[float, Dict[str, float]]] = None generation_info: Optional[Dict[str, Any]] = None raw_response: Optional[str] = "" explanation: Optional[str] = "" diff --git a/src/autolabel/tasks/attribute_extraction.py b/src/autolabel/tasks/attribute_extraction.py index ea4f76b7..c6758bfc 100644 --- a/src/autolabel/tasks/attribute_extraction.py +++ b/src/autolabel/tasks/attribute_extraction.py @@ -71,6 +71,11 @@ def _construct_attribute_json(self) -> str: attribute_name = attribute_dict["name"] attribute_desc = attribute_dict["description"] + if TaskType.MULTILABEL_CLASSIFICATION == attribute_dict.get( + "task_type", "" + ): + attribute_desc += " Output should be a list of labels from the options provided below, separated by semicolons." + if "options" in attribute_dict: attribute_options = attribute_dict["options"] attribute_desc += f"\nOptions:\n{','.join(attribute_options)}" @@ -268,14 +273,17 @@ def parse_llm_response( if successfully_labeled: for attribute in self.config.attributes(): - attr_options = attribute.get("options") + attr_options, attr_type = attribute.get("options"), attribute.get( + "task_type" + ) if attr_options is not None and len(attr_options) > 0: attr_label = str(llm_label.get(attribute["name"])) - if attr_label is not None and attr_label not in attr_options: - logger.warning( - f"Attribute {attr_label} from the LLM response {llm_label} is not in the labels list" - ) - llm_label.pop(attribute["name"], None) + if attr_type == TaskType.CLASSIFICATION: + if attr_label is not None and attr_label not in attr_options: + logger.warning( + f"Attribute {attr_label} from the LLM response {llm_label} is not in the labels list" + ) + llm_label.pop(attribute["name"], None) return LLMAnnotation( curr_sample=pickle.dumps(curr_sample), From ff60a521238329dc119ffee3744366d6a982720f Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Wed, 18 Sep 2024 15:04:45 -0700 Subject: [PATCH 2/5] Confidence for multilabel [to be deprecated] --- src/autolabel/confidence.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/autolabel/confidence.py b/src/autolabel/confidence.py index 6fbf62a0..11aa4b6d 100644 --- a/src/autolabel/confidence.py +++ b/src/autolabel/confidence.py @@ -132,7 +132,9 @@ def logprob_average_per_label( if logprobs is None or len(logprobs) == 0: return {} return self._logprob_average_per_label( - logprobs, delimiter, label=model_generation.label + logprobs=logprobs, + label=model_generation.label, + delimiter=delimiter, ) def logprob_average_per_key( @@ -179,10 +181,6 @@ def logprob_average_per_key( end_token = mapping[loc + len(key_to_find) - 1] locations.append((start_token, end_token, key)) locations.sort() - - logger.error(f"Keys: {keys}") - logger.error(f"Locations: {locations}") - logger.error(f"Logprobs: {logprobs}") # Here, the locations consist of the start and end *token* indices for each key # i.e for the keys A and B, we find the start and end tokens where they are found in the logprobs list # and store them in locations. For eg - locations can be [(1, 3, "A"), (9, 12, "C")] From 0167c705acd4beba2a712e9a104bc9fd206005dc Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Wed, 18 Sep 2024 15:06:52 -0700 Subject: [PATCH 3/5] fmt --- src/autolabel/labeler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/autolabel/labeler.py b/src/autolabel/labeler.py index 91b4413f..81ba1462 100644 --- a/src/autolabel/labeler.py +++ b/src/autolabel/labeler.py @@ -244,7 +244,9 @@ async def arun( console=self.console, ) if self.console_output - else tqdm(indices) if self.use_tqdm else indices + else tqdm(indices) + if self.use_tqdm + else indices ): chunk = dataset.inputs[current_index] examples = [] From df7f0bb0fe5cac0991cea338f56cff83c105748b Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 23 Sep 2024 19:05:10 -0700 Subject: [PATCH 4/5] Addressing overlapping labels --- src/autolabel/confidence.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/autolabel/confidence.py b/src/autolabel/confidence.py index 11aa4b6d..e68b5e35 100644 --- a/src/autolabel/confidence.py +++ b/src/autolabel/confidence.py @@ -107,12 +107,28 @@ def _logprob_average_per_label( for conf_label_candiate in conf_label_keys: closest_match, closest_match_score = None, 0 for label in logprob_per_label: + + # The SequenceMatcher class is used to compare two sequences. It is especially useful for comparing sequences of characters. + # None - This is a function that is used to compare the two sequences. If it is None, the default function is used. + # label - The first sequence to compare + # conf_label_candiate - The second sequence to compare + + # The find_longest_match function returns a named tuple with the following fields: + # a - The start of the matching subsequence in the first sequence + # b - The start of the matching subsequence in the second sequence + # size - The length of the matching subsequence + longest_substring = difflib.SequenceMatcher( None, label, conf_label_candiate ).find_longest_match(0, len(label), 0, len(conf_label_candiate)) - if longest_substring.size > closest_match_score: + if ( + longest_substring.size + / (1e-6 + max(len(label), len(conf_label_candiate))) + ) > closest_match_score: closest_match = label - closest_match_score = longest_substring.size + closest_match_score = longest_substring.size / ( + 1e-6 + max(len(label), len(conf_label_candiate)) + ) if closest_match is not None: logprob_per_label[closest_match] = conf_label_keys[conf_label_candiate] @@ -208,7 +224,6 @@ def logprob_average_per_key( logprob_per_key[curr_key] = self.logprob_average( logprobs[locations[i][1] + 1 : locations[i + 1][0]] ) - logger.error(f"Logprob per key: {logprob_per_key}") return logprob_per_key async def p_true(self, model_generation: LLMAnnotation, **kwargs) -> float: From 426cd88f56c648538a4c48245cafaf001a1bac38 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 23 Sep 2024 19:08:16 -0700 Subject: [PATCH 5/5] black --- src/autolabel/confidence.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/autolabel/confidence.py b/src/autolabel/confidence.py index e68b5e35..44715e1c 100644 --- a/src/autolabel/confidence.py +++ b/src/autolabel/confidence.py @@ -107,7 +107,6 @@ def _logprob_average_per_label( for conf_label_candiate in conf_label_keys: closest_match, closest_match_score = None, 0 for label in logprob_per_label: - # The SequenceMatcher class is used to compare two sequences. It is especially useful for comparing sequences of characters. # None - This is a function that is used to compare the two sequences. If it is None, the default function is used. # label - The first sequence to compare