diff --git a/src/gretel_client/navigator/__init__.py b/src/gretel_client/navigator/__init__.py index 14698ae3..fcb080a2 100644 --- a/src/gretel_client/navigator/__init__.py +++ b/src/gretel_client/navigator/__init__.py @@ -1,2 +1,6 @@ +from gretel_client.navigator.data_designer.factory import DataDesignerFactory from gretel_client.navigator.data_designer.interface import DataDesigner -from gretel_client.navigator.workflow import NavigatorWorkflow +from gretel_client.navigator.data_designer.sample_to_dataset import ( + DataDesignerFromSampleRecords, +) +from gretel_client.navigator.workflow import DataDesignerWorkflow diff --git a/src/gretel_client/navigator/client/remote.py b/src/gretel_client/navigator/client/remote.py index 01c9e45d..1742dc86 100644 --- a/src/gretel_client/navigator/client/remote.py +++ b/src/gretel_client/navigator/client/remote.py @@ -219,7 +219,6 @@ def submit_batch_workflow( logger.info(f"▢️ Starting your workflow run to generate {num_records} records:") logger.info(f" |-- project_name: {project.name}") logger.info(f" |-- project_id: {project.project_guid}") - logger.info(f" |-- workflow_id: {batch_response.workflow_id}") logger.info(f" |-- workflow_run_id: {batch_response.workflow_run_id}") logger.info(f"πŸ”— -> {workflow_run_url}") diff --git a/src/gretel_client/navigator/data_designer/data_column.py b/src/gretel_client/navigator/data_designer/data_column.py index 8447e901..1dada9f8 100644 --- a/src/gretel_client/navigator/data_designer/data_column.py +++ b/src/gretel_client/navigator/data_designer/data_column.py @@ -117,6 +117,25 @@ def generate_context_column_string(self, exclude: Optional[set[str]] = None) -> + "\n" ) + def get_system_prompt( + self, special_system_instructions: Optional[str] = None + ) -> str: + """Get the system prompt for the column generation task. + + Args: + special_instructions: Special instructions to be added to the system prompt. + + Returns: + System prompt string. + """ + return system_prompt_dict[self.llm_type].format( + special_instructions=( + "" + if special_system_instructions is None + else f"\n{special_system_instructions}\n" + ) + ) + def to_generation_task( self, special_system_instructions: Optional[str] = None, @@ -137,13 +156,7 @@ def to_generation_task( response_column_name=self.name, workflow_label=f"generating {self.name}", llm_type=self.llm_type, - system_prompt=system_prompt_dict[self.llm_type].format( - special_instructions=( - "" - if special_system_instructions is None - else f"\n{special_system_instructions}\n" - ) - ), + system_prompt=self.get_system_prompt(special_system_instructions), client=client, ) diff --git a/src/gretel_client/navigator/data_designer/factory.py b/src/gretel_client/navigator/data_designer/factory.py new file mode 100644 index 00000000..17856eaa --- /dev/null +++ b/src/gretel_client/navigator/data_designer/factory.py @@ -0,0 +1,121 @@ +import logging + +from pathlib import Path +from typing import Optional, Union + +import pandas as pd + +from gretel_client.navigator.data_designer.interface import DataDesigner +from gretel_client.navigator.data_designer.sample_to_dataset import ( + DataDesignerFromSampleRecords, +) +from gretel_client.navigator.log import get_logger +from gretel_client.navigator.tasks.types import ( + DEFAULT_MODEL_SUITE, + ModelSuite, + RecordsT, +) + +logger = get_logger(__name__, level=logging.INFO) + + +class DataDesignerFactory: + """Factory class for creating DataDesigner instances. + + Each class method on this object provides a different way to instantiate + a DataDesigner object, depending on your use case and desired workflow. + + Allowed session keyword arguments: + api_key (str): Your Gretel API key. If set to "prompt" and no API key + is found on the system, you will be prompted for the key. + endpoint (str): Specifies the Gretel API endpoint. This must be a fully + qualified URL. The default is "https://api.gretel.cloud". + default_runner (str): Specifies the runner mode. Must be one of "cloud", + "local", "manual", or "hybrid". The default is "cloud". + artifact_endpoint (str): Specifies the endpoint for project and model + artifacts. Defaults to "cloud" for running in Gretel Cloud. If + working in hybrid mode, set to the URL of your artifact storage bucket. + cache (str): Valid options are "yes" or "no". If set to "no", the session + configuration will not be written to disk. If set to "yes", the + session configuration will be written to disk only if one doesn't + already exist. The default is "no". + validate (bool): If `True`, will validate the login credentials at + instantiation. The default is `False`. + clear (bool): If `True`, existing Gretel credentials will be removed. + The default is `False.` + """ + + @classmethod + def from_blank_canvas( + cls, model_suite: ModelSuite = DEFAULT_MODEL_SUITE, **kwargs + ) -> DataDesigner: + """Instantiate an empty DataDesigner instance that can be built up programmatically. + + This initialization method is equivalent to directly instantiating a DataDesigner object. + + Args: + model_suite: The model suite to use for generating synthetic data. Defaults to the + apache-2.0 licensed model suite. + **kwargs: Additional keyword arguments to pass to the DataDesigner constructor. + + Returns: + An instance of DataDesigner with a blank canvas. + """ + logger.info("🎨 Creating DataDesigner instance from blank canvas") + + return DataDesigner(model_suite=model_suite, **kwargs) + + @classmethod + def from_config(cls, config: dict, **kwargs) -> DataDesigner: + """Instantiate a DataDesigner instance from a configuration dictionary. + + This method allows you to specify your data design using a YAML configuration file, + which is then built into a DataDesigner instance the same way you would do so programmatically. + + Args: + config: A YAML configuration file, dict, or string that fully specifies the data design. + **kwargs: Additional keyword arguments to pass to the DataDesigner constructor. + + Returns: + An instance of DataDesigner configured with the data seeds and generated data columns + defined in the configuration dictionary. + """ + logger.info("🎨 Creating DataDesigner instance from config") + + return DataDesigner.from_config(config, **kwargs) + + @classmethod + def from_sample_records( + cls, + sample_records: Union[str, Path, pd.DataFrame, RecordsT], + *, + subsample_size: Optional[int] = None, + model_suite: ModelSuite = DEFAULT_MODEL_SUITE, + **kwargs, + ) -> DataDesigner: + """Instantiate a DataDesigner instance from sample records. + + Use this subclass of DataDesigner when you want to turn a few sample records + into a rich, diverse synthetic dataset (Sample-to-Dataset). + + Args: + sample_records: Sample records from which categorical data seeds will be extracted + and optionally used to create generated data columns. + subsample_size: The number of records to use from the sample records. If None, + all records will be used. If the subsample size is larger than the sample records, + the full sample will be used. + model_suite: The model suite to use for generating synthetic data. Defaults to the + apache-2.0 licensed model suite. + + Returns: + An instance of DataDesigner configured to extract data seeds from the sample records + and optionally create generated data columns for each field in the sample records. + """ + logger.info("🎨 Creating DataDesigner instance from sample records") + + return DataDesignerFromSampleRecords( + sample_records=sample_records, + subsample_size=subsample_size, + model_suite=model_suite, + **kwargs, + ) diff --git a/src/gretel_client/navigator/data_designer/interface.py b/src/gretel_client/navigator/data_designer/interface.py index 74138b57..6c303c6d 100644 --- a/src/gretel_client/navigator/data_designer/interface.py +++ b/src/gretel_client/navigator/data_designer/interface.py @@ -23,6 +23,7 @@ ) from gretel_client.navigator.log import get_logger from gretel_client.navigator.tasks.base import Task +from gretel_client.navigator.tasks.constants import PREVIEW_NUM_RECORDS from gretel_client.navigator.tasks.evaluate_dataset import EvaluateDataset from gretel_client.navigator.tasks.generate.generate_seed_category_values import ( GenerateSeedCategoryValues, @@ -51,10 +52,11 @@ ValidateCode, ) from gretel_client.navigator.workflow import ( + DataDesignerBatchJob, + DataDesignerWorkflow, DataSpec, - NavigatorBatchJob, - NavigatorWorkflow, PreviewResults, + Step, ) logger = get_logger(__name__, level=logging.INFO) @@ -69,17 +71,15 @@ class DataDesigner: to assemble a scalable synthetic data generation workflow. Args: - dataset_description: Optional description of the dataset to be generated. This description will - be used in prompts to provide high-level context about the dataset. - special_system_instructions: Optional instructions for the system to follow when generating - the dataset. These instructions will be added to the system prompts. model_suite: The model suite to use for generating synthetic data. Defaults to the - Apache-2.0 licensed model suite. + apache-2.0 licensed model suite. session: Optional Gretel session configuration object. If not provided, the session will be configured based on the provided session_kwargs or cached session configuration. + special_system_instructions: Optional instructions for the system to follow when generating + the dataset. These instructions will be added to the system prompts. **session_kwargs: kwargs for your Gretel session. See options below. - Keyword Args: + Allowed session keyword arguments: api_key (str): Your Gretel API key. If set to "prompt" and no API key is found on the system, you will be prompted for the key. endpoint (str): Specifies the Gretel API endpoint. This must be a fully @@ -102,10 +102,9 @@ class DataDesigner: def __init__( self, *, - dataset_description: Optional[str] = None, - special_system_instructions: Optional[str] = None, model_suite: ModelSuite = DEFAULT_MODEL_SUITE, session: Optional[ClientConfig] = None, + special_system_instructions: Optional[str] = None, **session_kwargs, ): @@ -113,22 +112,24 @@ def __init__( configure_session(**session_kwargs) session = get_session_config() + logger.info(f"🦜 Using {model_suite} model suite") self._session = session self._client = get_navigator_client(session=session, **session_kwargs) self._seed_categories: dict[str, list[SeedCategory]] = {} - self._data_columns: dict[str, list[GeneratedDataColumn]] = {} self._seed_subcategory_names = defaultdict(list) + self._generated_data_columns: dict[str, list[GeneratedDataColumn]] = {} self._validators: dict[str, list[Task]] = {} self._evaluators: dict[str, Task] = {} self._eval_type: Optional[str] = None + + self.special_system_instructions = special_system_instructions + datetime_label = datetime.now().isoformat(timespec="seconds") self._workflow_kwargs = { "model_suite": check_model_suite(model_suite), "session": self._session, "client": self._client, - "workflow_name": f"NavDD-{datetime.now().isoformat(timespec='seconds')}", + "workflow_name": f"{self.__class__.__name__}-{datetime_label}", } - self.dataset_description = dataset_description - self.special_system_instructions = special_system_instructions @property def _seed_category_names(self) -> list[str]: @@ -136,79 +137,38 @@ def _seed_category_names(self) -> list[str]: return list(self._seed_categories.keys()) @property - def data_column_names(self) -> list[str]: - """Return a list of the names of the data columns (note order matters).""" - return list(self._data_columns.keys()) - - @property - def seed_column_names(self) -> list[str]: + def categorical_seed_column_names(self) -> list[str]: """Return a list of the names of the seed columns, including subcategories.""" return self._seed_category_names + [ s for ss in self._seed_subcategory_names.values() for s in ss ] + @property + def generated_data_column_names(self) -> list[str]: + """Return a list of the names of the data columns (note order matters).""" + return list(self._generated_data_columns.keys()) + @property def all_column_names(self) -> list[str]: """Return a list of all seed (including seed subcategories) and data column names.""" - return self.seed_column_names + self.data_column_names + return self.categorical_seed_column_names + self.generated_data_column_names @property def categorical_seed_columns(self) -> CategoricalDataSeeds: """Return a CategoricalDataSeeds instance that contains the seed categories.""" - if len(self._seed_categories) == 0: - raise ValueError("No seed categories have been defined.") - - return CategoricalDataSeeds( - seed_categories=list(self._seed_categories.values()) - ) - - @property - def data_spec(self) -> DataSpec: - """Return a DataSpec instance that defines the schema and other relevant info.""" - code_lang = None - code_columns = [] - validation_columns = [] - llm_judge_column_name = None - for validator in self._validators.values(): - if validator.name == "validate_code": - column_suffix = ( - VALIDATE_SQL_COLUMN_SUFFIXES - if validator.config.code_lang in SQL_DIALECTS - else VALIDATE_PYTHON_COLUMN_SUFFIXES - ) - code_lang = validator.config.code_lang - code_columns.extend(validator.config.code_columns) - for col in code_columns: - for prefix in column_suffix: - validation_columns.append(f"{col}{prefix}") - break - if self._eval_type in list(LLMJudgePromptTemplateType): - llm_judge_column_name = f"{self._eval_type}_llm_judge_results" - return DataSpec( - seed_category_names=self._seed_category_names, - seed_subcategory_names=dict(self._seed_subcategory_names), - data_column_names=self.data_column_names, - validation_column_names=validation_columns, - code_column_names=code_columns, - code_lang=code_lang, - eval_type=self._eval_type, - llm_judge_column_name=llm_judge_column_name, - ) + if len(self._seed_categories) > 0: + return CategoricalDataSeeds( + seed_categories=list(self._seed_categories.values()) + ) + return CategoricalDataSeeds(seed_categories=[]) @classmethod - def from_config( - cls, - config: Union[dict, str, Path], - session: Optional[ClientConfig] = None, - **session_kwargs, - ) -> Self: + def from_config(cls, config: Union[dict, str, Path], **kwargs) -> Self: """Instantiate a DataDesigner instance from a YAML configuration str, dict, or file. Args: config: A YAML configuration file, dict, or string. - session: Optional Gretel session configuration object. If not provided, the session will be - configured based on the provided session_kwargs or cached session configuration. - **session_kwargs: kwargs for your Gretel session. + **kwargs: Additional keyword arguments to pass to the DataDesigner constructor. Returns: An instance of DataDesigner configured with the settings from the provided YAML config. @@ -220,12 +180,10 @@ def from_config( config = requests.get(config).content.decode("utf-8") config = smart_load_yaml(config) - designer = cls( - dataset_description=config.get("dataset_description"), + dd = cls( special_system_instructions=config.get("special_system_instructions"), model_suite=config.get("model_suite", DEFAULT_MODEL_SUITE), - session=session, - **session_kwargs, + **kwargs, ) if "categorical_seed_columns" not in config: @@ -235,14 +193,14 @@ def from_config( ) for seed_category in config.get("categorical_seed_columns", []): - designer.add_categorical_seed_column(**seed_category) + dd.add_categorical_seed_column(**seed_category) logger.debug(f"🌱 Adding seed category: {seed_category['name']}") for data_column in config.get("generated_data_columns", []): - designer.add_generated_data_column(**data_column) + dd.add_generated_data_column(**data_column) logger.debug(f"πŸ’½ Adding data column: {data_column['name']}") - if len(designer.all_column_names) == 0: + if len(dd.all_column_names) == 0: raise ValueError("No seed or data columns were defined in the config.") # Post processors are applied after the data generation process @@ -253,7 +211,7 @@ def from_config( eval_type = None for processor in post_processors: if "validator" in processor: - designer.add_data_validator( + dd.add_validator( validator=ValidatorType(processor["validator"]), **processor["settings"], ) @@ -271,9 +229,7 @@ def from_config( eval_type = LLMJudgePromptTemplateType( processor["evaluator"] ).value - designer.add_dataset_evaluation( - eval_type=processor["evaluator"], **settings - ) + dd.add_evaluator(eval_type=processor["evaluator"], **settings) if code_lang and eval_type: if (code_lang in SQL_DIALECTS and eval_type != "text_to_sql") or ( code_lang == "python" and eval_type != "text_to_python" @@ -283,35 +239,46 @@ def from_config( f"the `{eval_type}` evaluator. Please ensure the code language " "of the validator and evaluator are compatible." ) - designer._config = config - return designer + dd._config = config - def _create_sequential_task_list( - self, data_seeds: Optional[CategoricalDataSeeds] = None - ) -> list[Task]: - """Returns a list of tasks to be executed sequentially in the workflow. + return dd + + def _create_workflow_steps( + self, + num_records: Optional[int] = None, + data_seeds: Optional[CategoricalDataSeeds] = None, + dataset_context: Optional[str] = None, + verbose_logging: bool = False, + **kwargs, + ) -> tuple[list[Step], Optional[CategoricalDataSeeds]]: + """Create workflow steps from a list of tasks. Args: - data_seeds: Data seeds to use in place of what was defined in the - configuration. This is useful if you have pre-generated data - seeds or want to experiment with different seed categories/values. + num_records: Number of records to be generated. + data_seeds: Data seeds to use in place of what was defined in the configuration. + This is useful if you have pre-generated data seeds or want to experiment with + different seed categories/values. + dataset_context: Context for the dataset to be used in the seed value generation task. + verbose_logging: If True, additional logging will be displayed. Returns: - A list of Task objects to be executed sequentially in the workflow. + A tuple that contains a list of workflow steps and the data seeds used in the workflow. """ if len(self._seed_categories) == 0: raise ValueError("No seed columns have been defined.") task_list = [] + num_records = num_records or PREVIEW_NUM_RECORDS data_seeds = data_seeds or self.categorical_seed_columns + if data_seeds.needs_generation: # If any seed category / subcategory values need generation, # we start with the seed value generation task. task_list.append( GenerateSeedCategoryValues( seed_categories=list(self._seed_categories.values()), - dataset_context=self.dataset_description, + dataset_context=dataset_context, client=self._client, ) ) @@ -327,10 +294,10 @@ def _create_sequential_task_list( # Given fully-specified data seeds, we next add a task to # sample them in to a seed dataset. - task_list.append(SampleDataSeeds(client=self._client)) + task_list.append(SampleDataSeeds(num_records=num_records, client=self._client)) # Iterate over the data columns and create generation tasks for each. - for column in self._data_columns.values(): + for column in self._generated_data_columns.values(): task = column.to_generation_task( self.special_system_instructions, client=self._client ) @@ -344,30 +311,30 @@ def _create_sequential_task_list( for eval_task in self._evaluators.values(): task_list.append(eval_task) - return task_list + steps = DataDesignerWorkflow.create_steps_from_sequential_tasks( + task_list, verbose_logging=verbose_logging + ) + return steps, data_seeds - def _create_workflow_steps( - self, - task_list: Optional[list[Task]] = None, - data_seeds: Optional[CategoricalDataSeeds] = None, - verbose_logging: bool = False, - ) -> list[dict]: - """Create workflow steps from a list of tasks. + def _get_data_seeds_task( + self, dataset_context: Optional[str] = None, **kwargs + ) -> Task: + """Get the task that generates seed category values.""" - Args: - task_list: List of tasks to be executed sequentially in the workflow. - data_seeds: Data seeds to use in place of what was defined in the configuration. - This is useful if you have pre-generated data seeds or want to experiment with - different seed categories/values. - verbose_logging: If True, additional logging will be displayed. + if len(self._seed_categories) == 0: + raise ValueError("No seed categories have been defined.") - Returns: - A list of workflow steps that can be executed by the NavigatorWorkflow. - """ - if task_list is None: - task_list = self._create_sequential_task_list(data_seeds) - return NavigatorWorkflow.create_steps_from_sequential_tasks( - task_list, verbose_logging=verbose_logging + if not self.categorical_seed_columns.needs_generation: + logger.warning("⚠️ Your categorical data seeds do not require generation.") + return LoadDataSeeds( + categorical_data_seeds=self.categorical_seed_columns, + client=self._client, + ) + + return GenerateSeedCategoryValues( + seed_categories=list(self._seed_categories.values()), + dataset_context=dataset_context, + client=self._client, ) def _validate_generated_data_column_inputs( @@ -388,7 +355,7 @@ def _validate_generated_data_column_inputs( A tuple containing the validated inputs for the data column. """ - if name in self._data_columns: + if name in self._generated_data_columns: raise ValueError(f"Column name `{name}` already exists.") if name in self._seed_categories: raise ValueError(f"Column name `{name}` already exists as a seed category.") @@ -401,22 +368,22 @@ def _validate_generated_data_column_inputs( f"The `generation_prompt` field of `{name}` contains template keywords that " "are not available as columns.\n" f"* Template keywords found in `generation_prompt`: {template_kwargs}\n" - f"* Available seed columns: {self.seed_column_names}\n" - f"* Available data columns: {self.data_column_names}" + f"* Available seed columns: {self.categorical_seed_column_names}\n" + f"* Available data columns: {self.generated_data_column_names}" ) if isinstance(columns_to_list_in_prompt, str): if columns_to_list_in_prompt == "all": columns_to_list_in_prompt = self.all_column_names elif columns_to_list_in_prompt == "all_categorical_seed_columns": - columns_to_list_in_prompt = self.seed_column_names + columns_to_list_in_prompt = self.categorical_seed_column_names elif columns_to_list_in_prompt == "all_generated_data_columns": - if len(self.data_column_names) == 0: + if len(self.generated_data_column_names) == 0: logger.warning( f"⚠️ The generated data column `{name}` has set `columns_to_list_in_prompt` " "to 'all_generated_data_columns', but no data columns have been defined." ) - columns_to_list_in_prompt = self.data_column_names + columns_to_list_in_prompt = self.generated_data_column_names else: raise ValueError( f"If not None, `columns_to_list_in_prompt` must be a list of column names or " @@ -434,44 +401,46 @@ def _validate_generated_data_column_inputs( f"The `columns_to_list_in_prompt` field of `{name}` contains invalid columns. " "Only seed or data columns that have been defined before the current " "column can be added as context.\n" - f"* Available seed columns: {self.seed_column_names}\n" - f"* Available data columns: {self.data_column_names}\n" + f"* Available seed columns: {self.categorical_seed_column_names}\n" + f"* Available data columns: {self.generated_data_column_names}\n" ) return name, generation_prompt, columns_to_list_in_prompt - def generate_seed_category_values( - self, verbose_logging: bool = False - ) -> CategoricalDataSeeds: - """Generate values for seed categories that require generation. + def add_generated_data_column( + self, + name: str, + *, + generation_prompt: str, + columns_to_list_in_prompt: Optional[list[str]] = None, + llm_type: LLMType = LLMType.NATURAL_LANGUAGE, + output_type: OutputColumnType = OutputColumnType.AUTO, + ) -> None: + """Add a generated data column to the data design. + + Generated data columns are fully generated by an LLM using the provided generation prompt. Args: - verbose_logging: If True, additional logging will be displayed during execution. + name: The name of the data column. + generation_prompt: The prompt that will be used to generate the data column. The prompt and can + contain template keywords that reference seed columns or other existing data columns. + columns_to_list_in_prompt: List of seed and/or data columns to list as context in the generation prompt. + llm_type: LLM type for generation of the column. Must be one of ["nl", "code", "judge"]. + output_type: Output type for the column. Must be one of ["auto", "text", "dict", "list", "code"]. + If "auto", the output type will be "code" when llm_type is "code" and "text" otherwise. """ - if len(self._seed_categories) == 0: - raise ValueError("No seed categories have been defined.") - - if not self.categorical_seed_columns.needs_generation: - logger.warning("⚠️ Your categorical data seeds do not require generation.") - return self.categorical_seed_columns - - workflow = NavigatorWorkflow(**self._workflow_kwargs) - task = GenerateSeedCategoryValues( - seed_categories=list(self._seed_categories.values()), - dataset_context=self.dataset_description, - client=self._client, - ) - workflow.add_steps(workflow.create_steps_from_sequential_tasks([task])) - seeds = workflow._generate_preview(verbose=verbose_logging).output - if seeds is None: - raise ValueError( - "An error occurred while generating your categorical seed values. " - "Please check that your configuration is as expected, restart your session, and try again. " - "If the problem persists, please contact support and/or submit a GitHub issue to the gretel-client repo." + name, generation_prompt, columns_to_list_in_prompt = ( + self._validate_generated_data_column_inputs( + name, generation_prompt, columns_to_list_in_prompt ) - - seeds["seed_categories"] = seeds["seed_categories"][::-1] - return CategoricalDataSeeds(**seeds) + ) + self._generated_data_columns[name] = GeneratedDataColumn( + name=name, + generation_prompt=generation_prompt, + columns_to_list_in_prompt=columns_to_list_in_prompt, + output_type=output_type, + llm_type=llm_type, + ) def add_categorical_seed_column( self, @@ -482,10 +451,13 @@ def add_categorical_seed_column( weights: Optional[list[float]] = None, num_new_values_to_generate: Optional[int] = None, subcategories: Optional[Union[list[SeedSubcategory], list[dict]]] = None, + **kwargs, ) -> None: """Add a seed category to the data design. - All seed categories must be added *before* any generated data columns are added. + All seed categories must be added *before* any generated data columns are added. This is to enable + validation of any keyword arguments in the generation prompts, which can only reference existing + seed and data columns. A seed category is a categorical column with values that can be user-provided or generated by an LLM using the GenerateSeedCategoryValues Task. The purpose of categorical data @@ -506,11 +478,6 @@ def add_categorical_seed_column( SeedSubcategory instance or a dictionary with the same fields as a parent seed category *except* for `weights` and `subcategories`. """ - if len(self._data_columns) > 0: - raise ValueError( - "Seed categories must be added *before* all data columns.\n" - f"* Current data columns: {self.data_column_names}" - ) if num_new_values_to_generate is None and values is None: raise ValueError( "You must provide *at least* one of `values` or `num_new_values_to_generate`." @@ -539,44 +506,10 @@ def add_categorical_seed_column( weights=weights or [], num_new_values_to_generate=num_new_values_to_generate, subcategories=subcategories or [], + **kwargs, ) - def add_generated_data_column( - self, - name: str, - *, - generation_prompt: str, - columns_to_list_in_prompt: Optional[list[str]] = None, - llm_type: LLMType = LLMType.NATURAL_LANGUAGE, - output_type: OutputColumnType = OutputColumnType.AUTO, - ) -> None: - """Add a generated data column to the data design. - - Generated data columns are fully generated by an LLM using the provided generation prompt. - - Args: - name: The name of the data column. - generation_prompt: The prompt that will be used to generate the data column. The prompt and can - contain template keywords that reference seed columns or other existing data columns. - columns_to_list_in_prompt: List of seed and/or data columns to list as context in the generation prompt. - llm_type: LLM type for generation of the column. Must be one of ["nl", "code", "judge"]. - output_type: Output type for the column. Must be one of ["auto", "text", "dict", "list", "code"]. - If "auto", the output type will be "code" when llm_type is "code" and "text" otherwise. - """ - name, generation_prompt, columns_to_list_in_prompt = ( - self._validate_generated_data_column_inputs( - name, generation_prompt, columns_to_list_in_prompt - ) - ) - self._data_columns[name] = GeneratedDataColumn( - name=name, - generation_prompt=generation_prompt, - columns_to_list_in_prompt=columns_to_list_in_prompt, - output_type=output_type, - llm_type=llm_type, - ) - - def add_data_validator(self, validator: ValidatorType, **settings) -> None: + def add_validator(self, validator: ValidatorType, **settings) -> None: """Add a data validator to the data design. Args: @@ -606,7 +539,7 @@ def add_data_validator(self, validator: ValidatorType, **settings) -> None: else: raise ValueError(f"Unknown validator type: {validator}") - def add_dataset_evaluation( + def add_evaluator( self, eval_type: Union[EvaluationType, LLMJudgePromptTemplateType], **settings ) -> None: """Add a dataset evaluation task to the data design. @@ -645,7 +578,7 @@ def add_dataset_evaluation( raise ValueError(f"Unknown evaluation type: {eval_type}") self._evaluators["evaluate_dataset"] = EvaluateDataset( - seed_columns=self.seed_column_names, + seed_columns=self.categorical_seed_column_names, ordered_list_like_columns=settings.get("ordered_list_like_columns", []), other_list_like_columns=settings.get("list_like_columns", []), llm_judge_column=settings.get("llm_judge_column", llm_judge_column), @@ -654,10 +587,68 @@ def add_dataset_evaluation( ) self._eval_type = eval_type + def get_generated_data_column(self, name: str) -> GeneratedDataColumn: + """Get a data column by name.""" + return self._generated_data_columns[name] + + def get_data_spec( + self, data_seeds: Optional[CategoricalDataSeeds] = None + ) -> DataSpec: + """Return a DataSpec instance that defines the schema and other relevant info.""" + code_lang = None + code_columns = [] + validation_columns = [] + llm_judge_column_name = None + for validator in self._validators.values(): + if validator.name == "validate_code": + column_suffix = ( + VALIDATE_SQL_COLUMN_SUFFIXES + if validator.config.code_lang in SQL_DIALECTS + else VALIDATE_PYTHON_COLUMN_SUFFIXES + ) + code_lang = validator.config.code_lang + code_columns.extend(validator.config.code_columns) + for col in code_columns: + for prefix in column_suffix: + validation_columns.append(f"{col}{prefix}") + break + if self._eval_type in list(LLMJudgePromptTemplateType): + llm_judge_column_name = f"{self._eval_type}_llm_judge_results" + + seed_category_names = ( + self._seed_category_names + if data_seeds is None + else [s.name for s in data_seeds.seed_categories] + ) + + seed_subcategory_names = ( + dict(self._seed_subcategory_names) + if data_seeds is None + else { + s.name: [ss.name for ss in s.subcategories] + for s in data_seeds.seed_categories + } + ) + + return DataSpec( + seed_category_names=seed_category_names, + seed_subcategory_names=seed_subcategory_names, + data_column_names=self.generated_data_column_names, + validation_column_names=validation_columns, + code_column_names=code_columns, + code_lang=code_lang, + eval_type=self._eval_type, + llm_judge_column_name=llm_judge_column_name, + ) + def export_as_workflow_config( self, path: Optional[Union[str, Path]] = None, + *, + num_records: Optional[int] = None, + dataset_context: Optional[str] = None, data_seeds: Optional[CategoricalDataSeeds] = None, + **kwargs, ) -> Union[dict, None]: """Export the data design as a Navigator workflow configuration. @@ -666,6 +657,8 @@ def export_as_workflow_config( Args: path: Optional JSON or YAML path to save the workflow configuration to. + num_records: Number of records to be generated. + dataset_context: Context for the dataset to be used in the seed generation task. data_seeds: Data seeds to use in place of what was defined in the DataDesigner configuration. This is useful if, for example, you generated category values using `generate_seed_category_values` and want to @@ -674,9 +667,18 @@ def export_as_workflow_config( Returns: If no path is provided, the configuration will be returned as a dict. """ - workflow = NavigatorWorkflow(**self._workflow_kwargs) - workflow.add_steps(self._create_workflow_steps(data_seeds=data_seeds)) + workflow = DataDesignerWorkflow(**self._workflow_kwargs) + + steps, _ = self._create_workflow_steps( + num_records=num_records, + dataset_context=dataset_context, + data_seeds=data_seeds, + **kwargs, + ) + workflow.add_steps(steps) config = workflow.to_dict() + config["globals"]["num_records"] = num_records + if path is None: return config else: @@ -693,19 +695,60 @@ def export_as_workflow_config( f"You provided: {path.suffix}" ) - def get_seed_category(self, name: str) -> SeedCategory: - """Get a seed category by name.""" - return self._seed_categories[name] + def ingest_categorical_data_seeds( + self, data_seeds: Union[dict, CategoricalDataSeeds] + ) -> None: + """Ingest pre-generated data seeds into the data design. - def get_data_column(self, name: str) -> GeneratedDataColumn: - """Get a data column by name.""" - return self._data_columns[name] + Args: + data_seeds: Ingest entire CategoricalDataSeeds instance into the data design. + """ + if isinstance(data_seeds, dict): + data_seeds = CategoricalDataSeeds(**data_seeds) + logger.info("🌱 Ingesting categorical data seeds into data design") + for seed in data_seeds.seed_categories: + try: + self.add_categorical_seed_column(**seed.model_dump()) + except NotImplementedError: + raise NotImplementedError( + "Ingesting data seeds is not supported for the " + f"{self.__class__.__name__} subclass." + ) + + def run_data_seeds_step( + self, + *, + dataset_context: Optional[str] = None, + verbose_logging: bool = False, + **kwargs, + ) -> CategoricalDataSeeds: + """Run workflow step that generates / extracts / defines data seeds. + + Args: + dataset_context: Context for the dataset to be used in the seed value generation task. + verbose_logging: If True, additional logging will be displayed during execution. + """ + task = self._get_data_seeds_task(dataset_context=dataset_context, **kwargs) + workflow = DataDesignerWorkflow(**self._workflow_kwargs) + workflow.add_steps(workflow.create_steps_from_sequential_tasks([task])) + seeds = workflow._generate_preview(verbose=verbose_logging).output + + if seeds is None or "seed_categories" not in seeds: + raise ValueError( + "An error occurred while generating your categorical seed values. " + "Please check that your configuration is as expected, restart your " + "session, and try again. If the problem persists, please contact support " + "and/or submit a GitHub issue to the gretel-client repo." + ) + + return CategoricalDataSeeds(**seeds) def generate_dataset_preview( self, *, data_seeds: Optional[CategoricalDataSeeds] = None, verbose_logging: bool = False, + **kwargs, ) -> PreviewResults: """Generate a preview synthetic dataset using the current workflow steps. @@ -718,10 +761,37 @@ def generate_dataset_preview( Returns: A PreviewResults instance containing the outputs of the workflow. """ - workflow = NavigatorWorkflow(**self._workflow_kwargs) - workflow.add_steps(self._create_workflow_steps(data_seeds=data_seeds)) + workflow = DataDesignerWorkflow(**self._workflow_kwargs) + + steps, data_seeds = self._create_workflow_steps( + data_seeds=data_seeds, + verbose_logging=False, + **kwargs, + ) + + workflow.add_steps(steps) + preview = workflow.generate_dataset_preview(verbose_logging=verbose_logging) - preview.data_spec = self.data_spec + + # In order to have a data spec, we need to know the final data seeds. We grab them + # from the Workflow outputs to be certain they are the same as the ones used in the preview. + seed_steps = [ + s.name + for s in workflow._steps + if workflow.step_io_map[s.name]["output"] == "categorical_data_seeds" + ] + if len(seed_steps) == 0: + raise ValueError( + "The workflow does not contain a step that outputs categorical data seeds. " + "This is required to generate a preview dataset by upsampling sample records." + ) + seed_step = seed_steps[-1] + data_seeds = CategoricalDataSeeds(**preview.outputs_by_step[seed_step]) + + # The data spec has info about all the various columns in the dataset. + # We use this to display the preview data in a more human-readable format. + preview.data_spec = self.get_data_spec(data_seeds) + if preview.evaluation_results is not None and self._eval_type in list( LLMJudgePromptTemplateType ): @@ -736,7 +806,8 @@ def submit_batch_workflow( num_records: int, project_name: Optional[str] = None, data_seeds: Optional[CategoricalDataSeeds] = None, - ) -> NavigatorBatchJob: + **kwargs, + ) -> DataDesignerBatchJob: """Submit a batch job to generate a synthetic dataset. Args: @@ -749,21 +820,21 @@ def submit_batch_workflow( configuration will be used, including generating new values if needed. Returns: - NavigatorBatchJob instance containing the workflow run details and helper + NavigatorWorkflowBatchJob instance containing the workflow run details and helper methods for fetching the results. """ - workflow = NavigatorWorkflow(**self._workflow_kwargs) - workflow.add_steps( - self._create_workflow_steps(data_seeds=data_seeds, verbose_logging=True) + workflow = DataDesignerWorkflow(**self._workflow_kwargs) + steps, _ = self._create_workflow_steps( + num_records=num_records, + data_seeds=data_seeds, + verbose_logging=True, + **kwargs, ) - workflow_run = workflow.submit_batch_job( + workflow.add_steps(steps) + batch_job = workflow.submit_batch_job( num_records=num_records, project_name=project_name ) - return NavigatorBatchJob( - workflow_step_names=workflow.workflow_step_names, - workflow_run=workflow_run, - data_spec=self.data_spec, - ) + return batch_job def __repr__(self): seed_categories = [ @@ -775,20 +846,34 @@ def __repr__(self): for name, s in self._seed_categories.items() ] + categorical_seed_columns = ( + f" categorical_seed_columns: {seed_categories}\n" + if len(seed_categories) > 0 + else "" + ) + + generated_data_columns = ( + f" generated_data_columns: {self.generated_data_column_names}\n" + if len(self.generated_data_column_names) > 0 + else "" + ) + validators = ( f" validator: {[f'{k}:{v.config.code_lang}' for k,v in self._validators.items()][0]}\n" if len(self._validators) > 0 else "" ) - evaluation = ( + evaluator = ( f" evaluator: {self._eval_type}\n" if self._eval_type is not None else "" ) - return ( - f"{self.__class__.__name__}(\n" - f" categorical_seed_columns: {seed_categories}\n" - f" generated_data_columns: {self.data_column_names}\n" + + column_repr = ( + f"{categorical_seed_columns}" + f"{generated_data_columns}" f"{validators}" - f"{evaluation}" - ")" + f"{evaluator}" ) + newline = "\n" if len(column_repr) > 0 else "" + + return f"{self.__class__.__name__}({newline}{column_repr})" diff --git a/src/gretel_client/navigator/data_designer/sample_to_dataset.py b/src/gretel_client/navigator/data_designer/sample_to_dataset.py new file mode 100644 index 00000000..8d1cc68c --- /dev/null +++ b/src/gretel_client/navigator/data_designer/sample_to_dataset.py @@ -0,0 +1,353 @@ +import logging + +from typing import Optional, Union + +import pandas as pd + +from gretel_client.config import ClientConfig +from gretel_client.navigator.data_designer.interface import DataDesigner +from gretel_client.navigator.log import get_logger +from gretel_client.navigator.tasks.base import Task +from gretel_client.navigator.tasks.constants import S2D_PREVIEW_NUM_RECORDS +from gretel_client.navigator.tasks.extract_data_seeds_from_sample_records import ( + ExtractDataSeedsFromSampleRecords, +) +from gretel_client.navigator.tasks.generate_dataset_from_sample_records import ( + GenerateDatasetFromSampleRecords, +) +from gretel_client.navigator.tasks.load_data_seeds import LoadDataSeeds +from gretel_client.navigator.tasks.seed.sample_data_seeds import SampleDataSeeds +from gretel_client.navigator.tasks.types import ( + CategoricalDataSeeds, + DEFAULT_MODEL_SUITE, + LLMJudgePromptTemplateType, + ModelSuite, + RecordsT, + SeedSubcategory, + SQL_DIALECTS, +) +from gretel_client.navigator.tasks.utils import process_sample_records +from gretel_client.navigator.tasks.validate.validate_code import ( + VALIDATE_PYTHON_COLUMN_SUFFIXES, + VALIDATE_SQL_COLUMN_SUFFIXES, +) +from gretel_client.navigator.workflow import DataDesignerWorkflow, DataSpec, Step + +logger = get_logger(__name__, level=logging.INFO) + + +class DataDesignerFromSampleRecords(DataDesigner): + """DataDesigner subclass that is initialized from a sample of records. + + Use this subclass of DataDesigner when you want to turn a few sample records + into a rich, diverse synthetic dataset (Sample-to-Dataset). + + Args: + sample_records: Sample records from which categorical data seeds will be extracted + and optionally used to create generated data columns. + subsample_size: The number of records to use from the sample records. If None, + all records will be used. If the subsample size is larger than the sample records, + the full sample will be used. + model_suite: The model suite to use for generating synthetic data. Defaults to the + apache-2.0 licensed model suite. + session: Optional Gretel session configuration object. If not provided, the session will be + configured based on the provided session_kwargs or cached session configuration. + special_system_instructions: Optional instructions for the system to follow when generating + the dataset. These instructions will be added to the system prompts. + **session_kwargs: kwargs for your Gretel session. See options below. + + Keyword Args: + api_key (str): Your Gretel API key. If set to "prompt" and no API key + is found on the system, you will be prompted for the key. + endpoint (str): Specifies the Gretel API endpoint. This must be a fully + qualified URL. The default is "https://api.gretel.cloud". + default_runner (str): Specifies the runner mode. Must be one of "cloud", + "local", "manual", or "hybrid". The default is "cloud". + artifact_endpoint (str): Specifies the endpoint for project and model + artifacts. Defaults to "cloud" for running in Gretel Cloud. If + working in hybrid mode, set to the URL of your artifact storage bucket. + cache (str): Valid options are "yes" or "no". If set to "no", the session + configuration will not be written to disk. If set to "yes", the + session configuration will be written to disk only if one doesn't + already exist. The default is "no". + validate (bool): If `True`, will validate the login credentials at + instantiation. The default is `False`. + clear (bool): If `True`, existing Gretel credentials will be removed. + The default is `False.` + """ + + def __init__( + self, + sample_records: RecordsT, + *, + subsample_size: Optional[int] = None, + model_suite: ModelSuite = DEFAULT_MODEL_SUITE, + session: Optional[ClientConfig] = None, + **session_kwargs, + ): + super().__init__( + model_suite=model_suite, + session=session, + **session_kwargs, + ) + + processed_sample_records = process_sample_records( + sample_records, + subsample_size=subsample_size, + ) + + self._sample_records: Optional[RecordsT] = processed_sample_records + + for name in list(pd.DataFrame.from_records(self._sample_records)): + self.add_generated_data_column( + name, generation_prompt="(from_sample_records)" + ) + + def _create_workflow_steps( + self, + num_records: Optional[int] = None, + data_seeds: Optional[CategoricalDataSeeds] = None, + dataset_context: Optional[str] = None, + verbose_logging: bool = False, + **kwargs, + ) -> tuple[list[Step], Optional[CategoricalDataSeeds]]: + """Create workflow steps from a list of tasks. + + Args: + num_records: The number of records to generate. + data_seeds: Data seeds to use in place of what was defined in the configuration. + This is useful if you have pre-generated data seeds or want to experiment with + different seed categories/values. + dataset_context: Context for the dataset to be used in the seed generation task. + verbose_logging: If True, additional logging will be displayed. + + Returns: + A tuple that contains a list of workflow steps and the data seeds used in the workflow. + """ + if len(self._seed_categories) > 0 and data_seeds is None: + data_seeds = self.categorical_seed_columns + if data_seeds.needs_generation: + raise ValueError( + "Category value generation is not supported for seeds that were " + "extracted from a data sample. Please provide pre-generated seed values " + "or switch to initializing DataDesigner directly or from a config." + ) + + task_list = [] + num_records = num_records or S2D_PREVIEW_NUM_RECORDS + + # We allow arguments from the extraction and data generation tasks to be passed as kwargs. + # Pop extraction task args so that only data generation args remain. + extract_data_seeds_kwargs = { + k: kwargs.pop(k) + for k in list(kwargs.keys()) + if k in ["max_num_seeds", "num_assistants", "system_prompt_type"] + } + + if data_seeds is not None: + task_list.append( + LoadDataSeeds(categorical_data_seeds=data_seeds, client=self._client) + ) + + else: + task_list.append( + ExtractDataSeedsFromSampleRecords( + sample_records=self._sample_records, + client=self._client, + dataset_context=dataset_context, + **extract_data_seeds_kwargs, + ) + ) + + # If any generated data columns are based on sample records, + # we are in the full sample-to-dataset use case + if any( + column.generation_prompt == "(from_sample_records)" + for column in self._generated_data_columns.values() + ): + task_list.append( + GenerateDatasetFromSampleRecords( + sample_records=self._sample_records, + client=self._client, + dataset_context=dataset_context, + target_num_records=num_records, + **kwargs, + ) + ) + + # Otherwise, we need to sample the data seeds to generate a dataset. + else: + task_list.append( + SampleDataSeeds(num_records=num_records, client=self._client) + ) + + # Append all other generated data columns that have been added. + for column in self._generated_data_columns.values(): + if column.generation_prompt == "(from_sample_records)": + continue + + task_list.append( + column.to_generation_task( + self.special_system_instructions, client=self._client + ) + ) + + for validator in self._validators.values(): + task_list.append(validator) + + for eval_task in self._evaluators.values(): + task_list.append(eval_task) + + steps = DataDesignerWorkflow.create_steps_from_sequential_tasks( + task_list, verbose_logging=verbose_logging + ) + + return steps, data_seeds + + def _get_data_seeds_task(self, dataset_context=None, **kwargs) -> Task: + """Get the task that extracts data seeds from the sample records.""" + return ExtractDataSeedsFromSampleRecords( + sample_records=self._sample_records, + client=self._client, + dataset_context=dataset_context, + **kwargs, + ) + + def add_categorical_seed_column( + self, + name: str, + *, + description: Optional[str] = None, + values: Optional[list[Union[str, int, float]]] = None, + weights: Optional[list[float]] = None, + num_new_values_to_generate: Optional[int] = None, + subcategories: Optional[Union[list[SeedSubcategory], list[dict]]] = None, + **kwargs, + ) -> None: + """Add a seed category to the data design. + + This method is not supported for DataDesignerFromSampleRecords instances. + If you want to add categorical seed columns, please use the base DataDesigner class. + """ + raise NotImplementedError( + f"Categorical seed columns cannot be added to a {self.__class__.__name__} instance. " + "If you want to add categorical seed columns, please use the base DataDesigner class." + ) + + def get_data_spec( + self, data_seeds: Optional[CategoricalDataSeeds] = None + ) -> DataSpec: + """Return a DataSpec instance that defines the schema and other relevant info.""" + code_lang = None + code_columns = [] + validation_columns = [] + llm_judge_column_name = None + for validator in self._validators.values(): + if validator.name == "validate_code": + column_suffix = ( + VALIDATE_SQL_COLUMN_SUFFIXES + if validator.config.code_lang in SQL_DIALECTS + else VALIDATE_PYTHON_COLUMN_SUFFIXES + ) + code_lang = validator.config.code_lang + code_columns.extend(validator.config.code_columns) + for col in code_columns: + for prefix in column_suffix: + validation_columns.append(f"{col}{prefix}") + break + if self._eval_type in list(LLMJudgePromptTemplateType): + llm_judge_column_name = f"{self._eval_type}_llm_judge_results" + + seed_category_names = ( + self._seed_category_names + if data_seeds is None + else [s.name for s in data_seeds.seed_categories] + ) + + seed_subcategory_names = ( + dict(self._seed_subcategory_names) + if data_seeds is None + else { + s.name: [ss.name for ss in s.subcategories] + for s in data_seeds.seed_categories + } + ) + + # We need this step for sample-to-dataset because we sometimes change + # the column names to follow a standard format. + if data_seeds is not None and data_seeds.dataset_schema_map is not None: + generated_data_column_names = [] + schema_map = data_seeds.dataset_schema_map.get("original_to_new", {}) + for col in self.generated_data_column_names: + if col in schema_map: + generated_data_column_names.append(schema_map[col]) + else: + generated_data_column_names.append(col) + else: + generated_data_column_names = self.generated_data_column_names + + return DataSpec( + seed_category_names=seed_category_names, + seed_subcategory_names=seed_subcategory_names, + data_column_names=generated_data_column_names, + validation_column_names=validation_columns, + code_column_names=code_columns, + code_lang=code_lang, + eval_type=self._eval_type, + llm_judge_column_name=llm_judge_column_name, + ) + + def __repr__(self): + if len(self._seed_categories) > 0: + seed_categories = [ + ( + name + if len(s.subcategories) == 0 + else f"{name}:{','.join([n.name for n in s.subcategories])}" + ) + for name, s in self._seed_categories.items() + ] + categorical_seed_columns = ( + f" categorical_seed_columns: {seed_categories}\n" + if len(seed_categories) > 0 + else "" + ) + else: + categorical_seed_columns = ( + " categorical_seed_columns: (from_sample_records)\n" + ) + + generated_data_columns = [ + ( + f"{c.name} (from_sample_records)" + if c.generation_prompt == "(from_sample_records)" + else c.name + ) + for c in self._generated_data_columns.values() + ] + + generated_data_columns = ( + f" generated_data_columns: {generated_data_columns}\n" + if len(generated_data_columns) > 0 + else "" + ) + + validators = ( + f" validator: {[f'{k}:{v.config.code_lang}' for k,v in self._validators.items()][0]}\n" + if len(self._validators) > 0 + else "" + ) + + evaluator = ( + f" evaluator: {self._eval_type}\n" if self._eval_type is not None else "" + ) + + column_repr = ( + f"{categorical_seed_columns}" + f"{generated_data_columns}" + f"{validators}" + f"{evaluator}" + ) + newline = "\n" if len(column_repr) > 0 else "" + + return f"{self.__class__.__name__}({newline}{column_repr})" diff --git a/src/gretel_client/navigator/data_designer/viz_tools.py b/src/gretel_client/navigator/data_designer/viz_tools.py index 9226d2a1..f4aab8f0 100644 --- a/src/gretel_client/navigator/data_designer/viz_tools.py +++ b/src/gretel_client/navigator/data_designer/viz_tools.py @@ -53,8 +53,8 @@ def create_rich_histogram_table( def display_sample_record( record: Union[dict, pd.Series, pd.DataFrame], - seed_categories: list[str], data_columns: list[str], + seed_categories: Optional[list[str]] = None, seed_subcategories: Optional[dict[str, list[str]]] = None, code_lang: Optional[CodeLang] = None, code_columns: Optional[list[str]] = None, @@ -80,15 +80,17 @@ def display_sample_record( ) code_columns = code_columns or [] + seed_categories = seed_categories or [] seed_subcategories = seed_subcategories or {} validation_columns = validation_columns or [] + code_lang = None if code_lang is None else CodeLang.validate(code_lang) table_kws = dict(show_lines=True, expand=True) render_list = [] if len(seed_categories) > 0: - table = Table(title="Seed Columns", **table_kws) + table = Table(title="Categorical Seed Columns", **table_kws) table.add_column("Name") table.add_column("Value") for col in [c for c in seed_categories if c not in code_columns]: @@ -99,7 +101,7 @@ def display_sample_record( render_list.append(_pad_console_element(table)) if len(data_columns) > 0: - table = Table(title="Data Columns", **table_kws) + table = Table(title="Generated Data Columns", **table_kws) table.add_column("Name") table.add_column("Value") for col in [c for c in data_columns if c not in code_columns]: @@ -160,7 +162,7 @@ def display_sample_record( index_label = Text(f"[index: {record_index}]", justify="center") render_list.append(index_label) - console.print(Group(*render_list)) + console.print(Group(*render_list), markup=False) def display_preview_evaluation_summary( @@ -249,4 +251,4 @@ def display_preview_evaluation_summary( render_list.append(dash_sep) - console.print(Group(*render_list)) + console.print(Group(*render_list), markup=False) diff --git a/src/gretel_client/navigator/tasks/constants.py b/src/gretel_client/navigator/tasks/constants.py new file mode 100644 index 00000000..23cadb4a --- /dev/null +++ b/src/gretel_client/navigator/tasks/constants.py @@ -0,0 +1,6 @@ +PREVIEW_NUM_RECORDS = 10 + +# Sample-to-Dataset constants +MAX_NUM_SEEDS = 10 +MAX_SAMPLE_SIZE = 50 +S2D_PREVIEW_NUM_RECORDS = 50 diff --git a/src/gretel_client/navigator/tasks/extract_data_seeds_from_sample_records.py b/src/gretel_client/navigator/tasks/extract_data_seeds_from_sample_records.py new file mode 100644 index 00000000..8556f72e --- /dev/null +++ b/src/gretel_client/navigator/tasks/extract_data_seeds_from_sample_records.py @@ -0,0 +1,66 @@ +from pathlib import Path +from typing import Optional, Union + +import pandas as pd + +from pydantic import BaseModel, Field + +from gretel_client.navigator.client.interface import Client +from gretel_client.navigator.log import get_logger +from gretel_client.navigator.tasks.base import Task +from gretel_client.navigator.tasks.constants import MAX_NUM_SEEDS +from gretel_client.navigator.tasks.types import ( + CategoricalDataSeeds, + DEFAULT_MODEL_SUITE, + ModelSuite, + RecordsT, + SystemPromptType, +) +from gretel_client.navigator.tasks.utils import process_sample_records + +logger = get_logger(__name__, level="INFO") + + +class ExtractDataSeedsFromSampleRecordsConfig(BaseModel): + sample_records: RecordsT + max_num_seeds: int = Field(default=5, ge=1, le=MAX_NUM_SEEDS) + num_assistants: int = Field(default=5, ge=1, le=8) + dataset_context: str = "" + system_prompt_type: SystemPromptType = SystemPromptType.COGNITION + num_samples: int = 25 + + +class ExtractDataSeedsFromSampleRecords(Task): + + def __init__( + self, + sample_records: Union[str, Path, pd.DataFrame, RecordsT], + max_num_seeds: int = 5, + num_assistants: int = 3, + system_prompt_type: SystemPromptType = SystemPromptType.COGNITION, + dataset_context: Optional[str] = None, + workflow_label: Optional[str] = None, + client: Optional[Client] = None, + model_suite: ModelSuite = DEFAULT_MODEL_SUITE, + ): + sample_records = process_sample_records(sample_records) + super().__init__( + config=ExtractDataSeedsFromSampleRecordsConfig( + sample_records=sample_records, + max_num_seeds=max_num_seeds, + num_assistants=num_assistants, + system_prompt_type=system_prompt_type, + dataset_context=dataset_context or "", + num_samples=len(sample_records), + ), + workflow_label=workflow_label, + client=client, + model_suite=model_suite, + ) + + @property + def name(self) -> str: + return "extract_data_seeds_from_sample_records" + + def run(self) -> CategoricalDataSeeds: + return self._run() diff --git a/src/gretel_client/navigator/tasks/generate_dataset_from_sample_records.py b/src/gretel_client/navigator/tasks/generate_dataset_from_sample_records.py new file mode 100644 index 00000000..2822719a --- /dev/null +++ b/src/gretel_client/navigator/tasks/generate_dataset_from_sample_records.py @@ -0,0 +1,76 @@ +from pathlib import Path +from typing import Optional, Union + +import pandas as pd + +from pydantic import BaseModel, Field + +from gretel_client.navigator.client.interface import Client +from gretel_client.navigator.log import get_logger +from gretel_client.navigator.tasks.base import Task +from gretel_client.navigator.tasks.constants import MAX_SAMPLE_SIZE +from gretel_client.navigator.tasks.types import ( + CategoricalDataSeeds, + DEFAULT_MODEL_SUITE, + ModelSuite, + RecordsT, + SystemPromptType, +) +from gretel_client.navigator.tasks.utils import process_sample_records + +logger = get_logger(__name__, level="INFO") + + +class GenerateDatasetFromSampleRecordsConfig(BaseModel): + sample_records: RecordsT + target_num_records: int = Field(500, ge=MAX_SAMPLE_SIZE, le=10_000) + system_prompt_type: SystemPromptType = SystemPromptType.COGNITION + num_records_per_seed: int = Field(5, ge=1, le=10) + num_examples_per_prompt: int = Field(5, ge=1, le=MAX_SAMPLE_SIZE) + dataset_context: str = "" + + +class GenerateDatasetFromSampleRecords(Task): + + def __init__( + self, + sample_records: Union[str, Path, pd.DataFrame, RecordsT], + target_num_records: int = 500, + system_prompt_type: SystemPromptType = SystemPromptType.COGNITION, + num_records_per_seed: int = 5, + num_examples_per_prompt: int = 5, + dataset_context: Optional[str] = None, + workflow_label: Optional[str] = None, + client: Optional[Client] = None, + model_suite: ModelSuite = DEFAULT_MODEL_SUITE, + ): + sample_records = process_sample_records(sample_records) + super().__init__( + config=GenerateDatasetFromSampleRecordsConfig( + sample_records=sample_records, + target_num_records=target_num_records, + system_prompt_type=system_prompt_type, + num_records_per_seed=num_records_per_seed, + num_examples_per_prompt=num_examples_per_prompt, + dataset_context=dataset_context or "", + ), + workflow_label=workflow_label, + client=client, + model_suite=model_suite, + ) + + @property + def name(self) -> str: + return "generate_dataset_from_sample_records" + + def run( + self, categorical_data_seeds: Union[dict, CategoricalDataSeeds] + ) -> CategoricalDataSeeds: + if categorical_data_seeds and isinstance(categorical_data_seeds, dict): + categorical_data_seeds = CategoricalDataSeeds(**categorical_data_seeds) + return self._run( + { + "type": "categorical_data_seeds", + "obj": categorical_data_seeds.model_dump(), + } + ) diff --git a/src/gretel_client/navigator/tasks/types.py b/src/gretel_client/navigator/tasks/types.py index d40a414e..4deb9235 100644 --- a/src/gretel_client/navigator/tasks/types.py +++ b/src/gretel_client/navigator/tasks/types.py @@ -47,6 +47,11 @@ class OutputColumnType(str, Enum): CODE = "code" +class SystemPromptType(str, Enum): + REFLECTION = "reflection" + COGNITION = "cognition" + + class LLMType(str, Enum): NATURAL_LANGUAGE = "natural_language" CODE = "code" @@ -249,6 +254,8 @@ def check_dynamic_categories_have_dynamic_subcategories(self) -> Self: class CategoricalDataSeeds(BaseModel): seed_categories: list[SeedCategory] + dataset_schema_map: Optional[dict] = None + @property def needs_generation(self) -> bool: return any(cat.needs_generation for cat in self.seed_categories) or any( @@ -263,8 +270,6 @@ def inspect(self) -> None: "name", "description", "values", - "generated_values", - "num_values_to_generate", "subcategories", ] console = Console() diff --git a/src/gretel_client/navigator/tasks/utils.py b/src/gretel_client/navigator/tasks/utils.py new file mode 100644 index 00000000..2fbc5d0f --- /dev/null +++ b/src/gretel_client/navigator/tasks/utils.py @@ -0,0 +1,90 @@ +import json + +from pathlib import Path +from typing import Optional, Union + +import pandas as pd + +from gretel_client.navigator.log import get_logger +from gretel_client.navigator.tasks.constants import MAX_SAMPLE_SIZE +from gretel_client.navigator.tasks.types import RecordsT + +logger = get_logger(__name__, level="INFO") + + +json_constant_map = { + "-Infinity": float("-Infinity"), + "Infinity": float("Infinity"), + "NaN": None, +} + + +def process_sample_records( + sample_records: Union[str, Path, pd.DataFrame, RecordsT], + subsample_size: Optional[int] = None, +) -> RecordsT: + if isinstance(sample_records, (str, Path)): + sample_records = Path(sample_records) + if sample_records.suffix == ".csv": + sample_records = pd.read_csv(sample_records) + elif sample_records.suffix == ".json": + sample_records = pd.read_json(sample_records) + else: + raise ValueError( + f"Unsupported file format for sample records: {sample_records.suffix}. " + "Supported formats are .csv and .json." + ) + elif isinstance(sample_records, list): + sample_records = pd.DataFrame.from_records(sample_records) + elif not isinstance(sample_records, pd.DataFrame): + raise ValueError( + "Sample records must be a DataFrame, list of records, or a path to a CSV or JSON file." + ) + + sample_size = len(sample_records) + + if sample_size > MAX_SAMPLE_SIZE and ( + subsample_size is None or subsample_size > MAX_SAMPLE_SIZE + ): + raise ValueError( + f"The sample size of {sample_size} records is larger than the " + f"maximum allowed size of {MAX_SAMPLE_SIZE}. Consider setting " + f"subsample_size <= {MAX_SAMPLE_SIZE} to reduce the sample size." + ) + + if subsample_size is not None: + if subsample_size < 1: + raise ValueError("Subsample size must be at least 1.") + + elif subsample_size > MAX_SAMPLE_SIZE: + logger.warning( + f"⚠️ The given subsample size of {subsample_size} is larger than both the input " + f"sample size and the maximum allowed size of {MAX_SAMPLE_SIZE}. We will shuffle " + f"the input data and use the full sample size of {len(sample_records)} records." + ) + + elif subsample_size > sample_size: + logger.warning( + f"⚠️ The given subsample size of {subsample_size} is larger than the number of " + f"records in the sample data. We will shuffle the data and use the " + f"full sample size of {len(sample_records)} records." + ) + + else: + logger.info( + f"🎲 Randomly sampling {subsample_size} records from the input data." + ) + sample_size = subsample_size + + sample_records = ( + sample_records.sample(sample_size, replace=False) + .reset_index(drop=True) + .to_dict(orient="records") + ) + + # Convert NaN and Infinity values to JSON serializable values. + sample_records = [ + json.loads(json.dumps(record), parse_constant=lambda c: json_constant_map[c]) + for record in sample_records + ] + return sample_records diff --git a/src/gretel_client/navigator/workflow.py b/src/gretel_client/navigator/workflow.py index d33cda14..9aa650c6 100644 --- a/src/gretel_client/navigator/workflow.py +++ b/src/gretel_client/navigator/workflow.py @@ -32,9 +32,8 @@ LLMJudgePromptTemplateType, ModelSuite, ) -from gretel_client.projects import Project +from gretel_client.projects.projects import get_project from gretel_client.rest_v1.api.workflows_api import WorkflowsApi -from gretel_client.rest_v1.api_client import ApiClient from gretel_client.workflows.logs import print_logs_for_workflow_run logger = get_logger(__name__, level=logging.INFO) @@ -46,9 +45,10 @@ "evaluate": "🧐", "validate": "πŸ”", "judge": "βš–οΈ", - "sample": "🌱", + "sample": "🎲", "seed": "🌱", "load": "πŸ“₯", + "extract": "πŸ’­", } @@ -60,7 +60,21 @@ def _get_task_log_emoji(task_name: str) -> str: return log_emoji -@dataclass(frozen=True) +def get_task_io_map(client: Client) -> dict: + """Create a mapping of task names to their inputs and output. + + This is helpful for finding the last step to emit a dataset. + """ + task_io = {} + for task in client.registry(): + task_io[task["name"]] = { + "inputs": task["inputs"], + "output": task["output"], + } + return task_io + + +@dataclass class DataSpec: """Specification for dataset created by DataDesigner. @@ -95,13 +109,16 @@ def dataset(self) -> Dataset: def display_sample_record( self, index: Optional[int] = None, + *, syntax_highlighting_theme: str = "dracula", background_color: Optional[str] = None, ) -> None: if self.dataset is None: raise ValueError("No dataset found in the preview results.") if self.data_spec is None: - raise ValueError("A data schema is required to display the sample record.") + raise ValueError( + "A data specification is required to display the sample record" + ) i = index or self._display_cycle_index display_sample_record( record=self.dataset.iloc[i], @@ -122,7 +139,7 @@ def display_sample_record( ) -_TERMINAL_STATUSES = [ +TERMINAL_STATUSES = [ "RUN_STATUS_COMPLETED", "RUN_STATUS_ERROR", "RUN_STATUS_CANCELLED", @@ -136,34 +153,34 @@ class Step(BaseModel): inputs: Optional[list[str]] = [] -class BatchWorkflowRun: - workflow_id: str - workflow_run_id: str - _client: Client - _project: Project - _workflow_api: ApiClient - _workflow_step_names: Optional[list[str]] - _last_dataset_step: Optional[Step] - _last_evaluate_step: Optional[Step] +class DataDesignerBatchJob: def __init__( self, - project: Project, - client: Client, - workflow_id: str, workflow_run_id: str, - workflow_step_names: Optional[list[str]] = None, - last_dataset_step: Optional[str] = None, - last_evaluate_step: Optional[str] = None, + client: Client, + *, + last_dataset_step: Optional[Step] = None, + last_evaluate_step: Optional[Step] = None, ): - self.workflow_id = workflow_id + self.workflow_run_id = workflow_run_id + self._client = client - self._project = project - self._workflow_api = project.session.get_v1_api(WorkflowsApi) + self._session = client._adapter._session self._last_dataset_step = last_dataset_step self._last_evaluate_step = last_evaluate_step - self._workflow_step_names = workflow_step_names + self._workflow_api = self._session.get_v1_api(WorkflowsApi) + self._data_spec: Optional[DataSpec] = None + + run = self._get_run() + self.workflow_id = run.workflow_id + self.workflow_step_names = [step.name for step in run.actions] + self._project = get_project(name=run.project_id, session=self._session) + self._step_io = { + step.name: get_task_io_map(self._client)[step.action_type] + for step in run.actions + } @property def console_url(self) -> str: @@ -172,172 +189,175 @@ def console_url(self) -> str: f"{self.workflow_id}/runs/{self.workflow_run_id}" ) - def wait_for_completion(self) -> None: - logger.info(f"πŸ‘€ Follow along -> {self.console_url}") - while True: - if self._reached_terminal_status(): - break - time.sleep(10) - - def run_status(self) -> str: - run = self._workflow_api.get_workflow_run(workflow_run_id=self.workflow_run_id) - return run.status - - def _reached_terminal_status(self) -> bool: - status = self.run_status() - return status in _TERMINAL_STATUSES + @property + def workflow_run_status(self) -> str: + return self._get_run().status - def poll_logs(self) -> None: - print_logs_for_workflow_run(self.workflow_run_id, self._project.session) + def _check_if_step_exists(self, step_name: str) -> None: + if step_name not in self.workflow_step_names: + raise ValueError( + f"Step {step_name} not found in workflow." + f"Available steps: {self.workflow_step_names}" + ) - def get_step_output( - self, step_name: str, format: Optional[str] = None + def _get_step_output( + self, step_name: str, output_format: Optional[str] = None ) -> TaskOutput: return self._client.get_step_output( workflow_run_id=self.workflow_run_id, step_name=step_name, - format=format, + format=output_format, ) - def download_step_output( + def _download_step_output( self, step_name: str, - format: Optional[str] = None, + output_format: Optional[str] = None, output_dir: Union[str, Path] = ".", ) -> Path: return self._client.download_step_output( workflow_run_id=self.workflow_run_id, step_name=step_name, output_dir=Path(output_dir), - format=format, + format=output_format, ) - -class NavigatorBatchJob: - - def __init__( - self, - *, - workflow_step_names: list[str], - workflow_run: BatchWorkflowRun, - data_spec: Optional[DataSpec] = None, - ): - self._workflow_run = workflow_run - self.workflow_step_names = workflow_step_names - self.data_spec = data_spec - - @property - def workflow_id(self) -> str: - return self._workflow_run.workflow_id - - @property - def workflow_run_id(self) -> str: - return self._workflow_run.workflow_run_id - - @property - def console_url(self) -> str: - return self._workflow_run.console_url - - @property - def status(self) -> str: - return self._workflow_run.run_status() + def _get_run(self): + return self._workflow_api.get_workflow_run( + workflow_run_id=self.workflow_run_id, expand=["actions"] + ) def _fetch_artifact( self, step_name: str, - artifact_type: str, + output_format: str, wait_for_completion: bool = False, **kwargs, ) -> TaskOutput: - if self.status == "RUN_STATUS_COMPLETED": - if artifact_type == "dataset": - logger.info("πŸ’Ώ Fetching dataset from completed workflow run") - return self._workflow_run.get_step_output(step_name) - elif artifact_type == "report": + status = self.get_step_status(step_name) + if status == "RUN_STATUS_COMPLETED": + if output_format == "parquet": + logger.info(f"πŸ’Ώ Fetching dataset from workflow step `{step_name}`") + return self._get_step_output(step_name, output_format=output_format) + elif output_format == "pdf": logger.info( "πŸ“Š Downloading evaluation report from completed workflow run" ) output_dir = kwargs.get("output_dir", ".") - path = self._workflow_run.download_step_output( - step_name, format="pdf", output_dir=output_dir + path = self._download_step_output( + step_name, output_format="pdf", output_dir=output_dir ) logger.info(f"πŸ“„ Evaluation report saved to {path}") return path - elif self.status in {"RUN_STATUS_ERROR", "RUN_STATUS_LOST"}: - logger.error("πŸ›‘ Workflow run failed. Cannot fetch dataset.") - elif self.status in {"RUN_STATUS_CANCELLING", "RUN_STATUS_CANCELLED"}: + elif output_format == "json": + logger.info(f"πŸ“¦ Fetching output from step `{step_name}`") + return self._get_step_output(step_name, output_format=output_format) + else: + raise ValueError(f"Unknown output type: {output_format}") + elif status in {"RUN_STATUS_ERROR", "RUN_STATUS_LOST"}: + logger.error("πŸ›‘ Workflow run failed. Cannot fetch step output.") + elif status in {"RUN_STATUS_CANCELLING", "RUN_STATUS_CANCELLED"}: logger.warning("⚠️ Workflow run was cancelled.") - elif self.status in { + elif status in { "RUN_STATUS_PENDING", "RUN_STATUS_CREATED", "RUN_STATUS_ACTIVE", + "RUN_STATUS_UNKNOWN", }: if wait_for_completion: - logger.info("⏳ Waiting for workflow run to complete...") - self._workflow_run.wait_for_completion() - return self._fetch_artifact(step_name, artifact_type, **kwargs) + logger.info( + f"⏳ Waiting for workflow step `{step_name}` to complete..." + ) + self.wait_for_completion(step_name) + return self._fetch_artifact(step_name, output_format, **kwargs) else: logger.warning( - "πŸ—οΈ We are still building your dataset. " - f"Workflow status: {self.status.split('_')[-1]}. " - "Use the `wait_for_completion` flag to wait for the workflow to complete." + f"πŸ—οΈ We are still building the requested artifact from step '{step_name}'. " + "Set `wait_for_completion=True` to wait for the step to complete. " + f"Workflow status: {self.workflow_run_status}." ) else: - logger.error(f"Unknown workflow status: {self.status}") + logger.error(f"Unknown step status: {status}") - def fetch_step_output(self, step_name: str) -> TaskOutput: - if step_name not in self.workflow_step_names: - raise ValueError( - f"Step {step_name} not found in workflow." - f"Available steps: {self.workflow_step_names}" + def _reached_terminal_status(self, step_name: Optional[str] = None) -> bool: + status = ( + self.workflow_run_status + if step_name is None + else self.get_step_status(step_name) + ) + return status in TERMINAL_STATUSES + + def get_step_status(self, step_name: str) -> str: + self._check_if_step_exists(step_name) + run = self._get_run() + return [a for a in run.actions if a.name == step_name][0].status + + def poll_logs(self) -> None: + print_logs_for_workflow_run(self.workflow_run_id, self._session) + + def wait_for_completion(self, step_name: Optional[str] = None) -> None: + self._check_if_step_exists(step_name) + logger.info(f"πŸ‘€ Follow along -> {self.console_url}") + while True: + if self._reached_terminal_status(step_name): + break + time.sleep(10) + + def fetch_step_output( + self, + step_name: str, + *, + output_format: Optional[str] = None, + wait_for_completion: bool = False, + **kwargs, + ) -> TaskOutput: + self._check_if_step_exists(step_name) + if output_format is None: + output_format = ( + "parquet" if self._step_io[step_name]["output"] == "dataset" else "json" ) - return self._workflow_run.get_step_output(step_name) + return self._fetch_artifact( + step_name, + output_format=output_format, + wait_for_completion=wait_for_completion, + **kwargs, + ) - def fetch_dataset(self, wait_for_completion: bool = False) -> Dataset: - if self._workflow_run._last_dataset_step is None: + def fetch_dataset(self, *, wait_for_completion: bool = False) -> Dataset: + if self._last_dataset_step is None: raise ValueError("The Workflow did not contain a dataset.") return self._fetch_artifact( - step_name=self._workflow_run._last_dataset_step.name, - artifact_type="dataset", + step_name=self._last_dataset_step.name, + output_format="parquet", wait_for_completion=wait_for_completion, ) def download_evaluation_report( self, + *, wait_for_completion: bool = False, output_dir: Union[str, Path] = Path("."), ) -> None: - if self._workflow_run._last_evaluate_step is None: + if self._last_evaluate_step is None: raise ValueError("The Workflow did not contain an evaluation step.") return self._fetch_artifact( - step_name=self._workflow_run._last_evaluate_step.name, - artifact_type="report", + step_name=self._last_evaluate_step.name, + output_format="pdf", wait_for_completion=wait_for_completion, output_dir=output_dir, ) - def display_sample_record( - self, - record: Union[dict, pd.Series, pd.DataFrame], - syntax_highlighting_theme: str = "dracula", - background_color: Optional[str] = None, - ) -> None: - if self.data_spec is None: - raise ValueError("A data schema is required to display the sample record.") - display_sample_record( - record=record, - seed_categories=self.data_spec.seed_category_names, - data_columns=self.data_spec.data_column_names, - seed_subcategories=self.data_spec.seed_subcategory_names, - background_color=background_color, - code_columns=self.data_spec.code_column_names, - validation_columns=self.data_spec.validation_column_names, - code_lang=self.data_spec.code_lang, - syntax_highlighting_theme=syntax_highlighting_theme, + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(\n" + f" workflow_run_id: {self.workflow_run_id}\n" + f" workflow_run_status: {self.workflow_run_status}\n" + f" console_url: {self.console_url}\n" + ")" ) -class NavigatorWorkflow: +class DataDesignerWorkflow: def __init__( self, *, @@ -359,23 +379,16 @@ def __init__( "num_records": 10, "model_suite": self._model_suite, } - self._task_io = {} - # Create a mapping of task names to their inputs and output. - # This is helpful for finding the last step to emit a dataset. - for task in self._client.registry(): - self._task_io[task["name"]] = { - "inputs": task["inputs"], - "output": task["output"], - } + self._task_io = get_task_io_map(self._client) @staticmethod def create_steps_from_sequential_tasks( - task_list: list[Task], verbose_logging: bool = False + task_list: list[Task], *, verbose_logging: bool = False ) -> list[Step]: steps = [] step_names = [] if verbose_logging: - logger.info("βš™οΈ Configuring Navigator Workflow steps:") + logger.info("βš™οΈ Configuring Data Designer Workflow steps:") for i in range(len(task_list)): inputs = [] task = task_list[i] @@ -401,6 +414,7 @@ def create_steps_from_sequential_tasks( def from_sequential_tasks( cls, task_list: list[Task], + *, workflow_name: str = None, session: Optional[ClientConfig] = None, **session_kwargs, @@ -421,6 +435,10 @@ def from_yaml(cls, yaml_str: str) -> Self: def workflow_step_names(self) -> list[str]: return [s.name for s in self._steps] + @property + def step_io_map(self) -> dict[str, dict[str, str]]: + return {s.name: self._task_io[s.task] for s in self._steps} + @property def _last_dataset_step(self) -> Optional[Step]: dataset_steps = [ @@ -534,18 +552,15 @@ def generate_dataset_preview( return preview def submit_batch_job( - self, num_records: int, project_name: Optional[str] = None - ) -> BatchWorkflowRun: + self, num_records: int, *, project_name: Optional[str] = None + ) -> DataDesignerBatchJob: self._globals.update({"num_records": num_records}) response = self._client.submit_batch_workflow( self.to_dict(), num_records, project_name ) - return BatchWorkflowRun( - project=response.project, - client=self._client, - workflow_id=response.workflow_id, + return DataDesignerBatchJob( workflow_run_id=response.workflow_run_id, + client=self._client, last_dataset_step=self._last_dataset_step, last_evaluate_step=self._last_evaluation_step, - workflow_step_names=self.workflow_step_names, ) diff --git a/tests/gretel_client/integration/test_navigator_data_designer.py b/tests/gretel_client/integration/test_navigator_data_designer.py index 01073161..a6a2ccf2 100644 --- a/tests/gretel_client/integration/test_navigator_data_designer.py +++ b/tests/gretel_client/integration/test_navigator_data_designer.py @@ -3,7 +3,7 @@ import pandas as pd from gretel_client.config import get_session_config -from gretel_client.navigator import DataDesigner +from gretel_client.navigator import DataDesignerFactory from gretel_client.projects import Project config = """\ @@ -69,16 +69,63 @@ context_column: sql_context """ - -def test_basic_smoke_test(project: Project, tmpdir: Path): - data_designer = DataDesigner.from_config(config, session=get_session_config()) - data_seeds = data_designer.generate_seed_category_values() +sample_records = [ + { + "question": "In January the families visiting a national park see animals 26 times. In February the families that visit the national park see animals three times as many as were seen there in January. Then in March the animals are shyer and the families who visit the national park see animals half as many times as they were seen in February. How many times total did families see an animal in the first three months of the year?", + "answer": "Animals were seen 26 times in January and three times as much in February, 26 x 3 = <<26*3=78>>78 times animals were seen in February.\nIn March the animals were seen 1/2 as many times as were seen in February, 78 / 2 = <<78/2=39>>39 times animals were seen in March.\nIf animals were seen 26 times in January + 78 times in February + 39 times in March = <<26+78+39=143>>143 times animals were seen in the first three months of the year.\n#### 143", + }, + { + "question": "Sarah is checking out books from the library to read on vacation. She can read 40 words per minute. The books she is checking out have 100 words per page and are 80 pages long. She will be reading for 20 hours. How many books should she check out?", + "answer": "Each book has 8,000 words because 100 x 80 = <<100*80=8000>>8,000\nShe can finish each book in 200 minutes because 8,000 / 40 = <<8000/40=200>>200\nShe will be reading for 1,200 minutes because 20 x 60 = <<20*60=1200>>1,200\nShe needs to check out 6 books because 1,200 / 200 = <<6=6>>6\n#### 6", + }, + { + "question": "At the beginning of the day there were 74 apples in a basket. If Ricki removes 14 apples and Samson removes twice as many as Ricki. How many apples are left in the basket by the end of the day?", + "answer": "There are 74-14 = <<74-14=60>>60 apples left after Ricki removes some.\nSamson removes 14*2 = <<14*2=28>>28 apples.\nThere are 60-28 = <<60-28=32>>32 apples left after Samson removes some.\n#### 32", + }, + { + "question": "A man drives 60 mph for 3 hours. How fast would he have to drive over the next 2 hours to get an average speed of 70 mph?", + "answer": "To have an average speed of 70 mph over 5 hours he needs to travel 70*5=<<70*5=350>>350 miles.\nHe drove 60*3=<<60*3=180>>180 miles in the first 3 hours\nHe needs to travel another 350-180=<<350-180=170>>170 miles over the next 2 hours.\nHis speed needs to be 170/2=<<170/2=85>>85 mph\n#### 85", + }, + { + "question": "Jaynie wants to make leis for the graduation party. It will take 2 and half dozen plumeria flowers to make 1 lei. If she wants to make 4 leis, how many plumeria flowers must she pick from the trees in her yard?", + "answer": "To make 1 lei, Jaynie will need 2.5 x 12 = <<12*2.5=30>>30 plumeria flowers.\nTo make 4 leis, she will need to pick 30 x 4 = <<30*4=120>>120 plumeria flowers from the trees.\n#### 120", + }, + { + "question": "A school is buying virus protection software to cover 50 devices. One software package costs $40 and covers up to 5 devices. The other software package costs $60 and covers up to 10 devices. How much money, in dollars, can the school save by buying the $60 software package instead of the $40 software package?", + "answer": "There are 50/5 = <<50/5=10>>10 sets of 5 devices in the school.\nSo the school will pay a total of $40 x 10 = $<<40*10=400>>400 for the $40 software package.\nThere are 50/10 = <<50/10=5>>5 sets of 10 devices in the school.\nSo the school will pay a total of $60 x 5 = $<<60*5=300>>300 for the $60 software package.\nThus, the school can save $400 - $100 = $<<400-100=300>>300 from buying the $60 software instead of the $40 software package.\n#### 100", + }, + { + "question": "Quinten sees three buildings downtown and decides to estimate their heights. He knows from a book on local buildings that the one in the middle is 100 feet tall. The one on the left looks like it is 80% of the height of the middle one. The one on the right looks 20 feet shorter than if the building on the left and middle were stacked on top of each other. How tall does Quinten estimate their total height to be?", + "answer": "He estimates the building on the left is 80 feet tall because 100 x .8 = <<100*.8=80>>80\nThe combined height of the left and middle is 180 because 100 + 80 = <<100+80=180>>180\nThe building on the right he estimates as 160 feet because 180 - 20 = <<180-20=160>>160\nHe estimates the combined height as 340 feet because 80 + 100 + 160 = <<80+100+160=340>>340\n#### 340", + }, + { + "question": "At a pool party, there are 4 pizzas cut into 12 slices each. If the guests eat 39 slices, how many slices are left?", + "answer": "There’s a total of 4 x 12 = <<4*12=48>>48 slices.\nAfter the guests eat, there are 48 - 39 = <<48-39=9>>9 pieces.\n#### 9", + }, + { + "question": "A farmer gets 20 pounds of bacon on average from a pig. He sells each pound for $6 at the monthly farmer’s market. This month’s pig is a runt that grew to only half the size of the average pig. How many dollars will the farmer make from the pig’s bacon?", + "answer": "The pig grew to half the size of the average pig, so it will produce 20 / 2 = <<20/2=10>>10 pounds of bacon.\nThe rancher will make 10 * 6 = $<<10*6=60>>60 from the pig’s bacon.\n#### 60", + }, + { + "question": "Legacy has 5 bars of gold she received from her father. Her friend Aleena has 2 bars fewer than she has. If a bar of gold is worth $2200, calculate the total value of gold the three have together.", + "answer": "If Legacy has 5 bars, Aleena has 5 bars - 2 bars = <<5-2=3>>3 bars.\nIn total, they have 5 bars + 3 bars = <<5+3=8>>8 bars,\nSince one bar of gold is worth $2200, the 8 bars they have together are worth 8 bars * $2200/bar = $<<8*2200=17600>>17600\n#### 17600", + }, +] + + +def test_from_config_smoke_test(project: Project, tmpdir: Path): + data_designer = DataDesignerFactory.from_config( + config, session=get_session_config() + ) + data_seeds = data_designer.run_data_seeds_step() preview = data_designer.generate_dataset_preview(data_seeds=data_seeds) preview_df: pd.DataFrame = preview.output assert len(preview_df) == 10 + preview.display_sample_record() + batch_job = data_designer.submit_batch_workflow( num_records=10, project_name=project.name, data_seeds=data_seeds ) @@ -86,6 +133,41 @@ def test_basic_smoke_test(project: Project, tmpdir: Path): assert len(df) == 10 - path: Path = batch_job.download_evaluation_report(output_dir=tmpdir) + path: Path = batch_job.download_evaluation_report( + output_dir=tmpdir, wait_for_completion=True + ) assert len(path.read_bytes()) > 0 + + +def test_from_sample_records_smoke_test(project: Project, tmpdir: Path): + data_designer = DataDesignerFactory.from_sample_records( + sample_records, session=get_session_config() + ) + + generation_prompt = """\ + Provide a thoughtful analysis of the quality of the answer to the provided question. + + In your analysis, consider whether the answer is accurate, relevant, and complete. + + *** Question *** + {question} + + *** Answer *** + {answer} + """ + + data_designer.add_generated_data_column( + name="analysis", generation_prompt=generation_prompt + ) + + data_designer.add_evaluator("general") + + data_seeds = data_designer.run_data_seeds_step() + preview = data_designer.generate_dataset_preview(data_seeds=data_seeds) + + preview_df: pd.DataFrame = preview.output + + assert len(preview_df) == 50 + + preview.display_sample_record()