Skip to content

Commit

Permalink
chore: use a runtime parameter for pusher destination in the TFX exam…
Browse files Browse the repository at this point in the history
…ple. (#6373)

{} Placeholder doesn't work well in component parameters and
it is better to have it as a runtime parameter for flexibility.

#6311 (comment)
  • Loading branch information
Jiyong Jung authored Aug 18, 2021
1 parent f3f383c commit 053edb5
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 16 deletions.
27 changes: 17 additions & 10 deletions samples/core/parameterized_tfx_oss/parameterized_tfx_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os

import kfp
Expand All @@ -32,14 +33,24 @@
_data_root = '/opt/conda/lib/python3.7/site-packages/tfx/examples/chicago_taxi_pipeline/data/simple'

# Path of pipeline root, should be a GCS path.
pipeline_root = os.path.join(
_pipeline_root = os.path.join(
'gs://{{kfp-default-bucket}}', 'tfx_taxi_simple', kfp.dsl.RUN_ID_PLACEHOLDER
)

# Path that ML models are pushed, should be a GCS path.
_serving_model_dir = os.path.join('gs://your-bucket', 'serving_model', 'tfx_taxi_simple')
_push_destination = tfx.dsl.experimental.RuntimeParameter(
name='push_destination',
default=json.dumps({'filesystem': {'base_directory': _serving_model_dir}}),
ptype=str,
)

def _create_pipeline(
pipeline_root: str, csv_input_location: str,
taxi_module_file: tfx.dsl.experimental.RuntimeParameter, enable_cache: bool
pipeline_root: str,
csv_input_location: str,
taxi_module_file: tfx.dsl.experimental.RuntimeParameter,
push_destination: tfx.dsl.experimental.RuntimeParameter,
enable_cache: bool
):
"""Creates a simple Kubeflow-based Chicago Taxi TFX pipeline.
Expand Down Expand Up @@ -125,12 +136,7 @@ def _create_pipeline(
pusher = tfx.components.Pusher(
model=trainer.outputs['model'],
model_blessing=evaluator.outputs['blessing'],
push_destination=tfx.proto.PushDestination(
filesystem=tfx.proto.PushDestination.Filesystem(
base_directory=os.path.
join(pipeline_root, 'model_serving')
)
),
push_destination=push_destination,
)

return tfx.dsl.Pipeline(
Expand All @@ -147,9 +153,10 @@ def _create_pipeline(
if __name__ == '__main__':
enable_cache = True
pipeline = _create_pipeline(
pipeline_root,
_pipeline_root,
_data_root,
_taxi_module_file_param,
_push_destination,
enable_cache=enable_cache,
)
# Make sure the version of TFX image used is consistent with the version of
Expand Down
16 changes: 10 additions & 6 deletions samples/core/parameterized_tfx_oss/taxi_pipeline_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"\n",
"import kfp\n",
Expand Down Expand Up @@ -113,6 +114,14 @@
" name='module-file',\n",
" default='/opt/conda/lib/python3.7/site-packages/tfx/examples/chicago_taxi_pipeline/taxi_utils_native_keras.py',\n",
" ptype=str,\n",
")\n",
"# Path that ML models are pushed, should be a GCS path.\n",
"# TODO: CHANGE the GCS bucket name to yours.\n",
"serving_model_dir = os.path.join('gs://your-bucket', 'serving_model', 'tfx_taxi_simple')\n",
"push_destination = tfx.dsl.experimental.RuntimeParameter(\n",
" name='push_destination',\n",
" default=json.dumps({'filesystem': {'base_directory': serving_model_dir}}),\n",
" ptype=str,\n",
")"
]
},
Expand All @@ -131,8 +140,6 @@
"metadata": {},
"outputs": [],
"source": [
"# The input data location is parameterized by _data_root_param\n",
"\n",
"example_gen = tfx.components.CsvExampleGen(input_base=data_root)\n",
"\n",
"statistics_gen = tfx.components.StatisticsGen(examples=example_gen.outputs['examples'])\n",
Expand Down Expand Up @@ -207,10 +214,7 @@
"pusher = tfx.components.Pusher(\n",
" model=trainer.outputs['model'],\n",
" model_blessing=evaluator.outputs['blessing'],\n",
" push_destination=tfx.proto.PushDestination(\n",
" filesystem=tfx.proto.PushDestination.Filesystem(\n",
" base_directory=os.path.join(\n",
" pipeline_root, 'model_serving'))))"
" push_destination=push_destination)"
]
},
{
Expand Down

0 comments on commit 053edb5

Please sign in to comment.