From 13b38c4163de9314f53c6abd73b0aea0c6ff36a2 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Fri, 22 Nov 2024 14:14:33 -0800 Subject: [PATCH 1/5] fmt --- src/autolabel/models/openai_vision.py | 17 +++++----- src/autolabel/transforms/serp_api.py | 46 ++++++++++++++++----------- 2 files changed, 36 insertions(+), 27 deletions(-) diff --git a/src/autolabel/models/openai_vision.py b/src/autolabel/models/openai_vision.py index 362285ef..4c3d2727 100644 --- a/src/autolabel/models/openai_vision.py +++ b/src/autolabel/models/openai_vision.py @@ -51,8 +51,7 @@ class OpenAIVisionLLM(BaseModel): def _engine(self) -> str: if self.model_name is not None and self.model_name in self.CHAT_ENGINE_MODELS: return "chat" - else: - return "completion" + return "completion" def __init__( self, @@ -66,7 +65,7 @@ def __init__( from openai import OpenAI except ImportError: raise ImportError( - "openai is required to use the OpenAIVisionLLM. Please install it with the following command: pip install 'refuel-autolabel[openai]'" + "openai is required to use the OpenAIVisionLLM. Please install it with the following command: pip install 'refuel-autolabel[openai]'", ) # populate model name @@ -106,7 +105,7 @@ def _label(self, prompts: List[str], output_schema: Dict) -> RefuelLLMResult: "url": parsed_prompt[col], "detail": "high", }, - } + }, ) result = self.llm( messages=[ @@ -114,15 +113,15 @@ def _label(self, prompts: List[str], output_schema: Dict) -> RefuelLLMResult: "role": "user", "content": content, }, - ] + ], ) generations.append( [ Generation( text=result.choices[0].message.content, generation_info=None, - ) - ] + ), + ], ) except Exception as e: logger.error(f"Error generating label: {e}") @@ -131,8 +130,8 @@ def _label(self, prompts: List[str], output_schema: Dict) -> RefuelLLMResult: Generation( text="", generation_info=None, - ) - ] + ), + ], ) return RefuelLLMResult( generations=generations, diff --git a/src/autolabel/transforms/serp_api.py b/src/autolabel/transforms/serp_api.py index f3323806..cce39780 100644 --- a/src/autolabel/transforms/serp_api.py +++ b/src/autolabel/transforms/serp_api.py @@ -1,12 +1,14 @@ -from collections import defaultdict import json -from autolabel.cache import BaseCache -from autolabel.transforms import BaseTransform -from langchain_community.utilities import SerpAPIWrapper -from typing import Dict, Any, List import logging +import time +from collections import defaultdict +from typing import Any, Dict, List + import pandas as pd +from langchain_community.utilities import SerpAPIWrapper +from autolabel.cache import BaseCache +from autolabel.transforms import BaseTransform from autolabel.transforms.schema import ( TransformError, TransformErrorType, @@ -21,7 +23,9 @@ class RefuelSerpAPIWrapper(SerpAPIWrapper): def __init__(self, search_engine=None, params=None, serpapi_api_key=None): super().__init__( - search_engine=search_engine, params=params, serpapi_api_key=serpapi_api_key + search_engine=search_engine, + params=params, + serpapi_api_key=serpapi_api_key, ) async def arun(self, query: str, **kwargs: Any) -> Dict: @@ -33,21 +37,21 @@ def _process_response(self, res: Dict) -> Dict: Processes the response from Serp API and returns the search results. """ cleaned_res = {} - if "error" in res.keys(): + if "error" in res: raise ValueError(f"Got error from SerpAPI: {res['error']}") - if "knowledge_graph" in res.keys(): + if "knowledge_graph" in res: cleaned_res["knowledge_graph"] = json.dumps(res["knowledge_graph"]) - if "organic_results" in res.keys(): + if "organic_results" in res: organic_results = list( map( lambda result: dict( filter( lambda item: item[0] in self.DEFAULT_ORGANIC_RESULTS_KEYS, result.items(), - ) + ), ), res["organic_results"], - ) + ), ) cleaned_res["organic_results"] = json.dumps(organic_results) return cleaned_res @@ -78,7 +82,9 @@ def __init__( self.serp_api_key = serp_api_key self.serp_args = serp_args self.serp_api_wrapper = RefuelSerpAPIWrapper( - search_engine=None, params=self.serp_args, serpapi_api_key=self.serp_api_key + search_engine=None, + params=self.serp_args, + serpapi_api_key=self.serp_api_key, ) def name(self) -> str: @@ -92,7 +98,7 @@ async def _get_result(self, query): try: search_result = await self.serp_api_wrapper.arun(query=query) except Exception as e: - logger.error(f"Error while making request to Serp API: {str(e)}") + logger.error(f"Error while making request to Serp API: {e!s}") raise TransformError( TransformErrorType.TRANSFORM_API_ERROR, f"Error while making request with query: {query}", @@ -100,13 +106,14 @@ async def _get_result(self, query): return search_result async def _apply(self, row: Dict[str, Any]) -> Dict[str, Any]: + start_time = time.time() for col in self.query_columns: if col not in row: logger.warning( f"Missing query column: {col} in row {row}", ) query = self.query_template.format_map( - defaultdict(str, {key: val for key, val in row.items() if val is not None}) + defaultdict(str, {key: val for key, val in row.items() if val is not None}), ) search_result = self.NULL_TRANSFORM_TOKEN if pd.isna(query) or query == self.NULL_TRANSFORM_TOKEN: @@ -114,16 +121,19 @@ async def _apply(self, row: Dict[str, Any]) -> Dict[str, Any]: TransformErrorType.INVALID_INPUT, f"Empty query in row {row}", ) - else: - search_result = await self._get_result(query) + search_result = await self._get_result(query) transformed_row = { self.output_columns["knowledge_graph_results"]: search_result.get( - "knowledge_graph" + "knowledge_graph", ), self.output_columns["organic_search_results"]: search_result.get( - "organic_results" + "organic_results", ), } + end_time = time.time() + logger.error( + f"Time taken to run Serp API: {end_time - start_time} seconds", + ) return self._return_output_row(transformed_row) From 5cbb73217777fee2af3273780ab0194c25f1243c Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 25 Nov 2024 18:30:10 -0800 Subject: [PATCH 2/5] Supporting sending s3 uris for images and pdfs to autolabel --- src/autolabel/dataset/dataset.py | 39 ++++--- src/autolabel/labeler.py | 83 +++++++++------ src/autolabel/models/openai.py | 36 ++++--- src/autolabel/models/openai_vision.py | 4 +- src/autolabel/task_chain/task_chain.py | 96 +++++++++++++---- src/autolabel/tasks/attribute_extraction.py | 111 ++++++++++++-------- src/autolabel/tasks/base.py | 11 +- src/autolabel/transforms/ocr.py | 2 +- src/autolabel/utils.py | 76 +++++++++++--- 9 files changed, 314 insertions(+), 144 deletions(-) diff --git a/src/autolabel/dataset/dataset.py b/src/autolabel/dataset/dataset.py index e3ceb612..7684f07f 100644 --- a/src/autolabel/dataset/dataset.py +++ b/src/autolabel/dataset/dataset.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Dict, List, Union, Optional +from typing import Callable, Dict, List, Optional, Union import pandas as pd from rich.console import Console @@ -39,12 +39,14 @@ def __init__( ) -> None: """ Initializes the dataset. + Args: dataset: The dataset to be used for labeling. Could be a path to a csv/jsonl file or a pandas dataframe. config: The config to be used for labeling. Could be a path to a json file or a dictionary. max_items: The maximum number of items to be parsed into the dataset object. start_index: The index to start parsing the dataset from. validate: Whether to validate the dataset or not. + """ if not (isinstance(config, AutolabelConfig)): self.config = AutolabelConfig(config) @@ -105,7 +107,9 @@ def get_slice(self, max_items: int = None, start_index: int = 0): return AutolabelDataset(df, self.config) def process_labels( - self, llm_labels: List[LLMAnnotation], metrics: List[MetricResult] = None + self, + llm_labels: List[LLMAnnotation], + metrics: List[MetricResult] = None, ): # Add the LLM labels to the dataframe self.df[self.generate_label_name("label")] = [x.label for x in llm_labels] @@ -152,13 +156,13 @@ def process_labels( for x in llm_labels: if x.successfully_labeled: attr_confidence_scores.append( - x.confidence_score.get(attr["name"], 0.0) + x.confidence_score.get(attr["name"], 0.0), ) else: attr_confidence_scores.append(0.0) - self.df[ - self.generate_label_name("confidence", attr["name"]) - ] = attr_confidence_scores + self.df[self.generate_label_name("confidence", attr["name"])] = ( + attr_confidence_scores + ) # Add the LLM explanations to the dataframe if chain of thought is set in config if self.config.chain_of_thought(): @@ -169,8 +173,10 @@ def process_labels( def save(self, output_file_name: str): """ Saves the dataset to a file based on the file extension. + Args: output_file_name: The name of the file to save the dataset to. Based on the extension we can save to a csv or jsonl file. + """ if output_file_name.endswith(".csv"): self.df.to_csv( @@ -245,21 +251,26 @@ def completed(self): return AutolabelDataset(filtered_df, self.config) def incorrect( - self, label: str = None, ground_truth: str = None, label_column: str = None + self, + label: str = None, + ground_truth: str = None, + label_column: str = None, ): """ Filter the dataset to only include incorrect items. This means the labels where the llm label was incorrect. + Args: label: The llm label to filter on. ground_truth: The ground truth label to filter on. label_column: The column to filter on. This is only used for attribute extraction tasks. + """ gt_label_column = label_column or self.config.label_column() if gt_label_column is None: raise ValueError( - "Cannot compute mistakes without ground truth label column" + "Cannot compute mistakes without ground truth label column", ) filtered_df = self.df[ @@ -281,8 +292,10 @@ def correct(self, label_column: str = None): """ Filter the dataset to only include correct items. This means the labels where the llm label was correct. + Args: label_column: The column to filter on. This is only used for attribute extraction tasks. + """ gt_label_column = label_column or self.config.label_column() @@ -298,12 +311,14 @@ def correct(self, label_column: str = None): def filter_by_confidence(self, threshold: float = 0.5): """ Filter the dataset to only include items with confidence scores greater than the threshold. + Args: threshold: The threshold to filter on. This means that only items with confidence scores greater than the threshold will be included. + """ if not self.config.confidence(): raise ValueError( - "Cannot compute correct and confident without confidence scores" + "Cannot compute correct and confident without confidence scores", ) filtered_df = self.df[ @@ -360,13 +375,13 @@ def _validate(self): if len(self.__malformed_records) > 0: logger.warning( - f"Data Validation failed for {len(self.__malformed_records)} records: \n Stats: \n {table}" + f"Data Validation failed for {len(self.__malformed_records)} records: \n Stats: \n {table}", ) raise DataValidationFailed( - f"Validation failed for {len(self.__malformed_records)} rows." + f"Validation failed for {len(self.__malformed_records)} rows.", ) - def generate_label_name(self, col_name: str, label_column: str = None): + def generate_label_name(self, col_name: str, label_column: str = None) -> str: label_column = label_column or f"{self.config.task_name()}_task" return f"{label_column}_{col_name}" diff --git a/src/autolabel/labeler.py b/src/autolabel/labeler.py index a768cdd4..4e7bcf13 100644 --- a/src/autolabel/labeler.py +++ b/src/autolabel/labeler.py @@ -94,7 +94,7 @@ def __init__( self.confidence_cache = confidence_cache if not cache: logger.warning( - "cache parameter is deprecated and will be removed soon. Please use generation_cache, transform_cache and confidence_cache instead." + "cache parameter is deprecated and will be removed soon. Please use generation_cache, transform_cache and confidence_cache instead.", ) self.generation_cache = None self.transform_cache = None @@ -116,13 +116,15 @@ def __init__( ) self.task = TaskFactory.from_config(self.config) self.llm: BaseModel = ModelFactory.from_config( - self.config, cache=self.generation_cache, tokenizer=confidence_tokenizer + self.config, + cache=self.generation_cache, + tokenizer=confidence_tokenizer, ) if self.config.confidence_chunk_column(): if not confidence_tokenizer: self.confidence_tokenizer = AutoTokenizer.from_pretrained( - **DEFAULT_TOKENIZATION_MODEL + **DEFAULT_TOKENIZATION_MODEL, ) else: self.confidence_tokenizer = confidence_tokenizer @@ -159,7 +161,7 @@ def run( start_index=start_index, additional_metrics=additional_metrics, skip_eval=skip_eval, - ) + ), ) async def arun( @@ -171,15 +173,16 @@ async def arun( additional_metrics: Optional[List[BaseMetric]] = [], skip_eval: Optional[bool] = False, ) -> Tuple[pd.Series, pd.DataFrame, List[MetricResult]]: - """Labels data in a given dataset. Output written to new CSV file. + """ + Labels data in a given dataset. Output written to new CSV file. Args: dataset: path to CSV dataset to be annotated max_items: maximum items in dataset to be annotated output_name: custom name of output CSV file start_index: skips annotating [0, start_index) - """ + """ dataset = dataset.get_slice(max_items=max_items, start_index=start_index) llm_labels = [] @@ -198,7 +201,7 @@ async def arun( and self.config.explanation_column() not in list(seed_examples[0].keys()) ): raise ValueError( - f"Explanation column {self.config.explanation_column()} not found in dataset.\nMake sure that explanations were generated using labeler.generate_explanations(seed_file)." + f"Explanation column {self.config.explanation_column()} not found in dataset.\nMake sure that explanations were generated using labeler.generate_explanations(seed_file).", ) if self.example_selector is None and self.config.few_shot_algorithm(): @@ -208,7 +211,7 @@ async def arun( ): # TODO: Add support for other few shot algorithms specially semantic similarity raise ValueError( - "Error: Only 'fixed' few shot example selector is supported for label selection." + "Error: Only 'fixed' few shot example selector is supported for label selection.", ) self.example_selector = ExampleSelectorFactory.initialize_selector( @@ -226,7 +229,7 @@ async def arun( self.label_selector_map = {} for attribute in self.config.attributes(): label_selection_count = attribute.get( - AutolabelConfig.LABEL_SELECTION_KEY + AutolabelConfig.LABEL_SELECTION_KEY, ) if label_selection_count: label_selector = LabelSelector( @@ -275,11 +278,10 @@ async def arun( safe_serialize_to_string(chunk), selected_labels_map=selected_labels_map, ) - else: - if self.example_selector: - examples = self.example_selector.select_examples( - safe_serialize_to_string(chunk), - ) + elif self.example_selector: + examples = self.example_selector.select_examples( + safe_serialize_to_string(chunk), + ) # Construct Prompt to pass to LLM final_prompt, output_schema = self.task.construct_prompt( @@ -325,7 +327,7 @@ async def arun( ) annotation.input_tokens = input_tokens annotation.output_tokens = self.llm.get_num_tokens( - annotation.raw_response + annotation.raw_response, ) annotation.cost = sum(response.costs) annotation.latency = latency @@ -335,7 +337,8 @@ async def arun( keys = ( { attribute_dict.get( - "name", "" + "name", + "", ): attribute_dict.get("task_type", "") for attribute_dict in self.config.attributes() } @@ -345,15 +348,16 @@ async def arun( ) annotation.confidence_score = ( await self.confidence.calculate( - model_generation=annotation, keys=keys + model_generation=annotation, + keys=keys, ) ) except Exception as e: logger.exception( - f"Error calculating confidence score: {e}" + f"Error calculating confidence score: {e}", ) logger.warning( - f"Could not calculate confidence score for annotation: {annotation}" + f"Could not calculate confidence score for annotation: {annotation}", ) annotation.confidence_score = {} annotations.append(annotation) @@ -384,7 +388,7 @@ async def arun( # This is a row wise metric if isinstance(m.value, list): continue - elif m.show_running: + if m.show_running: postfix_dict[m.name] = ( f"{m.value:.4f}" if isinstance(m.value, float) @@ -409,7 +413,7 @@ async def arun( for m in eval_result: if isinstance(m.value, list): continue - elif m.show_running: + if m.show_running: table[m.name] = m.value else: self.console.print(f"{m.name}:\n{m.value}") @@ -437,10 +441,12 @@ def plan( max_items: Optional[int] = None, start_index: int = 0, ) -> None: - """Calculates and prints the cost of calling autolabel.run() on a given dataset + """ + Calculates and prints the cost of calling autolabel.run() on a given dataset Args: dataset: path to a CSV dataset + """ dataset = dataset.get_slice(max_items=max_items, start_index=start_index) @@ -450,7 +456,7 @@ def plan( and not self.llm.returns_token_probs() ): raise ValueError( - "REFUEL_API_KEY environment variable must be set to compute confidence scores. You can request an API key at https://refuel-ai.typeform.com/llm-access." + "REFUEL_API_KEY environment variable must be set to compute confidence scores. You can request an API key at https://refuel-ai.typeform.com/llm-access.", ) prompt_list = [] @@ -471,7 +477,7 @@ def plan( and self.config.explanation_column() not in list(seed_examples[0].keys()) ): raise ValueError( - f"Explanation column {self.config.explanation_column()} not found in dataset.\nMake sure that explanations were generated using labeler.generate_explanations(seed_file)." + f"Explanation column {self.config.explanation_column()} not found in dataset.\nMake sure that explanations were generated using labeler.generate_explanations(seed_file).", ) self.example_selector = ExampleSelectorFactory.initialize_selector( @@ -490,7 +496,7 @@ def plan( # TODO: Check if this needs to use the example selector if self.example_selector: examples = self.example_selector.select_examples( - safe_serialize_to_string(input_i) + safe_serialize_to_string(input_i), ) else: examples = [] @@ -516,14 +522,19 @@ def plan( table = {"parameter": list(table.keys()), "value": list(table.values())} print_table( - table, show_header=False, console=self.console, styles=COST_TABLE_STYLES + table, + show_header=False, + console=self.console, + styles=COST_TABLE_STYLES, ) self.console.rule("Prompt Example") self.console.print(f"{prompt_list[0]}", markup=False) self.console.rule() async def async_run_transform( - self, transform: BaseTransform, dataset: AutolabelDataset + self, + transform: BaseTransform, + dataset: AutolabelDataset, ) -> AutolabelDataset: transform_outputs = [ transform.apply(input_dict) for input_dict in dataset.inputs @@ -547,7 +558,7 @@ def transform(self, dataset: AutolabelDataset) -> AutolabelDataset: transforms = [] for transform_dict in self.config.transforms(): transforms.append( - TransformFactory.from_dict(transform_dict, cache=self.transform_cache) + TransformFactory.from_dict(transform_dict, cache=self.transform_cache), ) for transform in transforms: dataset = asyncio.run(self.async_run_transform(transform, dataset)) @@ -555,7 +566,8 @@ def transform(self, dataset: AutolabelDataset) -> AutolabelDataset: return dataset def majority_annotation( - self, annotation_list: List[LLMAnnotation] + self, + annotation_list: List[LLMAnnotation], ) -> LLMAnnotation: labels = [a.label for a in annotation_list] counts = {} @@ -581,7 +593,7 @@ def generate_explanations( seed_examples=seed_examples, include_label=include_label, return_annotations=return_annotations, - ) + ), ) async def agenerate_explanations( @@ -600,7 +612,7 @@ async def agenerate_explanations( explanation_column = self.config.explanation_column() if not explanation_column: raise ValueError( - "The explanation column needs to be specified in the dataset config." + "The explanation column needs to be specified in the dataset config.", ) llm_annotations = [] for seed_example in track( @@ -609,7 +621,8 @@ async def agenerate_explanations( console=self.console, ): explanation_prompt = self.task.get_explanation_prompt( - seed_example, include_label=include_label + seed_example, + include_label=include_label, ) if self.task.image_cols is not None and len(self.task.image_cols) > 0: explanation_prompt = {"text": explanation_prompt} @@ -671,7 +684,7 @@ def generate_synthetic_dataset(self) -> AutolabelDataset: result = self.llm.label([prompt], output_schema=None) if result.errors[0] is not None: self.console.print( - f"Error generating rows for label {label}: {result.errors[0]}" + f"Error generating rows for label {label}: {result.errors[0]}", ) else: response = result.generations[0][0].text.strip() @@ -685,8 +698,10 @@ def generate_synthetic_dataset(self) -> AutolabelDataset: def clear_cache(self, use_ttl: bool = True): """ Clears the generation and transformation cache from autolabel. + Args: use_ttl: If true, only clears the cache if the ttl has expired. + """ self.generation_cache.clear(use_ttl=use_ttl) self.transform_cache.clear(use_ttl=use_ttl) @@ -695,7 +710,7 @@ def get_num_tokens(self, inp: str) -> int: if not self.confidence_tokenizer: logger.warning("Confidence tokenizer is not set. Using default tokenizer.") self.confidence_tokenizer = AutoTokenizer.from_pretrained( - **DEFAULT_TOKENIZATION_MODEL + **DEFAULT_TOKENIZATION_MODEL, ) """Returns the number of tokens in the prompt""" return len(self.confidence_tokenizer.encode(str(inp))) diff --git a/src/autolabel/models/openai.py b/src/autolabel/models/openai.py index 4df139a1..d740cd93 100644 --- a/src/autolabel/models/openai.py +++ b/src/autolabel/models/openai.py @@ -35,7 +35,7 @@ class OpenAILLM(BaseModel): "gpt-4o", "gpt-4o-2024-08-06", "gpt-4o-mini", - ] + ], ) MODELS_WITH_TOKEN_PROBS = set( [ @@ -57,7 +57,7 @@ class OpenAILLM(BaseModel): "gpt-4o", "gpt-4o-2024-08-06", "gpt-4o-mini", - ] + ], ) SUPPORTS_JSON_OUTPUTS = set( @@ -70,7 +70,7 @@ class OpenAILLM(BaseModel): "gpt-4o", "gpt-4o-2024-08-06", "gpt-4o-mini", - ] + ], ) SUPPORTS_STRUCTURED_OUTPUTS = set( @@ -78,7 +78,7 @@ class OpenAILLM(BaseModel): "gpt-4o", "gpt-4o-2024-08-06", "gpt-4o-mini", - ] + ], ) # Default parameters for OpenAILLM @@ -146,8 +146,7 @@ class OpenAILLM(BaseModel): def _engine(self) -> str: if self.model_name is not None and self.model_name in self.CHAT_ENGINE_MODELS: return "chat" - else: - return "completion" + return "completion" def __init__( self, @@ -161,7 +160,7 @@ def __init__( from langchain_openai import ChatOpenAI, OpenAI except ImportError: raise ImportError( - "openai is required to use the OpenAILLM. Please install it with the following command: pip install 'refuel-autolabel[openai]'" + "openai is required to use the OpenAILLM. Please install it with the following command: pip install 'refuel-autolabel[openai]'", ) self.tiktoken = tiktoken # populate model name @@ -175,7 +174,9 @@ def __init__( if self._engine == "chat": self.model_params = {**self.DEFAULT_PARAMS_CHAT_ENGINE, **model_params} self.llm = ChatOpenAI( - model_name=self.model_name, verbose=False, **self.model_params + model_name=self.model_name, + verbose=False, + **self.model_params, ) else: self.model_params = { @@ -183,11 +184,14 @@ def __init__( **model_params, } self.llm = OpenAI( - model_name=self.model_name, verbose=False, **self.model_params + model_name=self.model_name, + verbose=False, + **self.model_params, ) def _chat_backward_compatibility( - self, generations: List[LLMResult] + self, + generations: List[LLMResult], ) -> List[LLMResult]: for generation_options in generations: for curr_generation in generation_options: @@ -195,7 +199,7 @@ def _chat_backward_compatibility( new_logprobs = {"top_logprobs": []} for curr_token in generation_info["logprobs"]["content"]: new_logprobs["top_logprobs"].append( - {curr_token["token"]: curr_token["logprob"]} + {curr_token["token"]: curr_token["logprob"]}, ) curr_generation.generation_info["logprobs"] = new_logprobs return generations @@ -231,7 +235,7 @@ async def _alabel(self, prompts: List[str], output_schema: Dict) -> RefuelLLMRes ) else: logger.info( - "Not using structured output despite output_schema provided" + "Not using structured output despite output_schema provided", ) result = await self.llm.agenerate(prompts) generations = self._chat_backward_compatibility(result.generations) @@ -255,7 +259,8 @@ async def _alabel(self, prompts: List[str], output_schema: Dict) -> RefuelLLMRes ] error_code = error_json.get("code") error_type = self.ERROR_TYPE_MAPPING.get( - error_code, ErrorType.LLM_PROVIDER_ERROR + error_code, + ErrorType.LLM_PROVIDER_ERROR, ) error_message = error_json.get("message") except Exception as e: @@ -303,7 +308,7 @@ def _label(self, prompts: List[str], output_schema: Dict) -> RefuelLLMResult: ) else: logger.info( - "Not using structured output despite output_schema provided" + "Not using structured output despite output_schema provided", ) result = self.llm.generate(prompts) generations = self._chat_backward_compatibility(result.generations) @@ -328,7 +333,8 @@ def _label(self, prompts: List[str], output_schema: Dict) -> RefuelLLMResult: ] error_code = error_json.get("code") error_type = self.ERROR_TYPE_MAPPING.get( - error_code, ErrorType.LLM_PROVIDER_ERROR + error_code, + ErrorType.LLM_PROVIDER_ERROR, ) error_message = error_json.get("message") except Exception as e: diff --git a/src/autolabel/models/openai_vision.py b/src/autolabel/models/openai_vision.py index 4c3d2727..f24b00d4 100644 --- a/src/autolabel/models/openai_vision.py +++ b/src/autolabel/models/openai_vision.py @@ -84,6 +84,7 @@ def __init__( ) self.tiktoken = tiktoken self.image_cols = config.image_columns() + self.input_cols = config.input_columns() def _label(self, prompts: List[str], output_schema: Dict) -> RefuelLLMResult: generations = [] @@ -95,7 +96,8 @@ def _label(self, prompts: List[str], output_schema: Dict) -> RefuelLLMResult: if self.image_cols: for col in self.image_cols: if ( - parsed_prompt.get(col) is not None + col in self.input_cols + and parsed_prompt.get(col) is not None and len(parsed_prompt[col]) > 0 ): content.append( diff --git a/src/autolabel/task_chain/task_chain.py b/src/autolabel/task_chain/task_chain.py index 5d4bb156..40355f96 100644 --- a/src/autolabel/task_chain/task_chain.py +++ b/src/autolabel/task_chain/task_chain.py @@ -1,22 +1,25 @@ -from collections import defaultdict -from autolabel.configs import AutolabelConfig +import copy import logging -from typing import Dict, List, Optional +from collections import defaultdict +from typing import Dict, List, Optional, Tuple -from autolabel.few_shot.base_label_selector import BaseLabelSelector -from autolabel.labeler import LabelingAgent +import boto3 +import pandas as pd +from transformers import AutoTokenizer + +from autolabel.cache.base import BaseCache +from autolabel.cache.sqlalchemy_confidence_cache import SQLAlchemyConfidenceCache +from autolabel.cache.sqlalchemy_generation_cache import SQLAlchemyGenerationCache +from autolabel.cache.sqlalchemy_transform_cache import SQLAlchemyTransformCache +from autolabel.configs import AutolabelConfig, TaskChainConfig from autolabel.dataset import AutolabelDataset from autolabel.few_shot import ( BaseExampleSelector, ) -from autolabel.cache.sqlalchemy_generation_cache import SQLAlchemyGenerationCache -from autolabel.cache.sqlalchemy_transform_cache import SQLAlchemyTransformCache -from autolabel.cache.sqlalchemy_confidence_cache import SQLAlchemyConfidenceCache -from autolabel.cache.base import BaseCache -from autolabel.configs import TaskChainConfig +from autolabel.few_shot.base_label_selector import BaseLabelSelector +from autolabel.labeler import LabelingAgent from autolabel.transforms import TransformFactory -from transformers import AutoTokenizer -import pandas as pd +from autolabel.utils import generate_presigned_url, is_s3_uri logger = logging.getLogger(__name__) logging.getLogger("httpx").setLevel(logging.WARNING) @@ -28,21 +31,25 @@ def __init__(self, task_chain: List[Dict]): self.task_chain = task_chain def add_dependency(self, pre_task: str, post_task: str): - """Add dependencies between pairs of tasks + """ + Add dependencies between pairs of tasks Args: pre_task (str): The task that must be completed before post_task post_task (str): The task that depends on pre_task + """ self.graph[pre_task].add(post_task) def topological_sort_helper(self, pre_task: str, visited: Dict, stack: List): - """Recursive helper function to perform topological sort + """ + Recursive helper function to perform topological sort Args: pre_task (str): The task we are currently visiting visited (Dict): Dict of visited tasks stack (List): Stack to store the sorted tasks (in reverse order) + """ visited[pre_task] = True @@ -52,10 +59,12 @@ def topological_sort_helper(self, pre_task: str, visited: Dict, stack: List): stack.append(pre_task) def topological_sort(self) -> List[str]: - """Topological sort of the task graph + """ + Topological sort of the task graph Returns: List[str]: List of task names in topological order + """ visited = defaultdict(bool) stack = [] @@ -66,10 +75,13 @@ def topological_sort(self) -> List[str]: return stack[::-1] def check_cycle(self): - """Check for cycles in the task graph + """ + Check for cycles in the task graph Returns: - bool: True if cycle is present, False otherwise""" + bool: True if cycle is present, False otherwise + + """ visited = defaultdict(bool) rec_stack = defaultdict(bool) @@ -81,7 +93,8 @@ def check_cycle(self): return False def check_cycle_helper(self, pre_task: str, visited: Dict, rec_stack: Dict): - """Recursive helper function to check for cycles + """ + Recursive helper function to check for cycles Args: pre_task (str): The task we are currently visiting visited (Dict): List of visited tasks @@ -126,6 +139,7 @@ def __init__( self.confidence_endpoint = confidence_endpoint self.column_name_map = column_name_map self.label_selector_map = label_selector_map + self.s3_client = boto3.client("s3") # TODO: For now, we run each separate step of the task chain serially and aggregate at the end. # We can optimize this with parallelization where possible/no dependencies. @@ -137,6 +151,7 @@ async def run(self, dataset_df: pd.DataFrame): dataset_df (pd.DataFrame): Input dataset Returns: AutolabelDataset: Output dataset with the results of the task chain + """ subtasks = self.task_chain_config.subtasks() if len(subtasks) == 0: @@ -144,6 +159,10 @@ async def run(self, dataset_df: pd.DataFrame): for task in subtasks: autolabel_config = AutolabelConfig(task) dataset = AutolabelDataset(dataset_df, autolabel_config) + dataset, original_inputs = self.safe_convert_uri_to_presigned_url( + dataset, + autolabel_config, + ) if autolabel_config.transforms(): agent = LabelingAgent( config=autolabel_config, @@ -180,12 +199,19 @@ async def run(self, dataset_df: pd.DataFrame): dataset, skip_eval=True, ) + dataset = self.reset_presigned_url_to_uri( + dataset, + original_inputs, + autolabel_config, + ) dataset = self.rename_output_columns(dataset, autolabel_config) dataset_df = dataset.df return dataset def rename_output_columns( - self, dataset: AutolabelDataset, autolabel_config: AutolabelConfig + self, + dataset: AutolabelDataset, + autolabel_config: AutolabelConfig, ): """ Rename the output columns of the dataset for each intermediate step in the task chain so that @@ -196,6 +222,7 @@ def rename_output_columns( task (ChainTask): The current task in the task chain Returns: AutolabelDataset: The dataset with renamed output columns + """ if autolabel_config.transforms(): dataset.df.rename(columns=self.column_name_map, inplace=True) @@ -206,3 +233,34 @@ def rename_output_columns( ].apply(lambda x: x.get(attribute) if x and type(x) is dict else None) return dataset + + def safe_convert_uri_to_presigned_url( + self, + dataset: AutolabelDataset, + autolabel_config: AutolabelConfig, + ) -> Tuple[AutolabelDataset, List[Dict]]: + original_inputs = copy.deepcopy(dataset.inputs) + for col in autolabel_config.input_columns(): + for i in range(len(dataset.inputs)): + dataset.inputs[i][col] = ( + generate_presigned_url( + self.s3_client, + dataset.inputs[i][col], + ) + if is_s3_uri(dataset.inputs[i][col]) + else dataset.inputs[i][col] + ) + dataset.df.loc[i, col] = dataset.inputs[i][col] + return dataset, original_inputs + + def reset_presigned_url_to_uri( + self, + dataset: AutolabelDataset, + original_inputs: List[Dict], + autolabel_config: AutolabelConfig, + ) -> AutolabelDataset: + for col in autolabel_config.input_columns(): + for i in range(len(dataset.inputs)): + dataset.inputs[i][col] = original_inputs[i][col] + dataset.df.loc[i, col] = dataset.inputs[i][col] + return dataset diff --git a/src/autolabel/tasks/attribute_extraction.py b/src/autolabel/tasks/attribute_extraction.py index edfda42e..36c95c76 100644 --- a/src/autolabel/tasks/attribute_extraction.py +++ b/src/autolabel/tasks/attribute_extraction.py @@ -1,11 +1,11 @@ +import copy import json -import json5 import logging import pickle -import copy from collections import defaultdict -from typing import Callable, Dict, List, Optional, Union, Tuple +from typing import Callable, Dict, List, Optional, Tuple, Union +import json5 from langchain.prompts.prompt import PromptTemplate from langchain.schema import ChatGeneration, Generation @@ -54,35 +54,46 @@ def __init__(self, config: AutolabelConfig) -> None: self.metrics.append(AUROCMetric()) def _construct_attribute_json( - self, selected_labels_map: Dict[str, List[str]] = None + self, + selected_labels_map: Dict[str, List[str]] = None, ) -> Tuple[str, Dict]: - """This function is used to construct the attribute json string for the output guidelines. + """ + This function is used to construct the attribute json string for the output guidelines. + Args: attributes (List[Dict]): A list of dictionaries containing the output attributes. Returns: str: A string containing the output attributes. + """ - output_json, output_schema = {}, { - "title": "AnswerFormat", - "description": "Answer to the provided prompt.", - "type": "object", - "properties": {}, - "required": [], - "additionalProperties": False, - "definitions": {}, - } + output_json, output_schema = ( + {}, + { + "title": "AnswerFormat", + "description": "Answer to the provided prompt.", + "type": "object", + "properties": {}, + "required": [], + "additionalProperties": False, + "definitions": {}, + }, + ) for attribute_dict in self.config.attributes(): if "name" not in attribute_dict or "description" not in attribute_dict: raise ValueError( - "Attribute dictionary must contain 'name' and 'description' keys" + "Attribute dictionary must contain 'name' and 'description' keys", ) attribute_desc = attribute_dict["description"] attribute_name = attribute_dict["name"] - if TaskType.MULTILABEL_CLASSIFICATION == attribute_dict.get( - "task_type", "" + if ( + attribute_dict.get( + "task_type", + "", + ) + == TaskType.MULTILABEL_CLASSIFICATION ): attribute_desc += " The output format should be all the labels separated by semicolons. For example: label1;label2;label3" @@ -104,12 +115,12 @@ def _construct_attribute_json( ): curr_property = {"$ref": "#/definitions/" + attribute_name} output_schema["definitions"][attribute_name] = json5.loads( - attribute_dict["schema"] + attribute_dict["schema"], ) else: curr_property = {"title": attribute_dict["name"], "type": "string"} if "options" in attribute_dict and len(attribute_dict["options"]) < 500: - if TaskType.CLASSIFICATION == attribute_dict.get("task_type", ""): + if attribute_dict.get("task_type", "") == TaskType.CLASSIFICATION: curr_property = {"$ref": "#/definitions/" + attribute_name} output_schema["definitions"][attribute_name] = { "title": attribute_name, @@ -122,33 +133,37 @@ def _construct_attribute_json( return json.dumps(output_json, indent=4), output_schema def _generate_output_dict(self, input: Dict) -> Optional[str]: - """Generate the output dictionary from the input + """ + Generate the output dictionary from the input Args: input (Dict): The input dictionary Returns: Dict: The output dictionary + """ output_dict = {} for attribute in self.config.attributes(): attribute_name = attribute["name"] output_dict[attribute_name] = input.get(attribute_name, "") if not self._validate_output_dict(output_dict): - logger.warn( - f"Generated output dict: {output_dict} does not contain all the expected output attributes. Skipping example." + logger.warning( + f"Generated output dict: {output_dict} does not contain all the expected output attributes. Skipping example.", ) return None return json.dumps(output_dict) def _validate_output_dict(self, output_dict: Dict) -> bool: - """Validate the output dictionary + """ + Validate the output dictionary Args: output_dict (Dict): The output dictionary Returns: bool: True if the output dictionary is valid, False otherwise + """ for attribute in self.config.attributes(): attribute_name = attribute.get("name") @@ -194,7 +209,7 @@ def construct_prompt( selected_labels_map[attribute_name].append(l) attribute_json, output_schema = self._construct_attribute_json( - selected_labels_map=selected_labels_map + selected_labels_map=selected_labels_map, ) output_guidelines = ( self.output_guidelines @@ -222,17 +237,18 @@ def construct_prompt( except KeyError as e: try: current_example = example_template.format_map(defaultdict(str, input)) - logger.warn( + logger.warning( f'\n\nKey {e} in the "example_template" in the given config' f"\n\n{example_template}\n\nis not present in the datsaset columns - {input.keys()}.\n\n" f"Input - {input}\n\n" - "Continuing with the prompt as {current_example}" + "Continuing with the prompt as {current_example}", ) except AttributeError as e: - for key in input.keys(): + for key in input: if input[key] is not None: example_template = example_template.replace( - f"{{{key}}}", input[key] + f"{{{key}}}", + input[key], ) current_example = example_template @@ -264,12 +280,15 @@ def construct_prompt( if self.image_cols: prompt_dict = {"text": curr_text_prompt} for col in self.image_cols: - if input.get(col) is not None and len(input.get(col)) > 0: + if ( + col in self.input_cols + and input.get(col) is not None + and len(input.get(col)) > 0 + ): prompt_dict[col] = input[col] prompt_dict[col] = input[col] return json.dumps(prompt_dict), output_schema - else: - return curr_text_prompt, output_schema + return curr_text_prompt, output_schema def get_explanation_prompt(self, example: Dict, include_label=True) -> str: pt = PromptTemplate( @@ -293,7 +312,10 @@ def get_explanation_prompt(self, example: Dict, include_label=True) -> str: ) def get_generate_dataset_prompt( - self, label: str, num_rows: int, guidelines: str = None + self, + label: str, + num_rows: int, + guidelines: str = None, ) -> str: raise NotImplementedError("Dataset generation not implemented for this task") @@ -321,13 +343,13 @@ def parse_llm_response( successfully_labeled = True except Exception as e: logger.info( - f"Error parsing LLM response: {response.text}, Error: {e}. Now searching for valid JSON in response" + f"Error parsing LLM response: {response.text}, Error: {e}. Now searching for valid JSON in response", ) try: json_start, json_end = response.text.find("{"), response.text.rfind("}") llm_label = {} for k, v in json5.loads( - response.text[json_start : json_end + 1] + response.text[json_start : json_end + 1], ).items(): if isinstance(v, list) or isinstance(v, dict): llm_label[k] = v @@ -353,25 +375,25 @@ def parse_llm_response( if attr_type == TaskType.CLASSIFICATION: if attr_label is not None and attr_label not in attr_options: logger.warning( - f"Attribute {attr_label} from the LLM response {llm_label} is not in the labels list" + f"Attribute {attr_label} from the LLM response {llm_label} is not in the labels list", ) llm_label.pop(attribute["name"], None) elif attr_type == TaskType.MULTILABEL_CLASSIFICATION: original_attr_labels = attr_label.split( - self.config.label_separator() + self.config.label_separator(), ) filtered_attr_labels = list( filter( lambda x: x.strip() in attr_options, original_attr_labels, - ) + ), + ) + llm_label[attribute["name"]] = ( + self.config.label_separator().join(filtered_attr_labels) ) - llm_label[ - attribute["name"] - ] = self.config.label_separator().join(filtered_attr_labels) if len(filtered_attr_labels) != len(original_attr_labels): logger.warning( - f"Attribute {attr_label} from the LLM response {llm_label} is not in the labels list. Filtered list: {filtered_attr_labels}" + f"Attribute {attr_label} from the LLM response {llm_label} is not in the labels list. Filtered list: {filtered_attr_labels}", ) if len(filtered_attr_labels) == 0: llm_label.pop(attribute["name"], None) @@ -393,7 +415,6 @@ def eval( additional_metrics: List[BaseMetric] = [], ) -> List[MetricResult]: """Evaluate the LLM generated labels by comparing them against ground truth""" - # Convert the llm labels into a mapping from # name -> List[LLMAnnotation] llm_labels_dict = defaultdict(list) @@ -412,7 +433,7 @@ def eval( if llm_label.confidence_score else 0 ), - ) + ), ) eval_metrics = [] @@ -432,7 +453,7 @@ def eval( MetricResult( name=f"{attribute}:{m.name}", value=m.value, - ) + ), ) if m.name not in macro_metrics: macro_metrics[m.name] = [] @@ -443,7 +464,7 @@ def eval( MetricResult( name=f"Macro:{key}", value=sum(macro_metrics[key]) / len(macro_metrics[key]), - ) + ), ) return eval_metrics diff --git a/src/autolabel/tasks/base.py b/src/autolabel/tasks/base.py index 5f76f40c..9a289aeb 100644 --- a/src/autolabel/tasks/base.py +++ b/src/autolabel/tasks/base.py @@ -37,7 +37,7 @@ class BaseTask(ABC): def __init__(self, config: AutolabelConfig) -> None: self.config = config self.image_cols = self.config.image_columns() - + self.input_cols = self.config.input_columns() # Update the default prompt template with the prompt template from the config self.task_guidelines = ( self.config.task_guidelines() or self.DEFAULT_TASK_GUIDELINES @@ -149,12 +149,15 @@ def eval( @abstractmethod def get_explanation_prompt(self, example: Dict, include_label=True) -> str: raise NotImplementedError( - "Explanation generation not implemented for this task" + "Explanation generation not implemented for this task", ) @abstractmethod def get_generate_dataset_prompt( - self, label: str, num_rows: int, guidelines: Optional[str] = None + self, + label: str, + num_rows: int, + guidelines: Optional[str] = None, ) -> str: raise NotImplementedError("Dataset generation not implemented for this task") @@ -172,7 +175,7 @@ def parse_llm_response( try: explanation = response.text.strip().split("\n")[0].strip() completion_text = extract_valid_json_substring( - response.text.strip().split("\n")[-1].strip() + response.text.strip().split("\n")[-1].strip(), ) completion_text = json.loads(completion_text)["label"] except Exception as _: diff --git a/src/autolabel/transforms/ocr.py b/src/autolabel/transforms/ocr.py index ce198ae9..a6e82925 100644 --- a/src/autolabel/transforms/ocr.py +++ b/src/autolabel/transforms/ocr.py @@ -159,7 +159,7 @@ async def _apply(self, row: dict[str, Any]) -> dict[str, Any]: ) from exc ocr_output = [] - if curr_file_path.endswith(".pdf"): + if Path(curr_file_path).suffix.lower().startswith(".pdf"): pages = self.convert_from_path(curr_file_path) ocr_output = [ self.default_ocr_processor(page, lang=self.lang) for page in pages diff --git a/src/autolabel/utils.py b/src/autolabel/utils.py index 55c89b15..00151ae7 100644 --- a/src/autolabel/utils.py +++ b/src/autolabel/utils.py @@ -8,6 +8,7 @@ import string from string import Formatter from typing import Any, Dict, Iterable, List, Optional, Sequence, Union +from urllib.parse import urlparse import regex import wget @@ -74,7 +75,7 @@ def calculate_md5(input_data: Any) -> str: if isinstance(input_data, dict): # Convert dictionary to a JSON-formatted string input_str = json.dumps(input_data, sort_keys=True, skipkeys=True).encode( - "utf-8" + "utf-8", ) elif hasattr(input_data, "read"): # Read binary data from file-like object @@ -116,7 +117,7 @@ def _autolabel_progress( MofNCompleteColumn(), TimeElapsedColumn(), TimeRemainingColumn(), - ) + ), ) return Progress( *columns, @@ -151,7 +152,8 @@ def track( console: Optional[Console] = None, disable: bool = False, ) -> Iterable[ProgressType]: - """Track progress by iterating over a sequence. + """ + Track progress by iterating over a sequence. Args: sequence (Iterable[ProgressType]): A sequence (must support "len") you wish to iterate over. @@ -161,8 +163,10 @@ def track( transient (bool, optional): Clear the progress on exit. Defaults to False. console (Console, optional): Console to write to. Default creates internal Console instance. disable (bool, optional): Disable display of progress. + Returns: Iterable[ProgressType]: An iterable of the values in the sequence. + """ progress = _autolabel_progress( description=description, @@ -194,7 +198,8 @@ async def gather_async_tasks_with_progress( console: Optional[Console] = None, disable: bool = False, ) -> Iterable: - """Gather async tasks with progress bar + """ + Gather async tasks with progress bar Args: tasks (Iterable): A sequence of async tasks you wish to gather. @@ -204,8 +209,10 @@ async def gather_async_tasks_with_progress( transient (bool, optional): Clear the progress on exit. Defaults to False. console (Console, optional): Console to write to. Default creates internal Console instance. disable (bool, optional): Disable display of progress. + Returns: Iterable: Returns an iterable of the results of the async tasks. + """ progress = _autolabel_progress( description=description, @@ -247,7 +254,8 @@ def track_with_stats( console: Optional[Console] = None, disable: bool = False, ) -> Iterable[ProgressType]: - """Track progress and displays stats by iterating over a sequence. + """ + Track progress and displays stats by iterating over a sequence. Args: sequence (Iterable[ProgressType]): A sequence (must support "len") you wish to iterate over. @@ -258,8 +266,10 @@ def track_with_stats( transient (bool, optional): Clear the progress on exit. Defaults to False. console (Console, optional): Console to write to. Default creates internal Console instance. disable (bool, optional): Disable display of progress. + Returns: Iterable[ProgressType]: An iterable of the values in the sequence. + """ progress = _autolabel_progress( description=description, @@ -281,7 +291,8 @@ def track_with_stats( with live: progress_task = progress.add_task(description=description, total=total) stats_task = stats_progress.add_task( - "Stats", stats=", ".join(f"{k}={v}" for k, v in stats.items()) + "Stats", + stats=", ".join(f"{k}={v}" for k, v in stats.items()), ) for value in sequence: yield value @@ -290,7 +301,8 @@ def track_with_stats( advance=min(advance, total - progress.tasks[progress_task].completed), ) stats_progress.update( - stats_task, stats=", ".join(f"{k}={v}" for k, v in stats.items()) + stats_task, + stats=", ".join(f"{k}={v}" for k, v in stats.items()), ) live.refresh() @@ -299,8 +311,7 @@ def maybe_round(value: Any) -> Any: """Round's value only if it has a round function""" if hasattr(value, "__round__"): return round(value, 4) - else: - return value + return value def print_table( @@ -310,7 +321,8 @@ def print_table( default_style: str = "bold", styles: Dict = {}, ) -> None: - """Print a table of data. + """ + Print a table of data. Args: data (Dict[str, List]): A dictionary of data to print. @@ -318,6 +330,7 @@ def print_table( console (Console, optional): Console to write to. Default creates internal Console instance. default_style (str, optional): Default style to apply to the table. Defaults to "bold". styles (Dict, optional): A dictionary of styles to apply to the table. + """ # Convert all values to strings data = { @@ -338,16 +351,18 @@ def print_table( def get_data(dataset_name: str, force: bool = False): - """Download Datasets + """ + Download Datasets Args: dataset_name (str): dataset name force (bool, optional): if set to True, downloads and overwrites the local test and seed files if false then downloads onlyif the files are not present locally + """ def download_bar(current, total, width=80): - """custom progress bar for downloading data""" + """Custom progress bar for downloading data""" width = shutil.get_terminal_size()[0] // 2 print( f"{current//total*100}% [{'.' * (current//total * int(width))}] [{current}/{total}] bytes", @@ -368,7 +383,7 @@ def download(url: str) -> None: if dataset_name not in EXAMPLE_DATASETS: logger.error( - f"{dataset_name} not in list of available datasets: {str(EXAMPLE_DATASETS)}. Exiting..." + f"{dataset_name} not in list of available datasets: {EXAMPLE_DATASETS!s}. Exiting...", ) return seed_url = DATASET_URL.format(dataset=dataset_name, partition="seed") @@ -426,3 +441,38 @@ def safe_serialize_to_string(data: Dict) -> Dict: except Exception: ret[k] = "" return ret + + +def is_s3_uri(uri_string: str) -> bool: + return uri_string is not None and ( + uri_string.startswith("s3://") or uri_string.startswith("s3a://") + ) + + +def extract_bucket_key_from_s3_url(s3_path: str): + # Refer: https://stackoverflow.com/a/48245084 + if not is_s3_uri(s3_path): + logger.warning("URI is not actually an S3 URI: {}", s3_path) + return None + + path_object = urlparse(s3_path) + bucket = path_object.netloc + key = path_object.path + return {"Bucket": bucket, "Key": key.lstrip("/")} + + +def generate_s3_uri_from_bucket_key(bucket: str, key: str) -> str: + return f"s3://{bucket}/{key}" + + +def generate_presigned_url(client, s3_uri, expiration=86400): + s3_params = extract_bucket_key_from_s3_url(s3_uri) + + if not s3_params: + return s3_uri + + return client.generate_presigned_url( + ClientMethod="get_object", + Params={"Bucket": s3_params["Bucket"], "Key": s3_params["Key"]}, + ExpiresIn=expiration, + ) From 594aafac3a35d4b578159964e32f70a312ac2d71 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 2 Dec 2024 14:05:11 -0800 Subject: [PATCH 3/5] Tuple import --- src/autolabel/task_chain/task_chain.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/autolabel/task_chain/task_chain.py b/src/autolabel/task_chain/task_chain.py index ae48b2d3..8961f32d 100644 --- a/src/autolabel/task_chain/task_chain.py +++ b/src/autolabel/task_chain/task_chain.py @@ -1,6 +1,6 @@ import logging from collections import defaultdict -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import pandas as pd from transformers import AutoTokenizer @@ -206,7 +206,9 @@ async def run(self, dataset_df: pd.DataFrame): return dataset def rename_output_columns( - self, dataset: AutolabelDataset, autolabel_config: AutolabelConfig, + self, + dataset: AutolabelDataset, + autolabel_config: AutolabelConfig, ): """ Rename the output columns of the dataset for each intermediate step in the task chain so that From a0b7de6036bfc890a874788ea4a2d20efad2e1e6 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 2 Dec 2024 14:10:39 -0800 Subject: [PATCH 4/5] rm time logs --- src/autolabel/transforms/serp_api.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/autolabel/transforms/serp_api.py b/src/autolabel/transforms/serp_api.py index f627a2c5..bd186576 100644 --- a/src/autolabel/transforms/serp_api.py +++ b/src/autolabel/transforms/serp_api.py @@ -22,7 +22,9 @@ class RefuelSerpAPIWrapper(SerpAPIWrapper): def __init__(self, search_engine=None, params=None, serpapi_api_key=None): super().__init__( - search_engine=search_engine, params=params, serpapi_api_key=serpapi_api_key, + search_engine=search_engine, + params=params, + serpapi_api_key=serpapi_api_key, ) async def arun(self, query: str, **kwargs: Any) -> Dict: @@ -79,7 +81,9 @@ def __init__( self.serp_api_key = serp_api_key self.serp_args = serp_args self.serp_api_wrapper = RefuelSerpAPIWrapper( - search_engine=None, params=self.serp_args, serpapi_api_key=self.serp_api_key, + search_engine=None, + params=self.serp_args, + serpapi_api_key=self.serp_api_key, ) def name(self) -> str: @@ -101,7 +105,6 @@ async def _get_result(self, query): return search_result async def _apply(self, row: Dict[str, Any]) -> Dict[str, Any]: - start_time = time.time() for col in self.query_columns: if col not in row: logger.warning( @@ -125,10 +128,6 @@ async def _apply(self, row: Dict[str, Any]) -> Dict[str, Any]: "organic_results", ), } - end_time = time.time() - logger.error( - f"Time taken to run Serp API: {end_time - start_time} seconds", - ) return self._return_output_row(transformed_row) From c7cd145f15cd45295a64646413c5dc9f00912c1f Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 2 Dec 2024 18:23:47 -0800 Subject: [PATCH 5/5] Missing imports --- src/autolabel/task_chain/task_chain.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/autolabel/task_chain/task_chain.py b/src/autolabel/task_chain/task_chain.py index 8961f32d..40355f96 100644 --- a/src/autolabel/task_chain/task_chain.py +++ b/src/autolabel/task_chain/task_chain.py @@ -1,7 +1,9 @@ +import copy import logging from collections import defaultdict from typing import Dict, List, Optional, Tuple +import boto3 import pandas as pd from transformers import AutoTokenizer @@ -17,6 +19,7 @@ from autolabel.few_shot.base_label_selector import BaseLabelSelector from autolabel.labeler import LabelingAgent from autolabel.transforms import TransformFactory +from autolabel.utils import generate_presigned_url, is_s3_uri logger = logging.getLogger(__name__) logging.getLogger("httpx").setLevel(logging.WARNING)