Skip to content

Commit

Permalink
add generate method to base metric
Browse files Browse the repository at this point in the history
* add generate method

* pandas is a dependency of tuner

* test seed data

* address review comments

GitOrigin-RevId: 3d6beb471a757e3f82d2003b632d94cd051cf986
  • Loading branch information
johnnygreco committed Dec 5, 2023
1 parent 5816ea8 commit 72a8ead
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 2 deletions.
63 changes: 62 additions & 1 deletion src/gretel_client/tuner/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,19 @@

from abc import ABC, abstractmethod
from enum import Enum
from typing import List
from pathlib import Path
from typing import List, Optional, Union

import pandas as pd

from gretel_client.gretel.artifact_fetching import fetch_synthetic_data
from gretel_client.gretel.config_setup import (
CONFIG_SETUP_DICT,
extract_model_config_section,
ModelName,
)
from gretel_client.gretel.exceptions import GretelJobSubmissionError
from gretel_client.helpers import poll
from gretel_client.projects.models import Model


Expand Down Expand Up @@ -66,6 +72,7 @@ class BaseTunerMetric(ABC):
direction: MetricDirection = MetricDirection.MAXIMIZE

def get_gretel_report(self, model: Model) -> dict:
"""Get the Gretel synthetic data quality report."""
model_type, _ = extract_model_config_section(model.model_config)
report_type = CONFIG_SETUP_DICT[model_type].report_type
if report_type is None:
Expand All @@ -77,6 +84,60 @@ def get_gretel_report(self, model: Model) -> dict:
report = json.load(file)
return report

def submit_generate_for_trial(
self,
model: Model,
num_records: Optional[int] = None,
seed_data: Optional[Union[str, Path, pd.DataFrame]] = None,
**generate_kwargs,
) -> pd.DataFrame:
"""Submit generate job for hyperparameter tuning trial.
Only one of `num_records` or `seed_data` can be provided. The former
will generate a complete synthetic dataset, while the latter will
conditionally generate synthetic data based on the seed data.
Args:
model: Gretel `Model` instance.
num_records: Number of records to generate.
seed_data: Seed data source as a file path or pandas DataFrame.
Raises:
TypeError: If `model` is not a Gretel `Model` instance.
GretelJobSubmissionError: If the combination of arguments is invalid.
Returns:
Pandas DataFrame containing the synthetic data.
"""
if not isinstance(model, Model):
raise TypeError(f"Expected a Gretel Model object, got {type(model)}.")

if num_records is not None and seed_data is not None:
raise GretelJobSubmissionError(
"Only one of `num_records` or `seed_data` can be provided."
)

if num_records is None and seed_data is None:
raise GretelJobSubmissionError(
"Either `num_records` or `seed_data` must be provided."
)

if num_records is not None:
generate_kwargs.update({"num_records": num_records})

data_source = str(seed_data) if isinstance(seed_data, Path) else seed_data

record_handler = model.create_record_handler_obj(
data_source=data_source,
params=generate_kwargs,
)

record_handler.submit()

poll(record_handler, verbose=False)

return fetch_synthetic_data(record_handler)

@abstractmethod
def __call__(self, model: Model) -> float:
"""Calculate the optimization metric and return the score as a float."""
Expand Down
37 changes: 36 additions & 1 deletion tests/gretel_client/integration/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@
from pathlib import Path
from typing import Callable

import pandas as pd
import pytest

from gretel_client.gretel.config_setup import extract_model_config_section
from gretel_client.gretel.interface import Gretel
from gretel_client.tuner.metrics import (
BaseTunerMetric,
GretelQualityScore,
MetricDirection,
)


@pytest.fixture
Expand All @@ -31,7 +37,34 @@ def gretel() -> Gretel:
gretel.get_project().delete()


def test_tuner_tabular(gretel: Gretel, tabular_data_source: Path, tuner_config: Path):
class CustomMetric(BaseTunerMetric):
direction: MetricDirection = MetricDirection.MINIMIZE

def __call__(self, model):
seed_data = pd.DataFrame({"yea-mon": ["01/01/1998"] * 10})
df = pd.concat(
[
self.submit_generate_for_trial(model, seed_data=seed_data),
self.submit_generate_for_trial(model, num_records=100),
]
)
return df["balance"].sum()


@pytest.mark.parametrize(
"metric,direction",
[
(GretelQualityScore(), MetricDirection.MAXIMIZE),
(CustomMetric(), MetricDirection.MINIMIZE),
],
)
def test_tuner_tabular(
gretel: Gretel,
tabular_data_source: Path,
tuner_config: Path,
metric: BaseTunerMetric,
direction: str,
):
def callback(c):
c["params"]["discriminator_dim"] = c["params"]["generator_dim"]
return c
Expand All @@ -42,10 +75,12 @@ def callback(c):
n_jobs=1,
n_trials=1,
sampler_callback=callback,
metric=metric,
)

assert isinstance(tuned.best_config, dict)
assert len(tuned.trial_data) == 1
assert metric.direction == direction

_, c = extract_model_config_section(tuned.best_config)
assert c["params"]["discriminator_dim"] == c["params"]["generator_dim"]

0 comments on commit 72a8ead

Please sign in to comment.