diff --git a/src/gretel_client/tuner/metrics.py b/src/gretel_client/tuner/metrics.py index d388611b..61783c72 100644 --- a/src/gretel_client/tuner/metrics.py +++ b/src/gretel_client/tuner/metrics.py @@ -2,13 +2,19 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import List +from pathlib import Path +from typing import List, Optional, Union +import pandas as pd + +from gretel_client.gretel.artifact_fetching import fetch_synthetic_data from gretel_client.gretel.config_setup import ( CONFIG_SETUP_DICT, extract_model_config_section, ModelName, ) +from gretel_client.gretel.exceptions import GretelJobSubmissionError +from gretel_client.helpers import poll from gretel_client.projects.models import Model @@ -66,6 +72,7 @@ class BaseTunerMetric(ABC): direction: MetricDirection = MetricDirection.MAXIMIZE def get_gretel_report(self, model: Model) -> dict: + """Get the Gretel synthetic data quality report.""" model_type, _ = extract_model_config_section(model.model_config) report_type = CONFIG_SETUP_DICT[model_type].report_type if report_type is None: @@ -77,6 +84,60 @@ def get_gretel_report(self, model: Model) -> dict: report = json.load(file) return report + def submit_generate_for_trial( + self, + model: Model, + num_records: Optional[int] = None, + seed_data: Optional[Union[str, Path, pd.DataFrame]] = None, + **generate_kwargs, + ) -> pd.DataFrame: + """Submit generate job for hyperparameter tuning trial. + + Only one of `num_records` or `seed_data` can be provided. The former + will generate a complete synthetic dataset, while the latter will + conditionally generate synthetic data based on the seed data. + + Args: + model: Gretel `Model` instance. + num_records: Number of records to generate. + seed_data: Seed data source as a file path or pandas DataFrame. + + Raises: + TypeError: If `model` is not a Gretel `Model` instance. + GretelJobSubmissionError: If the combination of arguments is invalid. + + Returns: + Pandas DataFrame containing the synthetic data. + """ + if not isinstance(model, Model): + raise TypeError(f"Expected a Gretel Model object, got {type(model)}.") + + if num_records is not None and seed_data is not None: + raise GretelJobSubmissionError( + "Only one of `num_records` or `seed_data` can be provided." + ) + + if num_records is None and seed_data is None: + raise GretelJobSubmissionError( + "Either `num_records` or `seed_data` must be provided." + ) + + if num_records is not None: + generate_kwargs.update({"num_records": num_records}) + + data_source = str(seed_data) if isinstance(seed_data, Path) else seed_data + + record_handler = model.create_record_handler_obj( + data_source=data_source, + params=generate_kwargs, + ) + + record_handler.submit() + + poll(record_handler, verbose=False) + + return fetch_synthetic_data(record_handler) + @abstractmethod def __call__(self, model: Model) -> float: """Calculate the optimization metric and return the score as a float.""" diff --git a/tests/gretel_client/integration/test_tuner.py b/tests/gretel_client/integration/test_tuner.py index 16c8d14d..556d8e61 100644 --- a/tests/gretel_client/integration/test_tuner.py +++ b/tests/gretel_client/integration/test_tuner.py @@ -4,10 +4,16 @@ from pathlib import Path from typing import Callable +import pandas as pd import pytest from gretel_client.gretel.config_setup import extract_model_config_section from gretel_client.gretel.interface import Gretel +from gretel_client.tuner.metrics import ( + BaseTunerMetric, + GretelQualityScore, + MetricDirection, +) @pytest.fixture @@ -31,7 +37,34 @@ def gretel() -> Gretel: gretel.get_project().delete() -def test_tuner_tabular(gretel: Gretel, tabular_data_source: Path, tuner_config: Path): +class CustomMetric(BaseTunerMetric): + direction: MetricDirection = MetricDirection.MINIMIZE + + def __call__(self, model): + seed_data = pd.DataFrame({"yea-mon": ["01/01/1998"] * 10}) + df = pd.concat( + [ + self.submit_generate_for_trial(model, seed_data=seed_data), + self.submit_generate_for_trial(model, num_records=100), + ] + ) + return df["balance"].sum() + + +@pytest.mark.parametrize( + "metric,direction", + [ + (GretelQualityScore(), MetricDirection.MAXIMIZE), + (CustomMetric(), MetricDirection.MINIMIZE), + ], +) +def test_tuner_tabular( + gretel: Gretel, + tabular_data_source: Path, + tuner_config: Path, + metric: BaseTunerMetric, + direction: str, +): def callback(c): c["params"]["discriminator_dim"] = c["params"]["generator_dim"] return c @@ -42,10 +75,12 @@ def callback(c): n_jobs=1, n_trials=1, sampler_callback=callback, + metric=metric, ) assert isinstance(tuned.best_config, dict) assert len(tuned.trial_data) == 1 + assert metric.direction == direction _, c = extract_model_config_section(tuned.best_config) assert c["params"]["discriminator_dim"] == c["params"]["generator_dim"]