-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prediction to bigquery component - initial code #210
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Validating CODEOWNERS rules …
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
""" | ||
Digits Prediction-to-BigQuery: Functionality to write prediction results usually from a BulkInferrer to BigQuery. | ||
""" | ||
|
||
from typing import Optional | ||
|
||
from tfx import types | ||
from tfx.dsl.components.base import base_component, executor_spec | ||
from tfx.types import standard_artifacts | ||
from tfx.types.component_spec import ChannelParameter, ExecutionParameter | ||
|
||
from .executor import Executor as AnnotateUnlabeledCategoryDataExecutor | ||
|
||
_MIN_THRESHOLD = 0.8 | ||
_VOCAB_FILE = "vocab_label_txt" | ||
|
||
|
||
class AnnotateUnlabeledCategoryDataComponentSpec(types.ComponentSpec): | ||
|
||
PARAMETERS = { | ||
# These are parameters that will be passed in the call to | ||
# create an instance of this component. | ||
"vocab_label_file": ExecutionParameter(type=str), | ||
"bq_table_name": ExecutionParameter(type=str), | ||
"filter_threshold": ExecutionParameter(type=float), | ||
"table_suffix": ExecutionParameter(type=str), | ||
"table_partitioning": ExecutionParameter(type=bool), | ||
"expiration_time_delta": ExecutionParameter(type=int), | ||
} | ||
INPUTS = { | ||
# This will be a dictionary with input artifacts, including URIs | ||
"transform_graph": ChannelParameter(type=standard_artifacts.TransformGraph), | ||
"inference_results": ChannelParameter(type=standard_artifacts.InferenceResult), | ||
"schema": ChannelParameter(type=standard_artifacts.Schema), | ||
} | ||
OUTPUTS = { | ||
"bigquery_export": ChannelParameter(type=standard_artifacts.String), | ||
} | ||
|
||
|
||
class AnnotateUnlabeledCategoryDataComponent(base_component.BaseComponent): | ||
""" | ||
AnnotateUnlabeledCategoryData Component. | ||
|
||
The component takes the following input artifacts: | ||
* Inference results: InferenceResult | ||
* Transform graph: TransformGraph | ||
* Schema: Schema (optional) if not present, the component will determine the schema | ||
(only predtion supported at the moment) | ||
|
||
The component takes the following parameters: | ||
* vocab_label_file: str - The file name of the file containing the vocabulary labels | ||
(produced by TFT). | ||
* bq_table_name: str - The name of the BigQuery table to write the results to. | ||
* filter_threshold: float - The minimum probability threshold for a prediction to | ||
be considered a positive, thrustworthy prediction. Default is 0.8. | ||
* table_suffix: str (optional) - If provided, the generated datetime string will | ||
be added the BigQuery table name as suffix. The default is %Y%m%d. | ||
* table_partitioning: bool - Whether to partition the table by DAY. If True, | ||
the generated BigQuery table will be partition by date. If False, no partitioning will | ||
be applied. Default is True. | ||
* expiration_time_delta: int (optional) - The number of seconds after which the table will expire. | ||
|
||
The component produces the following output artifacts: | ||
* bigquery_export: String - The URI of the BigQuery table containing the results. | ||
""" | ||
|
||
SPEC_CLASS = AnnotateUnlabeledCategoryDataComponentSpec | ||
EXECUTOR_SPEC = executor_spec.BeamExecutorSpec(AnnotateUnlabeledCategoryDataExecutor) | ||
|
||
def __init__( | ||
self, | ||
inference_results: types.Channel = None, | ||
transform_graph: types.Channel = None, | ||
bq_table_name: str = None, | ||
vocab_label_file: str = _VOCAB_FILE, | ||
filter_threshold: float = _MIN_THRESHOLD, | ||
table_suffix: str = "%Y%m%d", | ||
table_partitioning: bool = True, | ||
schema: Optional[types.Channel] = None, | ||
expiration_time_delta: Optional[int] = 0, | ||
bigquery_export: Optional[types.Channel] = None, | ||
): | ||
|
||
bigquery_export = bigquery_export or types.Channel(type=standard_artifacts.String) | ||
schema = schema or types.Channel(type=standard_artifacts.Schema()) | ||
|
||
spec = AnnotateUnlabeledCategoryDataComponentSpec( | ||
inference_results=inference_results, | ||
transform_graph=transform_graph, | ||
schema=schema, | ||
bq_table_name=bq_table_name, | ||
vocab_label_file=vocab_label_file, | ||
filter_threshold=filter_threshold, | ||
table_suffix=table_suffix, | ||
table_partitioning=table_partitioning, | ||
expiration_time_delta=expiration_time_delta, | ||
bigquery_export=bigquery_export, | ||
) | ||
super().__init__(spec=spec) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
""" | ||
Executor functionality to write prediction results usually from a BulkInferrer to BigQuery. | ||
""" | ||
|
||
import datetime | ||
import os | ||
from typing import Any, Dict, List, Tuple | ||
|
||
import apache_beam as beam | ||
import numpy as np | ||
import tensorflow as tf | ||
import tensorflow_transform as tft | ||
from absl import logging | ||
from tensorflow.python.eager.context import eager_mode | ||
from tensorflow_serving.apis import prediction_log_pb2 | ||
from tfx import types | ||
from tfx.dsl.components.base import base_beam_executor | ||
from tfx.types import artifact_utils | ||
|
||
from .utils import (convert_single_value_to_native_py_value, | ||
create_annotation_fields, feature_to_bq_schema, | ||
load_schema, parse_schema) | ||
|
||
_SCORE_MULTIPLIER = 1e6 | ||
_SCHEMA_FILE = "schema.pbtxt" | ||
_ADDITIONAL_BQ_PARAMETERS = {} | ||
|
||
|
||
@beam.typehints.with_input_types(str) | ||
@beam.typehints.with_output_types(beam.typehints.Iterable[Tuple[str, str, Any]]) | ||
class FilterPredictionToDictFn(beam.DoFn): | ||
""" | ||
Convert a prediction to a dictionary. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
labels: List, | ||
features: Any, | ||
ts: datetime.datetime, | ||
filter_threshold: float, | ||
score_multiplier: int = _SCORE_MULTIPLIER, | ||
): | ||
self.labels = labels | ||
self.features = features | ||
self.filter_threshold = filter_threshold | ||
self.score_multiplier = score_multiplier | ||
self.ts = ts | ||
|
||
def _fix_types(self, example): | ||
with eager_mode(): | ||
return [convert_single_value_to_native_py_value(v) for v in example.values()] | ||
|
||
def _parse_prediction(self, predictions): | ||
prediction_id = np.argmax(predictions) | ||
logging.debug("Prediction id: %s", prediction_id) | ||
logging.debug("Predictions: %s", predictions) | ||
label = self.labels[prediction_id] | ||
score = predictions[0][prediction_id] | ||
return label, score | ||
|
||
def process(self, element): | ||
parsed_examples = tf.make_ndarray(element.predict_log.request.inputs["examples"]) | ||
parsed_predictions = tf.make_ndarray(element.predict_log.response.outputs["output_0"]) | ||
|
||
example_values = self._fix_types(tf.io.parse_single_example(parsed_examples[0], self.features)) | ||
label, score = self._parse_prediction(parsed_predictions) | ||
|
||
if score > self.filter_threshold: | ||
yield { | ||
# TODO: features should be read dynamically | ||
"feature0": example_values[0], | ||
"feature1": example_values[1], | ||
"feature2": example_values[2], | ||
"category_label": label, | ||
"score": int(score * self.score_multiplier), | ||
"datetime": self.ts, | ||
} | ||
|
||
|
||
class Executor(base_beam_executor.BaseBeamExecutor): | ||
""" | ||
Beam Executor for predictions_to_bq. | ||
""" | ||
|
||
def Do( | ||
self, | ||
input_dict: Dict[str, List[types.Artifact]], | ||
output_dict: Dict[str, List[types.Artifact]], | ||
exec_properties: Dict[str, Any], | ||
) -> None: | ||
"""Do function for predictions_to_bq executor.""" | ||
|
||
ts = datetime.datetime.now().replace(second=0, microsecond=0) | ||
|
||
# check required executive properties | ||
if exec_properties["bq_table_name"] is None: | ||
raise ValueError("bq_table_name must be set in exec_properties") | ||
if exec_properties["filter_threshold"] is None: | ||
raise ValueError("filter_threshold must be set in exec_properties") | ||
if exec_properties["vocab_label_file"] is None: | ||
raise ValueError("vocab_label_file must be set in exec_properties") | ||
|
||
# get labels from tf transform generated vocab file | ||
transform_output = artifact_utils.get_single_uri(input_dict["transform_graph"]) | ||
tf_transform_output = tft.TFTransformOutput(transform_output) | ||
tft_vocab = tf_transform_output.vocabulary_by_name(vocab_filename=exec_properties["vocab_label_file"]) | ||
labels = [label.decode() for label in tft_vocab] | ||
logging.info(f"found the following labels from TFT vocab: {labels}") | ||
|
||
# get predictions from predict log | ||
inference_results_uri = artifact_utils.get_single_uri(input_dict["inference_results"]) | ||
|
||
# set table prefix and partitioning parameters | ||
bq_table_name = exec_properties["bq_table_name"] | ||
if exec_properties["table_suffix"]: | ||
bq_table_name += "_" + ts.strftime(exec_properties["table_suffix"]) | ||
|
||
if exec_properties["expiration_time_delta"]: | ||
expiration_time = int(ts.timestamp()) + exec_properties["expiration_time_delta"] | ||
_ADDITIONAL_BQ_PARAMETERS.update({"expirationTime": str(expiration_time)}) | ||
logging.info(f"expiration time on {bq_table_name} set to {expiration_time}") | ||
|
||
if exec_properties["table_partitioning"]: | ||
_ADDITIONAL_BQ_PARAMETERS.update({"timePartitioning": {"type": "DAY"}}) | ||
logging.info(f"time partitioning on {bq_table_name} set to DAY") | ||
|
||
# set prediction result file path and decoder | ||
prediction_log_path = f"{inference_results_uri}/*.gz" | ||
prediction_log_decoder = beam.coders.ProtoCoder(prediction_log_pb2.PredictionLog) | ||
|
||
# get features from tfx schema if present | ||
if input_dict["schema"]: | ||
schema_uri = os.path.join(artifact_utils.get_single_uri(input_dict["schema"]), _SCHEMA_FILE) | ||
features = load_schema(schema_uri) | ||
|
||
# generate features from predictions | ||
else: | ||
features = parse_schema(prediction_log_path) | ||
|
||
# generate bigquery schema from tfx schema (features) | ||
bq_schema_fields = feature_to_bq_schema(features, required=True) | ||
bq_schema_fields.extend( | ||
create_annotation_fields( | ||
label_field_name="category_label", score_field_name="score", required=True, add_datetime_field=True | ||
) | ||
) | ||
bq_schema = {"fields": bq_schema_fields} | ||
logging.info(f"generated bq_schema: {bq_schema}") | ||
|
||
with self._make_beam_pipeline() as pipeline: | ||
_ = ( | ||
pipeline | ||
| "Read Prediction Log" >> beam.io.ReadFromTFRecord(prediction_log_path, coder=prediction_log_decoder) | ||
| "Filter and Convert to Dict" | ||
>> beam.ParDo( | ||
FilterPredictionToDictFn( | ||
labels=labels, | ||
features=features, | ||
ts=ts, | ||
filter_threshold=exec_properties["filter_threshold"], | ||
) | ||
) | ||
| "Write Dict to BQ" | ||
>> beam.io.gcp.bigquery.WriteToBigQuery( | ||
table=bq_table_name, | ||
schema=bq_schema, | ||
additional_bq_parameters=_ADDITIONAL_BQ_PARAMETERS, | ||
create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED, | ||
write_disposition=beam.io.BigQueryDisposition.WRITE_TRUNCATE, | ||
) | ||
) | ||
|
||
bigquery_export = artifact_utils.get_single_instance(output_dict["bigquery_export"]) | ||
|
||
bigquery_export.set_string_custom_property("generated_bq_table_name", bq_table_name) | ||
|
||
logging.info(f"Annotated data exported to {bq_table_name}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
""" | ||
Tests around Digits Prediction-to-BigQuery component. | ||
""" | ||
|
||
import tensorflow as tf | ||
from tfx.types import channel_utils, standard_artifacts | ||
|
||
from . import component | ||
|
||
|
||
class ComponentTest(tf.test.TestCase): | ||
def setUp(self): | ||
super(ComponentTest, self).setUp() | ||
self._transform_graph = channel_utils.as_channel([standard_artifacts.TransformGraph()]) | ||
self._inference_results = channel_utils.as_channel([standard_artifacts.InferenceResult()]) | ||
self._schema = channel_utils.as_channel([standard_artifacts.Schema()]) | ||
|
||
def testConstruct(self): | ||
# not a real test, just checking if if the component can be | ||
# instantiated | ||
_ = component.AnnotateUnlabeledCategoryDataComponent( | ||
transform_graph=self._transform_graph, | ||
inference_results=self._inference_results, | ||
schema=self._schema, | ||
bq_table_name="gcp_project:bq_database.table", | ||
vocab_label_file="vocab_txt", | ||
filter_threshold=0.1, | ||
table_suffix="%Y", | ||
table_partitioning=False, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
tf.test.main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: let's add the Copyright headers here please.