Skip to content

Commit

Permalink
circular imports fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
DhruvaBansal00 committed Nov 27, 2024
1 parent f8e255b commit aa3dfaa
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 31 deletions.
7 changes: 5 additions & 2 deletions src/autolabel/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

from sklearn.metrics import accuracy_score

from autolabel.metrics import BaseMetric
from autolabel.schema import LLMAnnotation, MetricResult, MetricType

from .base import BaseMetric

logger = logging.getLogger(__name__)


Expand All @@ -14,7 +15,9 @@ def __init__(self) -> None:
super().__init__()

def compute(
self, llm_labels: List[LLMAnnotation], gt_labels: List[str],
self,
llm_labels: List[LLMAnnotation],
gt_labels: List[str],
) -> List[MetricResult]:
# If there are not ground truth labels, return an empty list
if not gt_labels:
Expand Down
7 changes: 5 additions & 2 deletions src/autolabel/metrics/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import pylcs
from sklearn.metrics import roc_auc_score

from autolabel.metrics import BaseMetric
from autolabel.schema import LLMAnnotation, MetricResult, MetricType

from .base import BaseMetric

logger = logging.getLogger(__name__)


Expand All @@ -23,7 +24,9 @@ def similarity_acceptance(self, a, b) -> bool:
return substring_lengths / max(len(a) + 1e-5, len(b) + 1e-5)

def compute(
self, llm_labels: List[LLMAnnotation], gt_labels: List[str],
self,
llm_labels: List[LLMAnnotation],
gt_labels: List[str],
) -> List[MetricResult]:
if not gt_labels:
logger.warning(
Expand Down
7 changes: 5 additions & 2 deletions src/autolabel/metrics/classification_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

from sklearn.metrics import classification_report

from autolabel.metrics import BaseMetric
from autolabel.schema import LLMAnnotation, MetricResult, MetricType

from .base import BaseMetric

logger = logging.getLogger(__name__)


Expand All @@ -14,7 +15,9 @@ def __init__(self) -> None:
super().__init__()

def compute(
self, llm_labels: List[LLMAnnotation], gt_labels: List[str],
self,
llm_labels: List[LLMAnnotation],
gt_labels: List[str],
) -> List[MetricResult]:
# If there are not ground truth labels, return an empty list
if not gt_labels:
Expand Down
7 changes: 5 additions & 2 deletions src/autolabel/metrics/completion_rate.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from typing import List

from autolabel.metrics import BaseMetric
from autolabel.schema import LLMAnnotation, MetricResult, MetricType

from .base import BaseMetric


class CompletionRateMetric(BaseMetric):
def __init__(self) -> None:
super().__init__()

def compute(
self, llm_labels: List[LLMAnnotation], gt_labels: List[str],
self,
llm_labels: List[LLMAnnotation],
gt_labels: List[str],
) -> List[MetricResult]:
completed = 0
for label in llm_labels:
Expand Down
15 changes: 11 additions & 4 deletions src/autolabel/metrics/f1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from sklearn.metrics import f1_score
from sklearn.preprocessing import MultiLabelBinarizer

from autolabel.metrics import BaseMetric
from autolabel.schema import F1Type, LLMAnnotation, MetricResult, MetricType
from autolabel.utils import normalize_text

from .base import BaseMetric

logger = logging.getLogger(__name__)


Expand All @@ -26,7 +27,9 @@ def __init__(
self.average = average

def multi_label_compute(
self, llm_labels: List[LLMAnnotation], gt_labels: List[str],
self,
llm_labels: List[LLMAnnotation],
gt_labels: List[str],
) -> List[MetricResult]:
filtered_llm_labels = []
filtered_gt_labels = []
Expand All @@ -52,7 +55,9 @@ def multi_label_compute(
return value

def text_compute(
self, llm_labels: List[LLMAnnotation], gt_labels: List[str],
self,
llm_labels: List[LLMAnnotation],
gt_labels: List[str],
) -> List[MetricResult]:
truth = [normalize_text(gt_label).split(self.sep) for gt_label in gt_labels]
prediction = [
Expand Down Expand Up @@ -101,7 +106,9 @@ def text_compute(
return values

def compute(
self, llm_labels: List[LLMAnnotation], gt_labels: List[str],
self,
llm_labels: List[LLMAnnotation],
gt_labels: List[str],
) -> List[MetricResult]:
# If there are not ground truth labels, return an empty list
if not gt_labels:
Expand Down
51 changes: 32 additions & 19 deletions src/autolabel/tasks/attribute_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
MetricResult,
TaskType,
)
from autolabel.tasks import BaseTask
from autolabel.utils import get_format_variables

from .base import BaseTask

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -54,7 +55,8 @@ def __init__(self, config: AutolabelConfig) -> None:
self.metrics.append(AUROCMetric())

def _construct_attribute_json(
self, selected_labels_map: Dict[str, List[str]] = None,
self,
selected_labels_map: Dict[str, List[str]] = None,
) -> Tuple[str, Dict]:
"""
This function is used to construct the attribute json string for the output guidelines.
Expand All @@ -66,15 +68,18 @@ def _construct_attribute_json(
str: A string containing the output attributes.
"""
output_json, output_schema = {}, {
"title": "AnswerFormat",
"description": "Answer to the provided prompt.",
"type": "object",
"properties": {},
"required": [],
"additionalProperties": False,
"definitions": {},
}
output_json, output_schema = (
{},
{
"title": "AnswerFormat",
"description": "Answer to the provided prompt.",
"type": "object",
"properties": {},
"required": [],
"additionalProperties": False,
"definitions": {},
},
)
for attribute_dict in self.config.attributes():
if "name" not in attribute_dict or "description" not in attribute_dict:
raise ValueError(
Expand All @@ -84,9 +89,13 @@ def _construct_attribute_json(
attribute_desc = attribute_dict["description"]
attribute_name = attribute_dict["name"]

if attribute_dict.get(
"task_type", "",
) == TaskType.MULTILABEL_CLASSIFICATION:
if (
attribute_dict.get(
"task_type",
"",
)
== TaskType.MULTILABEL_CLASSIFICATION
):
attribute_desc += " The output format should be all the labels separated by semicolons. For example: label1;label2;label3"

if len(attribute_dict.get("options", [])) > 0 or (
Expand Down Expand Up @@ -239,7 +248,8 @@ def construct_prompt(
for key in input:
if input[key] is not None:
example_template = example_template.replace(
f"{{{key}}}", input[key],
f"{{{key}}}",
input[key],
)
current_example = example_template

Expand Down Expand Up @@ -299,7 +309,10 @@ def get_explanation_prompt(self, example: Dict, include_label=True) -> str:
)

def get_generate_dataset_prompt(
self, label: str, num_rows: int, guidelines: str = None,
self,
label: str,
num_rows: int,
guidelines: str = None,
) -> str:
raise NotImplementedError("Dataset generation not implemented for this task")

Expand Down Expand Up @@ -372,9 +385,9 @@ def parse_llm_response(
original_attr_labels,
),
)
llm_label[
attribute["name"]
] = self.config.label_separator().join(filtered_attr_labels)
llm_label[attribute["name"]] = (
self.config.label_separator().join(filtered_attr_labels)
)
if len(filtered_attr_labels) != len(original_attr_labels):
logger.warning(
f"Attribute {attr_label} from the LLM response {llm_label} is not in the labels list. Filtered list: {filtered_attr_labels}",
Expand Down

0 comments on commit aa3dfaa

Please sign in to comment.