Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prediction to bigquery component - initial code #210

Merged
merged 5 commits into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,8 @@
/tfx_addons/message_exit_handler @hanneshapke
/tfx_addons/utils @hanneshapke

# Predictions to Bigquery Component
/tfx_addons/predictions_to_bigquery @hanneshapke

# PandasTransform Component
/tfx_addons/pandas_transform @rcrowe-google
Empty file.
100 changes: 100 additions & 0 deletions tfx_addons/predictions_to_biquery/component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's add the Copyright headers here please.

Digits Prediction-to-BigQuery: Functionality to write prediction results usually from a BulkInferrer to BigQuery.
"""

from typing import Optional

from tfx import types
from tfx.dsl.components.base import base_component, executor_spec
from tfx.types import standard_artifacts
from tfx.types.component_spec import ChannelParameter, ExecutionParameter

from .executor import Executor as AnnotateUnlabeledCategoryDataExecutor

_MIN_THRESHOLD = 0.8
_VOCAB_FILE = "vocab_label_txt"


class AnnotateUnlabeledCategoryDataComponentSpec(types.ComponentSpec):

PARAMETERS = {
# These are parameters that will be passed in the call to
# create an instance of this component.
"vocab_label_file": ExecutionParameter(type=str),
"bq_table_name": ExecutionParameter(type=str),
"filter_threshold": ExecutionParameter(type=float),
"table_suffix": ExecutionParameter(type=str),
"table_partitioning": ExecutionParameter(type=bool),
"expiration_time_delta": ExecutionParameter(type=int),
}
INPUTS = {
# This will be a dictionary with input artifacts, including URIs
"transform_graph": ChannelParameter(type=standard_artifacts.TransformGraph),
"inference_results": ChannelParameter(type=standard_artifacts.InferenceResult),
"schema": ChannelParameter(type=standard_artifacts.Schema),
}
OUTPUTS = {
"bigquery_export": ChannelParameter(type=standard_artifacts.String),
}


class AnnotateUnlabeledCategoryDataComponent(base_component.BaseComponent):
"""
AnnotateUnlabeledCategoryData Component.

The component takes the following input artifacts:
* Inference results: InferenceResult
* Transform graph: TransformGraph
* Schema: Schema (optional) if not present, the component will determine the schema
(only predtion supported at the moment)

The component takes the following parameters:
* vocab_label_file: str - The file name of the file containing the vocabulary labels
(produced by TFT).
* bq_table_name: str - The name of the BigQuery table to write the results to.
* filter_threshold: float - The minimum probability threshold for a prediction to
be considered a positive, thrustworthy prediction. Default is 0.8.
* table_suffix: str (optional) - If provided, the generated datetime string will
be added the BigQuery table name as suffix. The default is %Y%m%d.
* table_partitioning: bool - Whether to partition the table by DAY. If True,
the generated BigQuery table will be partition by date. If False, no partitioning will
be applied. Default is True.
* expiration_time_delta: int (optional) - The number of seconds after which the table will expire.

The component produces the following output artifacts:
* bigquery_export: String - The URI of the BigQuery table containing the results.
"""

SPEC_CLASS = AnnotateUnlabeledCategoryDataComponentSpec
EXECUTOR_SPEC = executor_spec.BeamExecutorSpec(AnnotateUnlabeledCategoryDataExecutor)

def __init__(
self,
inference_results: types.Channel = None,
transform_graph: types.Channel = None,
bq_table_name: str = None,
vocab_label_file: str = _VOCAB_FILE,
filter_threshold: float = _MIN_THRESHOLD,
table_suffix: str = "%Y%m%d",
table_partitioning: bool = True,
schema: Optional[types.Channel] = None,
expiration_time_delta: Optional[int] = 0,
bigquery_export: Optional[types.Channel] = None,
):

bigquery_export = bigquery_export or types.Channel(type=standard_artifacts.String)
schema = schema or types.Channel(type=standard_artifacts.Schema())

spec = AnnotateUnlabeledCategoryDataComponentSpec(
inference_results=inference_results,
transform_graph=transform_graph,
schema=schema,
bq_table_name=bq_table_name,
vocab_label_file=vocab_label_file,
filter_threshold=filter_threshold,
table_suffix=table_suffix,
table_partitioning=table_partitioning,
expiration_time_delta=expiration_time_delta,
bigquery_export=bigquery_export,
)
super().__init__(spec=spec)
178 changes: 178 additions & 0 deletions tfx_addons/predictions_to_biquery/executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""
Executor functionality to write prediction results usually from a BulkInferrer to BigQuery.
"""

import datetime
import os
from typing import Any, Dict, List, Tuple

import apache_beam as beam
import numpy as np
import tensorflow as tf
import tensorflow_transform as tft
from absl import logging
from tensorflow.python.eager.context import eager_mode
from tensorflow_serving.apis import prediction_log_pb2
from tfx import types
from tfx.dsl.components.base import base_beam_executor
from tfx.types import artifact_utils

from .utils import (convert_single_value_to_native_py_value,
create_annotation_fields, feature_to_bq_schema,
load_schema, parse_schema)

_SCORE_MULTIPLIER = 1e6
_SCHEMA_FILE = "schema.pbtxt"
_ADDITIONAL_BQ_PARAMETERS = {}


@beam.typehints.with_input_types(str)
@beam.typehints.with_output_types(beam.typehints.Iterable[Tuple[str, str, Any]])
class FilterPredictionToDictFn(beam.DoFn):
"""
Convert a prediction to a dictionary.
"""

def __init__(
self,
labels: List,
features: Any,
ts: datetime.datetime,
filter_threshold: float,
score_multiplier: int = _SCORE_MULTIPLIER,
):
self.labels = labels
self.features = features
self.filter_threshold = filter_threshold
self.score_multiplier = score_multiplier
self.ts = ts

def _fix_types(self, example):
with eager_mode():
return [convert_single_value_to_native_py_value(v) for v in example.values()]

def _parse_prediction(self, predictions):
prediction_id = np.argmax(predictions)
logging.debug("Prediction id: %s", prediction_id)
logging.debug("Predictions: %s", predictions)
label = self.labels[prediction_id]
score = predictions[0][prediction_id]
return label, score

def process(self, element):
parsed_examples = tf.make_ndarray(element.predict_log.request.inputs["examples"])
parsed_predictions = tf.make_ndarray(element.predict_log.response.outputs["output_0"])

example_values = self._fix_types(tf.io.parse_single_example(parsed_examples[0], self.features))
label, score = self._parse_prediction(parsed_predictions)

if score > self.filter_threshold:
yield {
# TODO: features should be read dynamically
"feature0": example_values[0],
"feature1": example_values[1],
"feature2": example_values[2],
"category_label": label,
"score": int(score * self.score_multiplier),
"datetime": self.ts,
}


class Executor(base_beam_executor.BaseBeamExecutor):
"""
Beam Executor for predictions_to_bq.
"""

def Do(
self,
input_dict: Dict[str, List[types.Artifact]],
output_dict: Dict[str, List[types.Artifact]],
exec_properties: Dict[str, Any],
) -> None:
"""Do function for predictions_to_bq executor."""

ts = datetime.datetime.now().replace(second=0, microsecond=0)

# check required executive properties
if exec_properties["bq_table_name"] is None:
raise ValueError("bq_table_name must be set in exec_properties")
if exec_properties["filter_threshold"] is None:
raise ValueError("filter_threshold must be set in exec_properties")
if exec_properties["vocab_label_file"] is None:
raise ValueError("vocab_label_file must be set in exec_properties")

# get labels from tf transform generated vocab file
transform_output = artifact_utils.get_single_uri(input_dict["transform_graph"])
tf_transform_output = tft.TFTransformOutput(transform_output)
tft_vocab = tf_transform_output.vocabulary_by_name(vocab_filename=exec_properties["vocab_label_file"])
labels = [label.decode() for label in tft_vocab]
logging.info(f"found the following labels from TFT vocab: {labels}")

# get predictions from predict log
inference_results_uri = artifact_utils.get_single_uri(input_dict["inference_results"])

# set table prefix and partitioning parameters
bq_table_name = exec_properties["bq_table_name"]
if exec_properties["table_suffix"]:
bq_table_name += "_" + ts.strftime(exec_properties["table_suffix"])

if exec_properties["expiration_time_delta"]:
expiration_time = int(ts.timestamp()) + exec_properties["expiration_time_delta"]
_ADDITIONAL_BQ_PARAMETERS.update({"expirationTime": str(expiration_time)})
logging.info(f"expiration time on {bq_table_name} set to {expiration_time}")

if exec_properties["table_partitioning"]:
_ADDITIONAL_BQ_PARAMETERS.update({"timePartitioning": {"type": "DAY"}})
logging.info(f"time partitioning on {bq_table_name} set to DAY")

# set prediction result file path and decoder
prediction_log_path = f"{inference_results_uri}/*.gz"
prediction_log_decoder = beam.coders.ProtoCoder(prediction_log_pb2.PredictionLog)

# get features from tfx schema if present
if input_dict["schema"]:
schema_uri = os.path.join(artifact_utils.get_single_uri(input_dict["schema"]), _SCHEMA_FILE)
features = load_schema(schema_uri)

# generate features from predictions
else:
features = parse_schema(prediction_log_path)

# generate bigquery schema from tfx schema (features)
bq_schema_fields = feature_to_bq_schema(features, required=True)
bq_schema_fields.extend(
create_annotation_fields(
label_field_name="category_label", score_field_name="score", required=True, add_datetime_field=True
)
)
bq_schema = {"fields": bq_schema_fields}
logging.info(f"generated bq_schema: {bq_schema}")

with self._make_beam_pipeline() as pipeline:
_ = (
pipeline
| "Read Prediction Log" >> beam.io.ReadFromTFRecord(prediction_log_path, coder=prediction_log_decoder)
| "Filter and Convert to Dict"
>> beam.ParDo(
FilterPredictionToDictFn(
labels=labels,
features=features,
ts=ts,
filter_threshold=exec_properties["filter_threshold"],
)
)
| "Write Dict to BQ"
>> beam.io.gcp.bigquery.WriteToBigQuery(
table=bq_table_name,
schema=bq_schema,
additional_bq_parameters=_ADDITIONAL_BQ_PARAMETERS,
create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED,
write_disposition=beam.io.BigQueryDisposition.WRITE_TRUNCATE,
)
)

bigquery_export = artifact_utils.get_single_instance(output_dict["bigquery_export"])

bigquery_export.set_string_custom_property("generated_bq_table_name", bq_table_name)

logging.info(f"Annotated data exported to {bq_table_name}")
34 changes: 34 additions & 0 deletions tfx_addons/predictions_to_biquery/test_component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
Tests around Digits Prediction-to-BigQuery component.
"""

import tensorflow as tf
from tfx.types import channel_utils, standard_artifacts

from . import component


class ComponentTest(tf.test.TestCase):
def setUp(self):
super(ComponentTest, self).setUp()
self._transform_graph = channel_utils.as_channel([standard_artifacts.TransformGraph()])
self._inference_results = channel_utils.as_channel([standard_artifacts.InferenceResult()])
self._schema = channel_utils.as_channel([standard_artifacts.Schema()])

def testConstruct(self):
# not a real test, just checking if if the component can be
# instantiated
_ = component.AnnotateUnlabeledCategoryDataComponent(
transform_graph=self._transform_graph,
inference_results=self._inference_results,
schema=self._schema,
bq_table_name="gcp_project:bq_database.table",
vocab_label_file="vocab_txt",
filter_threshold=0.1,
table_suffix="%Y",
table_partitioning=False,
)


if __name__ == "__main__":
tf.test.main()
Loading