Skip to content

Commit

Permalink
Some improvements to utilities for testing notebooks
Browse files Browse the repository at this point in the history
* Changes pulled in from kuueflow/examples#764

* Notebook tests should print a link to the stackdriver logs for
  the actual notebook job.

* Related to kubeflow/testing#613
  • Loading branch information
gabrielwen authored and Jeremy Lewi committed Jun 10, 2020
1 parent b8b7179 commit e5d1341
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 10 deletions.
14 changes: 14 additions & 0 deletions py/kubeflow/examples/notebook_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ def pytest_addoption(parser):
"--notebook_path", help=("Path to the testing notebook file, starting from"
"the base directory of examples repository."),
type=str, default="")
parser.addoption(
"--test-target-name", help=("Test target name, used as junit class name."),
type=str, default="")
parser.addoption(
"--artifacts-gcs", help=("GCS to upload artifacts to."),
type=str, default="")

@pytest.fixture
def name(request):
Expand All @@ -40,3 +46,11 @@ def repos(request):
@pytest.fixture
def notebook_path(request):
return request.config.getoption("--notebook_path")

@pytest.fixture
def test_target_name(request):
return request.config.getoption("--test-target-name")

@pytest.fixture
def artifacts_gcs(request):
return request.config.getoption("--artifacts-gcs")
42 changes: 36 additions & 6 deletions py/kubeflow/examples/notebook_tests/nb_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import datetime
import logging
import os
from urllib.parse import urlencode
import uuid
import tempfile
import yaml

from google.cloud import storage
Expand All @@ -20,8 +20,25 @@
NB_BUCKET = "kubeflow-ci-deployment"
PROJECT = "kbueflow-ci-deployment"

def logs_for_job(project, job_name):
"""Get a stack driver link for the job with the specified name."""
logs_filter = f"""resource.type="k8s_container"
labels."k8s-pod/job-name" = "{job_name}"
"""

new_params = {"project": project,
# Logs for last 7 days
"interval": 'P7D',
"advancedFilter": logs_filter}

query = urlencode(new_params)

url = "https://console.cloud.google.com/logs/viewer?" + query

return url

def run_papermill_job(notebook_path, name, namespace, # pylint: disable=too-many-branches,too-many-statements
repos, image):
repos, image, artifacts_gcs="", test_target_name=""):
"""Generate a K8s job to run a notebook using papermill
Args:
Expand All @@ -41,7 +58,7 @@ def run_papermill_job(notebook_path, name, namespace, # pylint: disable=too-many

if notebook_path.startswith("/"):
raise ValueError("notebook_path={0} should not start with /".format(
notebook_path))
notebook_path))

# We need to checkout the correct version of the code
# in presubmits and postsubmits. We should check the environment variables
Expand All @@ -51,6 +68,7 @@ def run_papermill_job(notebook_path, name, namespace, # pylint: disable=too-many
# https://github.com/kubernetes/test-infra/blob/45246b09ed105698aa8fb928b7736d14480def29/prow/jobs.md#job-environment-variables
if not repos:
repos = argo_build_util.get_repo_from_prow_env()
logging.info(f"Using repos {repos}")

if not repos:
raise ValueError("Could not get repos from prow environment variable "
Expand All @@ -75,12 +93,18 @@ def run_papermill_job(notebook_path, name, namespace, # pylint: disable=too-many
"--notebook_path", full_notebook_path]

job["spec"]["template"]["spec"]["containers"][0][
"workingDir"] = os.path.dirname(full_notebook_path)
"workingDir"] = os.path.dirname(full_notebook_path)

# The prow bucket to use for results/artifacts
prow_bucket = prow_artifacts.PROW_RESULTS_BUCKET

if os.getenv("REPO_OWNER") and os.getenv("REPO_NAME"):
if artifacts_gcs:
prow_dir = os.path.join(artifacts_gcs, "artifacts")
if test_target_name:
prow_dir = os.path.join(prow_dir, test_target_name)
logging.info("Prow artifacts directory: %s", prow_dir)
prow_bucket, prow_path = util.split_gcs_uri(prow_dir)
elif os.getenv("REPO_OWNER") and os.getenv("REPO_NAME"):
# Running under prow
prow_dir = prow_artifacts.get_gcs_dir(prow_bucket)
logging.info("Prow artifacts dir: %s", prow_dir)
Expand Down Expand Up @@ -128,11 +152,18 @@ def run_papermill_job(notebook_path, name, namespace, # pylint: disable=too-many
logging.info("Created job %s.%s:\n%s", namespace, name,
yaml.safe_dump(actual_job.to_dict()))

logging.info("*********************Job logs************************")
logging.info(logs_for_job(PROJECT, name))
logging.info("*****************************************************")
final_job = util.wait_for_job(api_client, namespace, name,
timeout=datetime.timedelta(minutes=30))

logging.info("Final job:\n%s", yaml.safe_dump(final_job.to_dict()))

logging.info("*********************Job logs************************")
logging.info(logs_for_job(PROJECT, name))
logging.info("*****************************************************")

# Download notebook html to artifacts
logging.info("Copying %s to bucket %s", output_gcs, prow_bucket)

Expand All @@ -151,4 +182,3 @@ def run_papermill_job(notebook_path, name, namespace, # pylint: disable=too-many
if last_condition.type not in ["Complete"]:
logging.error("Job didn't complete successfully")
raise RuntimeError("Job {0}.{1} failed".format(namespace, name))

9 changes: 5 additions & 4 deletions py/kubeflow/examples/notebook_tests/run_notebook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@
from kubeflow.testing import util

def test_run_notebook(record_xml_attribute, namespace, # pylint: disable=too-many-branches,too-many-statements
repos, image, notebook_path):
repos, image, notebook_path, test_target_name,
artifacts_gcs):
notebook_name = os.path.basename(
notebook_path).replace(".ipynb", "").replace("_", "-")
junit_name = "_".join(["test", notebook_name])
util.set_pytest_junit(record_xml_attribute, junit_name)
util.set_pytest_junit(record_xml_attribute, junit_name, test_target_name)

name = "-".join([notebook_name,
datetime.datetime.now().strftime("%H%M%S"),
uuid.uuid4().hex[0:3]])

util.set_pytest_junit(record_xml_attribute, junit_name)
nb_test_util.run_papermill_job(notebook_path, name, namespace, repos, image)
nb_test_util.run_papermill_job(notebook_path, name, namespace, repos, image,
artifacts_gcs, test_target_name)

if __name__ == '__main__':
logging.basicConfig(level=logging.INFO,
Expand Down

0 comments on commit e5d1341

Please sign in to comment.