Skip to content

Commit

Permalink
[TASK] PLAT-2662: Add support of PII Replay for QualityReport at gret…
Browse files Browse the repository at this point in the history
…el client

GitOrigin-RevId: 0a6d65780039eafc70937e9f18076ed7d93a410d
  • Loading branch information
Anastasia Nesterenko committed Nov 14, 2024
1 parent ab5cb2c commit c7e2a07
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 22 deletions.
16 changes: 8 additions & 8 deletions src/gretel_client/evaluation/downstream_classification_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
from pathlib import Path
from typing import List, Optional, Union

from gretel_client.config import ClientConfig, get_session_config, RunnerMode
from gretel_client.config import ClientConfig, RunnerMode
from gretel_client.evaluation.reports import (
BaseReport,
DEFAULT_RECORD_COUNT,
ReportDictType,
)
from gretel_client.projects.common import DataSourceTypes, RefDataTypes
from gretel_client.projects.models import Model
from gretel_client.projects.projects import Project


Expand All @@ -19,17 +18,18 @@ class DownstreamClassificationReport(BaseReport):
Args:
project: Optional project associated with the report. If no project is passed, a temp project (:obj:`gretel_client.projects.projects.tmp_project`) will be used.
name: Optional name of the model. If no name is provided, a default name will be used.
data_source: Data source used for the report (generally your synthetic data).
ref_data: Reference data used for the report (generally your real data, i.e. the training data for your gretel model).
test_data: Optional data set used as a test set for training models used in report.
output_dir: Optional directory path to write the report to. If the directory does not exist, the path will be created for you.
runner_mode: Determines where to run the model. See :obj:`gretel_client.config.RunnerMode` for a list of valid modes. Manual mode is not explicitly supported.
target: The field which the downstream classifiers are trained to predict. Must be present in both data_source and ref_data.
holdout: The ratio of data to hold out from ref_data (i.e., your real data) as a test set. Must be between 0.0 and 1.0.
models: The list of classifier models to train. If absent or an empty list, use all supported models.
metric: The metric used to sort classifier results. "Accuracy" by default.
record_count: Number of rows to use from the data sets, 5000 by default. A value of 0 means "use as many rows/columns
as possible." We still attempt to maintain parity between the data sets for "fair" comparisons,
target: The field which the downstream classifiers are trained to predict. Must be present in both data_source and ref_data.
holdout: The ratio of data to hold out from ref_data (i.e., your real data) as a test set. Must be between 0.0 and 1.0.
models: The list of classifier models to train. If absent or an empty list, use all supported models.
metric: The metric used to sort classifier results. "Accuracy" by default.
record_count: Number of rows to use from the data sets, 5000 by default. A value of 0 means "use as many rows/columns
as possible." We still attempt to maintain parity between the data sets for "fair" comparisons,
i.e. we will use min(len(train), len(synth)), e.g.
session: The client session to use, or ``None`` to use the session associated with the project
(if any), or the default session otherwise.
Expand Down
16 changes: 8 additions & 8 deletions src/gretel_client/evaluation/downstream_regression_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
from pathlib import Path
from typing import List, Optional, Union

from gretel_client.config import ClientConfig, get_session_config, RunnerMode
from gretel_client.config import ClientConfig, RunnerMode
from gretel_client.evaluation.reports import (
BaseReport,
DEFAULT_RECORD_COUNT,
ReportDictType,
)
from gretel_client.projects.common import DataSourceTypes, RefDataTypes
from gretel_client.projects.models import Model
from gretel_client.projects.projects import Project


Expand All @@ -19,17 +18,18 @@ class DownstreamRegressionReport(BaseReport):
Args:
project: Optional project associated with the report. If no project is passed, a temp project (:obj:`gretel_client.projects.projects.tmp_project`) will be used.
name: Optional name of the model. If no name is provided, a default name will be used.
data_source: Data source used for the report.
ref_data: Reference data used for the report.
test_data: Optional data set used as a test set for training models used in report.
output_dir: Optional directory path to write the report to. If the directory does not exist, the path will be created for you.
runner_mode: Determines where to run the model. See :obj:`gretel_client.config.RunnerMode` for a list of valid modes. Manual mode is not explicitly supported.
target: The field which the downstream regression models are trained to predict. Must be present in both data_source and ref_data.
holdout: The ratio of data to hold out from ref_data (i.e., your real data) as a test set. Must be between 0.0 and 1.0.
models: The list of regression models to train. If absent or an empty list, use all supported models.
metric: The metric used to sort regression results. "R2" by default.
record_count: Number of rows to use from the data sets, 5000 by default. A value of 0 means "use as many rows/columns
as possible." We still attempt to maintain parity between the data sets for "fair" comparisons,
target: The field which the downstream regression models are trained to predict. Must be present in both data_source and ref_data.
holdout: The ratio of data to hold out from ref_data (i.e., your real data) as a test set. Must be between 0.0 and 1.0.
models: The list of regression models to train. If absent or an empty list, use all supported models.
metric: The metric used to sort regression results. "R2" by default.
record_count: Number of rows to use from the data sets, 5000 by default. A value of 0 means "use as many rows/columns
as possible." We still attempt to maintain parity between the data sets for "fair" comparisons,
i.e. we will use min(len(train), len(synth)), e.g.
session: The client session to use, or ``None`` to use the session associated with the project
(if any), or the default session otherwise.
Expand Down
23 changes: 18 additions & 5 deletions src/gretel_client/evaluation/quality_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path
from typing import List, Optional, Union

from gretel_client.config import ClientConfig, get_session_config, RunnerMode
from gretel_client.config import ClientConfig, RunnerMode
from gretel_client.evaluation.reports import (
BaseReport,
DEFAULT_CORRELATION_COLUMNS,
Expand All @@ -20,24 +20,28 @@ class QualityReport(BaseReport):
Args:
project: Optional project associated with the report. If no project is passed, a temp project (:obj:`gretel_client.projects.projects.tmp_project`) will be used.
name: Optional name of the model. If no name is provided, a default name will be used.
data_source: Data source used for the report.
ref_data: Reference data used for the report.
output_dir: Optional directory path to write the report to. If the directory does not exist, the path will be created for you.
runner_mode: Determines where to run the model. See :obj:`gretel_client.config.RunnerMode` for a list of valid modes. Manual mode is not explicitly supported.
record_count: Number of rows to use from the data sets, 5000 by default. A value of 0 means "use as many rows/columns
as possible." We still attempt to maintain parity between the data sets for "fair" comparisons,
record_count: Number of rows to use from the data sets, 5000 by default. A value of 0 means "use as many rows/columns
as possible." We still attempt to maintain parity between the data sets for "fair" comparisons,
i.e. we will use min(len(train), len(synth)), e.g.
correlation_column_count: Similar to record_count, but for number of columns used for correlation calculations.
column_count: Similar to record_count, but for number of columns used for all other calculations.
mandatory_columns: Use in conjuction with correlation_column_count and column_count. The columns listed will be included
in the sample of columns. Any additional requested columns will be selected randomly.
mandatory_columns: Use in conjuction with correlation_column_count and column_count. The columns listed will be included
in the sample of columns. Any additional requested columns will be selected randomly.
session: The client session to use, or ``None`` to use the session associated with the project
(if any), or the default session otherwise.
test_data: Optional reference data used for the Privacy Metrics of the report.
run_pii_replay: Determines if PII Replay should be run for the report. If True, the PII replay section will be included in the report.
pii_entities: List of PII entities to include in the PII Replay section. If None, default entities will be used. This is used only if run_pii_replay is True.
"""

_model_dict: dict = {
"schema_version": "1.0",
"name": "evaluate-quality-model",
"models": [
{
"evaluate": {
Expand Down Expand Up @@ -71,6 +75,8 @@ def __init__(
mandatory_columns: Optional[List[str]] = [],
session: Optional[ClientConfig] = None,
test_data: Optional[RefDataTypes] = None,
run_pii_replay: bool = False,
pii_entities: Optional[List[str]] = None,
):
project, session = BaseReport.resolve_session(project, session)
runner_mode = runner_mode or session.default_runner
Expand All @@ -91,6 +97,13 @@ def __init__(
params["sqs_report_columns"] = column_count
params["mandatory_columns"] = mandatory_columns

if run_pii_replay:
self._model_dict["models"][0]["evaluate"]["pii_replay"] = {"skip": False}
if pii_entities:
self._model_dict["models"][0]["evaluate"]["pii_replay"][
"entities"
] = pii_entities

super().__init__(
project,
data_source,
Expand Down
3 changes: 2 additions & 1 deletion src/gretel_client/gretel/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def submit_train(
model_setup.report_type is not None
and model_config_section.get("data_source") is not None
and model.status == Status.COMPLETED
and not model_config_section.get("evaluate", {}).get("skip", False)
):
report = fetch_model_report(model, model_setup.report_type)

Expand Down Expand Up @@ -623,7 +624,7 @@ def submit_transform(
Returns:
Transform results dataclass with the transformed_dataframe, transform_logs
and as attributes.
and as attributes.
Example::
Expand Down
7 changes: 7 additions & 0 deletions tests/gretel_client/integration/test_quality_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def test_report_initialization_with_defaults(
assert report.output_dir
assert report.model_config == {
"schema_version": "1.0",
"name": "evaluate-quality-model",
"models": [
{
"evaluate": {
Expand Down Expand Up @@ -105,6 +106,8 @@ def test_report_initialization_with_custom_params(
output_dir=tmpdir,
runner_mode=RunnerMode.CLOUD,
mandatory_columns=["foo", "bar", "baz"],
run_pii_replay=True,
pii_entities=["name", "email"],
)
assert report.project
assert report.data_source
Expand All @@ -124,6 +127,10 @@ def test_report_initialization_with_custom_params(
"sqs_report_columns": 250,
"sqs_report_rows": 5000,
},
"pii_replay": {
"skip": False,
"entities": ["name", "email"],
},
}
}
],
Expand Down

0 comments on commit c7e2a07

Please sign in to comment.