-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Project import generated by Copybara.
GitOrigin-RevId: c89859d506acff032a4d0bf4ee2f83491cbcc959
- Loading branch information
1 parent
3663930
commit 9440870
Showing
26 changed files
with
1,179 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from gretel_client.navigator.workflow import NavigatorWorkflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from gretel_client.navigator.blueprints.text_to_code.blueprint import ( | ||
TextToCodeBlueprint, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from abc import ABC | ||
|
||
|
||
class NavigatorBlueprint(ABC): | ||
"""Base class for all blueprint classes.""" | ||
|
||
@property | ||
def name(self) -> str: | ||
"""The name of the blueprint.""" | ||
return self.__class__.__name__ | ||
|
||
def __str__(self) -> str: | ||
return self.name | ||
|
||
def __repr__(self) -> str: | ||
return f"<{self.name}>" |
Empty file.
Empty file.
162 changes: 162 additions & 0 deletions
162
src/gretel_client/navigator/blueprints/text_to_code/blueprint.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
from typing import Optional, Union | ||
|
||
from gretel_client.gretel.config_setup import smart_load_yaml | ||
from gretel_client.navigator.blueprints.base import NavigatorBlueprint | ||
from gretel_client.navigator.blueprints.text_to_code.prompt_templates import ( | ||
CODE_PROMPT, | ||
FIELD_GENERATION_PROMPT, | ||
TEXT_PROMPT, | ||
) | ||
from gretel_client.navigator.blueprints.text_to_code.utils import display_nl2code_sample | ||
from gretel_client.navigator.tasks import ( | ||
GenerateColumnFromTemplate, | ||
GenerateSeedValues, | ||
SampleDataSeeds, | ||
ValidateCode, | ||
) | ||
from gretel_client.navigator.tasks.base import Task | ||
from gretel_client.navigator.tasks.io import Dataset | ||
from gretel_client.navigator.workflow import NavigatorWorkflow | ||
|
||
output_parser_instructions = { | ||
"pass_through": "* Return only the requested text, without any additional comments or instructions.", | ||
"json_array": "* Respond only with the list as a valid JSON array.", | ||
} | ||
|
||
output_parser_type_map = { | ||
"str": "pass_through", | ||
"string": "pass_through", | ||
"text": "pass_through", | ||
"json": "json", | ||
"dict": "json", | ||
"list": "json_array", | ||
"json_array": "json_array", | ||
"code": "extract_code", | ||
} | ||
|
||
|
||
@dataclass | ||
class DataPreview: | ||
dataset: Dataset | ||
contextual_columns: list[dict] | ||
blueprint_config: dict | ||
data_seeds: dict | ||
|
||
def display_sample(self, index: Optional[int] = None, **kwargs): | ||
if index is None: | ||
record = self.dataset.sample(1).iloc[0] | ||
else: | ||
record = self.dataset.loc[index] | ||
display_nl2code_sample( | ||
lang=self.blueprint_config["programming_language"], | ||
record=record, | ||
contextual_columns=self.contextual_columns, | ||
**kwargs, | ||
) | ||
|
||
|
||
class TextToCodeBlueprint(NavigatorBlueprint): | ||
|
||
def __init__(self, config: Union[str, dict, Path], **session_kwargs): | ||
self.config = smart_load_yaml(config) | ||
self.lang = self.config["programming_language"] | ||
self.workflow = NavigatorWorkflow(**session_kwargs) | ||
self.task_list = self._build_sequential_task_list() | ||
self.workflow.add_steps( | ||
self.workflow.create_steps_from_sequential_tasks(self.task_list) | ||
) | ||
|
||
def _create_context_template(self, columns: list) -> str: | ||
return "\n".join( | ||
[f" * {c.replace('_', ' ').capitalize()}: {{{c}}}" for c in columns] | ||
) | ||
|
||
def _create_contextual_column_task(self, field) -> Task: | ||
output_parser = output_parser_type_map[field["column_type"]] | ||
generation_type = "text" if field["llm_type"] == "nl" else "code" | ||
system_prompt = self.config[f"{generation_type}_generation_instructions"] | ||
return GenerateColumnFromTemplate( | ||
prompt_template=FIELD_GENERATION_PROMPT.format( | ||
name=field["name"], | ||
description=field["description"], | ||
context=self._create_context_template(field["relevant_columns"]), | ||
generation_type=generation_type.capitalize(), | ||
parser_instructions=output_parser_instructions[output_parser], | ||
), | ||
response_column_name=field["name"], | ||
system_prompt=system_prompt, | ||
workflow_label=f"{field['name'].replace('_', ' ')}", | ||
llm_type=field["llm_type"], | ||
output_parser=output_parser, | ||
) | ||
|
||
def _build_sequential_task_list(self) -> list[Task]: | ||
additional_context_columns = [] | ||
for field in self.config.get("additional_contextual_columns", []): | ||
additional_context_columns.append( | ||
self._create_contextual_column_task(field) | ||
) | ||
|
||
generate_text_column = GenerateColumnFromTemplate( | ||
prompt_template=TEXT_PROMPT.format( | ||
lang=self.lang, | ||
context=self._create_context_template( | ||
self.config["text_relevant_columns"] | ||
), | ||
), | ||
llm_type="nl", | ||
response_column_name="text", | ||
system_prompt=self.config["text_generation_instructions"], | ||
workflow_label="text prompt", | ||
) | ||
|
||
generate_code_column = GenerateColumnFromTemplate( | ||
prompt_template=CODE_PROMPT.format( | ||
lang=self.lang, | ||
context=self._create_context_template( | ||
self.config["code_relevant_columns"] | ||
), | ||
), | ||
llm_type="nl", | ||
response_column_name="code", | ||
system_prompt=self.config["code_generation_instructions"], | ||
workflow_label="code prompt", | ||
output_parser="extract_code", | ||
) | ||
|
||
return [ | ||
GenerateSeedValues(**self.config["seed_generation"]), | ||
SampleDataSeeds(), | ||
*additional_context_columns, | ||
generate_text_column, | ||
generate_code_column, | ||
ValidateCode("python"), | ||
] | ||
|
||
def generate_dataset_preview(self) -> DataPreview: | ||
results = self.workflow.generate_dataset_preview() | ||
|
||
seeds = {} | ||
for s in results.auxiliary_outputs[0]["seed_columns"]: | ||
seeds[s["name"]] = s["starting_values"] + s["generated_values"] | ||
|
||
additional_context = self.config.get("additional_contextual_columns", []) | ||
context_cols = [ | ||
s["name"] for s in self.config["seed_generation"]["seed_columns"] | ||
] | ||
return DataPreview( | ||
dataset=results.dataset, | ||
contextual_columns=context_cols | ||
+ [field["name"] for field in additional_context], | ||
blueprint_config=self.config, | ||
data_seeds=seeds, | ||
) | ||
|
||
def submit_batch_job( | ||
self, num_records: int, project_name: Optional[str] = None | ||
) -> None: | ||
self.workflow.submit_batch_job( | ||
num_records=num_records, project_name=project_name | ||
) |
49 changes: 49 additions & 0 deletions
49
src/gretel_client/navigator/blueprints/text_to_code/prompt_templates.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
TEXT_PROMPT = """\ | ||
Your task is to generate the natural language component of a text-to-{lang} dataset, \ | ||
carefully following the given context and instructions. | ||
### Context: | ||
{context} | ||
### Instructions: | ||
* Generate text related to {lang} code based on the given context. | ||
* Do NOT return any code in the response. | ||
* Return only the requested text, without any additional comments or instructions. | ||
### Text: | ||
""" | ||
|
||
|
||
CODE_PROMPT = """\ | ||
Your task is to generate {lang} code that corresponds to the text and context given below. | ||
### Text: | ||
{{text}} | ||
### Context: | ||
{context} | ||
### Instructions: | ||
* Remember to base your response on the given context. | ||
* Include ONLY a SINGLE block of code WITHOUT ANY additional text. | ||
### Code: | ||
""" | ||
|
||
|
||
FIELD_GENERATION_PROMPT = """\ | ||
Your task is to generate a `{name}` field in a dataset based on the given description and context. | ||
### Description: | ||
{description} | ||
### Context: | ||
{context} | ||
### Instructions: | ||
* Generate `{name}` as described above. | ||
* Remember to base your response on the given context. | ||
{parser_instructions} | ||
### Response: | ||
""" |
50 changes: 50 additions & 0 deletions
50
src/gretel_client/navigator/blueprints/text_to_code/utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from typing import Optional, Union | ||
|
||
import pandas as pd | ||
|
||
from rich.console import Console | ||
from rich.panel import Panel | ||
from rich.syntax import Syntax | ||
from rich.table import Table | ||
from rich.text import Text | ||
|
||
console = Console() | ||
|
||
|
||
def display_nl2code_sample( | ||
lang: str, | ||
record: Union[dict, pd.Series], | ||
contextual_columns: list[str], | ||
theme: str = "dracula", | ||
background_color: Optional[str] = None, | ||
): | ||
if isinstance(record, (dict, pd.Series)): | ||
record = pd.DataFrame([record]).iloc[0] | ||
else: | ||
raise ValueError("record must be a dictionary or pandas Series") | ||
|
||
table = Table(title="Contextual Columns") | ||
|
||
for col in contextual_columns: | ||
table.add_column(col.replace("_", " ").capitalize()) | ||
table.add_row(*[str(record[col]) for col in contextual_columns]) | ||
|
||
console.print(table) | ||
|
||
panel = Panel( | ||
Text(record.text, justify="left", overflow="fold"), | ||
title="Text", | ||
) | ||
console.print(panel) | ||
|
||
panel = Panel( | ||
Syntax( | ||
record.code, | ||
lexer=lang.lower(), | ||
theme=theme, | ||
word_wrap=True, | ||
background_color=background_color, | ||
), | ||
title="Code", | ||
) | ||
console.print(panel) |
Empty file.
Oops, something went wrong.