Skip to content

Commit

Permalink
CSSE-33/johnny/move gptx hyperparams to params section
Browse files Browse the repository at this point in the history
* added params to gptx config

also updated interface for dgan and gptx

* fixed validation methods

* update gpt-x config unit tests

* fix extra kwargs option

* fixed refresh bug

* update config tests to use new format

* update config parsing logic

* fixed extra_kwargs bug; updated tests

* added tests for timeseries train + generate

* fix gptx config quite poll; ref_data support

* add status code api reporting error message

* add edge case gpt_x backwards compat check

* added future warning for the old config format

* address john's review comments

* a couple more report_kind -> report_type

* support optional data source models

* implement piotr's comments

* use conftest env

* hardcode endpoint in test; tabular in file name

* missed a Gretel instantiation

* Apply suggestions from malte's code review

Co-authored-by: Malte Isberner <2822367+misberner@users.noreply.github.com>

* implement malte's suggestions

* missed usage of extract model config function

* found older comment from malte

* ahhh... found one more

* print project url when creating a new one

* move project url logging to set_project

* update None type hint

---------

Co-authored-by: Malte Isberner <2822367+misberner@users.noreply.github.com>
GitOrigin-RevId: 9acd6b1845c7037a052b554e9233f59f0a1f4593
  • Loading branch information
johnnygreco and misberner committed Oct 17, 2023
1 parent 50fd89c commit f95e00c
Show file tree
Hide file tree
Showing 11 changed files with 2,540 additions and 106 deletions.
26 changes: 23 additions & 3 deletions src/gretel_client/gretel/artifact_fetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import webbrowser

from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import List, Union

Expand All @@ -28,6 +29,20 @@
logger.setLevel(logging.INFO)


class ReportType(str, Enum):
"""The kind of report to fetch."""

SQS = "sqs"
TEXT = "text"

@property
def artifact_name(self) -> str:
if self == ReportType.SQS:
return "report"
elif self == ReportType.TEXT:
return "text_metrics_report"


@dataclass
class GretelReport:
"""Dataclass for a Gretel synthetic data quality report."""
Expand Down Expand Up @@ -79,20 +94,25 @@ def fetch_model_logs(model: Model) -> List[dict]:
return model_logs


def fetch_model_report(model: Model) -> GretelReport:
def fetch_model_report(
model: Model, report_type: ReportType = ReportType.SQS
) -> GretelReport:
"""Fetch the quality report from a model training job.
Args:
model: The Gretel model object.
report_type: The type of report to fetch. One of "sqs" or "text".
Returns:
The Gretel report object.
"""

with model.get_artifact_handle("report_json") as file:
report_type = ReportType(report_type)

with model.get_artifact_handle(f"{report_type.artifact_name}_json") as file:
report_dict = json.load(file)

with model.get_artifact_handle("report") as file:
with model.get_artifact_handle(report_type.artifact_name) as file:
report_html = str(file.read(), encoding="utf-8")

return GretelReport(as_dict=report_dict, as_html=report_html)
Expand Down
144 changes: 122 additions & 22 deletions src/gretel_client/gretel/config_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,68 @@
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union

from gretel_client.gretel.artifact_fetching import ReportType
from gretel_client.gretel.exceptions import BaseConfigError, ConfigSettingError
from gretel_client.projects.exceptions import ModelConfigError
from gretel_client.projects.models import read_model_config


class ConfigDictName(str, Enum):
"""Name of the model parameter dict in the config."""
class ModelName(str, Enum):
"""Name of the model parameter dict in the config.
Note: The values are the names used in the model configs.
"""

# tabular
ACTGAN = "actgan"
AMPLIFY = "amplify"
LSTM = "synthetics"
TABULAR_DP = "tabular_dp"

# text
GPT_X = "gpt_x"

# time series
DGAN = "timeseries_dgan"


@dataclass(frozen=True)
class ModelConfigSections:
"""Config sections for each model type."""
"""Config sections for each model type.
Args:
model_name: Model name used in the url of the model docs.
config_sections: List of nested config sections (e.g., `params`).
data_source_optional: If True, the `data_source` config parameter is optional.
report_type: The type of quality report generated by the model.
extra_kwargs: List of non-nested config sections.
"""

model_name: str
config_sections: List[str]
data_source_optional: bool
report_type: Optional[ReportType] = None
extra_kwargs: Optional[List[str]] = None


CONFIG_SETUP_DICT = {
ConfigDictName.ACTGAN: ModelConfigSections(
ModelName.ACTGAN: ModelConfigSections(
model_name="actgan",
config_sections=["params", "generate", "privacy_filters", "evaluate"],
data_source_optional=False,
report_type=ReportType.SQS,
extra_kwargs=["ref_data"],
),
ConfigDictName.AMPLIFY: ModelConfigSections(
ModelName.AMPLIFY: ModelConfigSections(
model_name="amplify",
config_sections=["params", "evaluate"],
data_source_optional=False,
report_type=ReportType.SQS,
extra_kwargs=["ref_data"],
),
ConfigDictName.LSTM: ModelConfigSections(
ModelName.LSTM: ModelConfigSections(
model_name="lstm",
config_sections=[
"params",
Expand All @@ -47,14 +75,60 @@ class ModelConfigSections:
"privacy_filters",
"evaluate",
],
data_source_optional=False,
report_type=ReportType.SQS,
extra_kwargs=["ref_data"],
),
ConfigDictName.TABULAR_DP: ModelConfigSections(
ModelName.TABULAR_DP: ModelConfigSections(
model_name="tabular_dp",
config_sections=["params", "generate", "evaluate"],
data_source_optional=False,
report_type=ReportType.SQS,
extra_kwargs=["ref_data"],
),
ModelName.GPT_X: ModelConfigSections(
model_name="gpt",
config_sections=["params", "generate"],
data_source_optional=True,
report_type=ReportType.TEXT,
extra_kwargs=["pretrained_model", "column_name", "validation", "ref_data"],
),
ModelName.DGAN: ModelConfigSections(
model_name="dgan",
config_sections=["params", "generate"],
data_source_optional=False,
report_type=None,
extra_kwargs=[
"attribute_columns",
"df_style",
"discrete_columns",
"example_id_column",
"feature_columns",
"ref_data",
"time_column",
],
),
}


def _backwards_compat_transform_config(
config: dict, non_default_settings: dict
) -> dict:
"""
If the base config is in old format *and* the user passes in a params dict, move the
non-default params to the base (non-nested) config level to be consistent with old format.
"""
model_type, model_config_section = extract_model_config_section(config)
if (
model_type == ModelName.GPT_X.value
and "params" in non_default_settings
and "params" not in model_config_section
):
params = non_default_settings.pop("params")
model_config_section.update(params)
return config


def create_model_config_from_base(
base_config: Union[str, Path],
job_label: Optional[str] = None,
Expand All @@ -64,12 +138,14 @@ def create_model_config_from_base(
To update the base config, pass in keyword arguments, where the keys
are config section names and the values are dicts of non-default settings.
If the parameter is not nested within a section, pass it directly as
a keyword argument.
The base config can be given as a yaml file path or the name of one of
the Gretel template files (without the extension) listed here:
https://github.com/gretelai/gretel-blueprints/tree/main/config_templates/gretel/synthetics
Example::
Examples::
# Create an ACTGAN config with 10 epochs.
from gretel_client.gretel.config_setup import create_model_config_from_base
Expand All @@ -78,6 +154,15 @@ def create_model_config_from_base(
params={"epochs": 10},
)
# Create a GPT config with a custom column name and 100 epochs.
from gretel_client.gretel.config_setup import create_model_config_from_base
config = create_model_config_from_base(
base_config="natural-language",
column_name="custom_name", # not nested in a config section
params={"epochs": 100}, # nested in the `params` section
)
The model configs are documented at
https://docs.gretel.ai/reference/synthetics/models. For ACTGAN, the
available config sections are `params`, `generate`, and `privacy_filters`.
Expand Down Expand Up @@ -110,26 +195,41 @@ def create_model_config_from_base(
"gretel-blueprints/tree/main/config_templates/gretel/synthetics"
) from e

dict_name = list(config["models"][0].keys())[0]
setup = CONFIG_SETUP_DICT[ConfigDictName(dict_name)]
model_type, model_config_section = extract_model_config_section(config)
setup = CONFIG_SETUP_DICT[ModelName(model_type)]

config = _backwards_compat_transform_config(config, non_default_settings)

if job_label is not None:
config["name"] = f"{config['name']}-{job_label}"

for section, settings in non_default_settings.items():
if section not in setup.config_sections:
if not isinstance(settings, dict):
extra_kwargs = setup.extra_kwargs or []
if section in extra_kwargs:
model_config_section[section] = settings
else:
raise ConfigSettingError(
f"`{section}` is an invalid keyword argument. Valid options "
f"include {setup.config_sections + extra_kwargs}."
)
elif section not in setup.config_sections:
raise ConfigSettingError(
f"`{section}` is not a valid `{setup.model_name}` config section. "
f"Must be one of [{setup.config_sections}]."
)
if not isinstance(settings, dict):
raise ConfigSettingError(
f"Invalid value for the `{section}` keyword argument. "
f"Must be a dict, but you gave `{type(settings)}`."
)
for k, v in settings.items():
if section not in config["models"][0][dict_name]:
config["models"][0][dict_name][section] = {}
config["models"][0][dict_name][section][k] = v

else:
model_config_section.setdefault(section, {}).update(settings)
return config


def extract_model_config_section(config: dict) -> Tuple[str, dict]:
"""Extract the model type and config dict from a Gretel model config.
Args:
config: The full Gretel config.
Returns:
A tuple of the model type and the model section from the config.
"""
return next(iter(config["models"][0].items()))
Loading

0 comments on commit f95e00c

Please sign in to comment.