From 9fa783e5a1cdbaacf2ea883774ea00b169110115 Mon Sep 17 00:00:00 2001 From: Dhruva Bansal Date: Mon, 30 Sep 2024 16:47:15 -0700 Subject: [PATCH] Setting output schema to None for expln (#907) --- src/autolabel/confidence.py | 2 +- src/autolabel/labeler.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/autolabel/confidence.py b/src/autolabel/confidence.py index 62f2a130..a949d749 100644 --- a/src/autolabel/confidence.py +++ b/src/autolabel/confidence.py @@ -228,7 +228,7 @@ async def p_true(self, model_generation: LLMAnnotation, **kwargs) -> float: if self.llm.returns_token_probs(): p_true_prompt = model_generation.prompt + p_true_prompt - response = self.llm.label([p_true_prompt]) + response = self.llm.label([p_true_prompt], output_schema=None) response_logprobs = response.generations[0][0].generation_info["logprobs"][ "top_logprobs" ] diff --git a/src/autolabel/labeler.py b/src/autolabel/labeler.py index 51ee1551..d55e9b3b 100644 --- a/src/autolabel/labeler.py +++ b/src/autolabel/labeler.py @@ -225,9 +225,7 @@ async def arun( console=self.console, ) if self.console_output - else tqdm(indices) - if self.use_tqdm - else indices + else tqdm(indices) if self.use_tqdm else indices ): chunk = dataset.inputs[current_index] examples = [] @@ -562,7 +560,7 @@ async def agenerate_explanations( if col in seed_example and seed_example[col] is not None: explanation_prompt[col] = seed_example[col] explanation_prompt = json.dumps(explanation_prompt) - response = await self.llm.label([explanation_prompt]) + response = await self.llm.label([explanation_prompt], output_schema=None) explanation = response.generations[0][0].text seed_example[explanation_column] = str(explanation) if explanation else "" if return_annotations: @@ -613,7 +611,7 @@ def generate_synthetic_dataset(self) -> AutolabelDataset: ): prompt = self.task.get_generate_dataset_prompt(label) - result = self.llm.label([prompt]) + result = self.llm.label([prompt], output_schema=None) if result.errors[0] is not None: self.console.print( f"Error generating rows for label {label}: {result.errors[0]}"