Skip to content

Commit

Permalink
[FIX] Add Tv2 report to SDK API
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 0c54b41c4763876b7e649da93540ce6f9f123995
  • Loading branch information
parkanzky committed Nov 19, 2024
1 parent 2a26175 commit ec3221a
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 9 deletions.
32 changes: 23 additions & 9 deletions src/gretel_client/gretel/artifact_fetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,26 +36,25 @@ class ReportType(str, Enum):

SQS = "sqs"
TEXT = "text"
TRANSFORM = "transform"

@property
def artifact_name(self) -> str: # type: ignore
if self == ReportType.SQS:
return "report"
elif self == ReportType.TEXT:
return "text_metrics_report"
names = {
ReportType.SQS: "report",
ReportType.TRANSFORM: "report",
ReportType.TEXT: "text_metrics_report",
}
return names[self]


@dataclass
class GretelReport:
"""Dataclass for a Gretel synthetic data quality report."""
"""Dataclass base class for Gretel report artifacts."""

as_dict: dict
as_html: str

@property
def quality_scores(self) -> dict:
return {d["field"]: d["value"] for d in self.as_dict["summary"]}

def display_in_browser(self):
"""Display the HTML report in a browser."""
with tempfile.NamedTemporaryFile(suffix=".html") as file:
Expand All @@ -74,6 +73,18 @@ def save_html(self, save_path: Union[str, Path]):
with open(save_path, "w") as file:
file.write(self.as_html)

def __repr__(self):
return f"{self.__class__.__name__}{self.as_dict}"


@dataclass
class GretelDataQualityReport(GretelReport):
"""Dataclass for a Gretel synthetic data quality report."""

@property
def quality_scores(self) -> dict:
return {d["field"]: d["value"] for d in self.as_dict["summary"]}

def __repr__(self):
r = "\n".join([f" {k}: {v}" for k, v in self.quality_scores.items()])
r = "(\n" + r + "\n)\n"
Expand Down Expand Up @@ -122,6 +133,9 @@ def fetch_model_report(
with model.get_artifact_handle(report_type.artifact_name) as file:
report_html = str(file.read(), encoding="utf-8") # type: ignore

if report_type in [ReportType.SQS, ReportType.TEXT]:
return GretelDataQualityReport(as_dict=report_dict, as_html=report_html)

return GretelReport(as_dict=report_dict, as_html=report_html)


Expand Down
12 changes: 12 additions & 0 deletions src/gretel_client/gretel/job_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
fetch_model_report,
fetch_synthetic_data,
GretelReport,
ReportType,
)
from gretel_client.gretel.config_setup import (
CONFIG_SETUP_DICT,
Expand Down Expand Up @@ -200,6 +201,8 @@ class TransformResults(GretelJobResults):
"""URI to the transformed data (as a flat file). This will
not be populated until the transforms job succeeds."""

report: Optional[GretelReport] = None

@property
def model_url(self) -> str:
"""
Expand All @@ -214,6 +217,10 @@ def model_config(self) -> str:
"""
return yaml.safe_dump(self.model.model_config)

@property
def model_config_section(self) -> dict:
return next(iter(self.model.model_config["models"][0].values())) # type: ignore

@property
def job_status(self) -> Status:
"""The current status of the transform job."""
Expand All @@ -230,6 +237,11 @@ def refresh(self) -> None:
if self.transformed_df is None and PANDAS_IS_INSTALLED:
with self.model.get_artifact_handle("data_preview") as fin:
self.transformed_df = pd.read_csv(fin) # type: ignore
if (
self.report is None
and self.model_config_section.get("data_source") is not None
):
self.report = fetch_model_report(self.model, ReportType.TRANSFORM)

# We can fetch model logs no matter what
self.transform_logs = fetch_model_logs(self.model)
Expand Down
4 changes: 4 additions & 0 deletions tests/gretel_client/integration/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def test_transform(
185,
185,
]
assert transform.report is not None

# If the report type selection is broken this will fail
str(transform.report)


def test_transform_errors(
Expand Down

0 comments on commit ec3221a

Please sign in to comment.