diff --git a/src/gretel_client/gretel/artifact_fetching.py b/src/gretel_client/gretel/artifact_fetching.py index 3c8cc8ae..a3d17bf7 100644 --- a/src/gretel_client/gretel/artifact_fetching.py +++ b/src/gretel_client/gretel/artifact_fetching.py @@ -36,26 +36,25 @@ class ReportType(str, Enum): SQS = "sqs" TEXT = "text" + TRANSFORM = "transform" @property def artifact_name(self) -> str: # type: ignore - if self == ReportType.SQS: - return "report" - elif self == ReportType.TEXT: - return "text_metrics_report" + names = { + ReportType.SQS: "report", + ReportType.TRANSFORM: "report", + ReportType.TEXT: "text_metrics_report", + } + return names[self] @dataclass class GretelReport: - """Dataclass for a Gretel synthetic data quality report.""" + """Dataclass base class for Gretel report artifacts.""" as_dict: dict as_html: str - @property - def quality_scores(self) -> dict: - return {d["field"]: d["value"] for d in self.as_dict["summary"]} - def display_in_browser(self): """Display the HTML report in a browser.""" with tempfile.NamedTemporaryFile(suffix=".html") as file: @@ -74,6 +73,18 @@ def save_html(self, save_path: Union[str, Path]): with open(save_path, "w") as file: file.write(self.as_html) + def __repr__(self): + return f"{self.__class__.__name__}{self.as_dict}" + + +@dataclass +class GretelDataQualityReport(GretelReport): + """Dataclass for a Gretel synthetic data quality report.""" + + @property + def quality_scores(self) -> dict: + return {d["field"]: d["value"] for d in self.as_dict["summary"]} + def __repr__(self): r = "\n".join([f" {k}: {v}" for k, v in self.quality_scores.items()]) r = "(\n" + r + "\n)\n" @@ -122,6 +133,9 @@ def fetch_model_report( with model.get_artifact_handle(report_type.artifact_name) as file: report_html = str(file.read(), encoding="utf-8") # type: ignore + if report_type in [ReportType.SQS, ReportType.TEXT]: + return GretelDataQualityReport(as_dict=report_dict, as_html=report_html) + return GretelReport(as_dict=report_dict, as_html=report_html) diff --git a/src/gretel_client/gretel/job_results.py b/src/gretel_client/gretel/job_results.py index 5eca5a93..6f899dda 100644 --- a/src/gretel_client/gretel/job_results.py +++ b/src/gretel_client/gretel/job_results.py @@ -18,6 +18,7 @@ fetch_model_report, fetch_synthetic_data, GretelReport, + ReportType, ) from gretel_client.gretel.config_setup import ( CONFIG_SETUP_DICT, @@ -200,6 +201,8 @@ class TransformResults(GretelJobResults): """URI to the transformed data (as a flat file). This will not be populated until the transforms job succeeds.""" + report: Optional[GretelReport] = None + @property def model_url(self) -> str: """ @@ -214,6 +217,10 @@ def model_config(self) -> str: """ return yaml.safe_dump(self.model.model_config) + @property + def model_config_section(self) -> dict: + return next(iter(self.model.model_config["models"][0].values())) # type: ignore + @property def job_status(self) -> Status: """The current status of the transform job.""" @@ -230,6 +237,11 @@ def refresh(self) -> None: if self.transformed_df is None and PANDAS_IS_INSTALLED: with self.model.get_artifact_handle("data_preview") as fin: self.transformed_df = pd.read_csv(fin) # type: ignore + if ( + self.report is None + and self.model_config_section.get("data_source") is not None + ): + self.report = fetch_model_report(self.model, ReportType.TRANSFORM) # We can fetch model logs no matter what self.transform_logs = fetch_model_logs(self.model) diff --git a/tests/gretel_client/integration/test_transform.py b/tests/gretel_client/integration/test_transform.py index 0bc902d7..a99bc037 100644 --- a/tests/gretel_client/integration/test_transform.py +++ b/tests/gretel_client/integration/test_transform.py @@ -67,6 +67,10 @@ def test_transform( 185, 185, ] + assert transform.report is not None + + # If the report type selection is broken this will fail + str(transform.report) def test_transform_errors(