Skip to content

Commit

Permalink
ddel: validate provider-specific arguments
Browse files Browse the repository at this point in the history
Problem was reported in #87.

PiperOrigin-RevId: 173580233
  • Loading branch information
mbookman authored and eap committed Oct 27, 2017
1 parent deae452 commit 9cce71a
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 71 deletions.
45 changes: 22 additions & 23 deletions dsub/commands/ddel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
Follows the model of qdel.
"""
import argparse
import sys

from ..lib import dsub_util
from ..lib import param_util
from ..lib import resources
from ..providers import provider_base


def parse_arguments():
def _parse_arguments():
"""Parses command line arguments.
Returns:
Expand All @@ -33,25 +33,11 @@ def parse_arguments():
# Handle version flag and exit if it was passed.
param_util.handle_version_flag()

provider_required_args = {
'google': ['project'],
'test-fails': [],
'local': [],
}
epilog = 'Provider-required arguments:\n'
for provider in provider_required_args:
epilog += ' %s: %s\n' % (provider, provider_required_args[provider])
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter, epilog=epilog)
provider_base.add_provider_argument(parser)
parser = provider_base.create_parser(sys.argv[0])

parser.add_argument(
'--version', '-v', default=False, help='Print the dsub version and exit.')
google = parser.add_argument_group(
title='google',
description='Options for the Google provider (Pipelines API)')
google.add_argument(
'--project',
help='Cloud project ID in which to find and delete the job(s)')

parser.add_argument(
'--jobs',
'-j',
Expand Down Expand Up @@ -83,10 +69,23 @@ def parse_arguments():
default=[],
help='User labels to match. Tasks returned must match all labels.',
metavar='KEY=VALUE')
return parser.parse_args()

# Add provider-specific arguments
google = parser.add_argument_group(
title='google',
description='Options for the Google provider (Pipelines API)')
google.add_argument(
'--project',
help='Cloud project ID in which to find and delete the job(s)')

return provider_base.parse_args(parser, {
'google': ['project'],
'test-fails': [],
'local': [],
}, sys.argv[1:])


def emit_search_criteria(users, jobs, tasks, labels):
def _emit_search_criteria(users, jobs, tasks, labels):
"""Print the filters used to delete tasks. Use raw flags as arguments."""
print 'Delete running jobs:'
print ' user:'
Expand All @@ -104,7 +103,7 @@ def emit_search_criteria(users, jobs, tasks, labels):

def main():
# Parse args and validate
args = parse_arguments()
args = _parse_arguments()

# Compute the age filter (if any)
create_time = param_util.age_to_create_time(args.age)
Expand All @@ -122,7 +121,7 @@ def main():

# Let the user know which jobs we are going to look up
with dsub_util.replace_print():
emit_search_criteria(user_list, args.jobs, args.tasks, args.label)
_emit_search_criteria(user_list, args.jobs, args.tasks, args.label)
# Delete the requested jobs
deleted_tasks = ddel_tasks(provider, user_list, args.jobs, args.tasks,
labels, create_time)
Expand Down
50 changes: 23 additions & 27 deletions dsub/commands/dstat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@

from __future__ import print_function

import argparse
import collections
from datetime import datetime
import json
import sys
import time
from dateutil.tz import tzlocal

Expand Down Expand Up @@ -192,7 +192,7 @@ def print_table(self, table):
print(json.dumps(table, indent=2, default=self.serialize))


def prepare_row(task, full):
def _prepare_row(task, full):
"""return a dict with the task's info (more if "full" is set)."""

# Would like to include the Job ID in the default set of columns, but
Expand Down Expand Up @@ -237,7 +237,7 @@ def prepare_row(task, full):
return row


def parse_arguments():
def _parse_arguments():
"""Parses command line arguments.
Returns:
Expand All @@ -246,21 +246,11 @@ def parse_arguments():
# Handle version flag and exit if it was passed.
param_util.handle_version_flag()

provider_required_args = {
'google': ['project'],
'test-fails': [],
'local': [],
}
epilog = 'Provider-required arguments:\n'
for provider in provider_required_args:
epilog += ' %s: %s\n' % (provider, provider_required_args[provider])
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter, epilog=epilog)
parser = provider_base.create_parser(sys.argv[0])

parser.add_argument(
'--version', '-v', default=False, help='Print the dsub version and exit.')
parser.add_argument(
'--project',
help='Cloud project ID in which to query pipeline operations')

parser.add_argument(
'--jobs',
'-j',
Expand Down Expand Up @@ -290,7 +280,9 @@ def parse_arguments():
default=['RUNNING'],
choices=['RUNNING', 'SUCCESS', 'FAILURE', 'CANCELED', '*'],
help="""Lists only those jobs which match the specified status(es).
Use "*" to list jobs of any status.""")
Choose from {'RUNNING', 'SUCCESS', 'FAILURE', 'CANCELED'}.
Use "*" to list jobs of any status.""",
metavar='STATUS')
parser.add_argument(
'--age',
help="""List only those jobs newer than the specified age. Ages can be
Expand Down Expand Up @@ -326,21 +318,25 @@ def parse_arguments():
'--format',
choices=['text', 'json', 'yaml', 'provider-json'],
help='Set the output format.')
# Add provider-specific arguments
provider_base.add_provider_argument(parser)

args = parser.parse_args()
# Add provider-specific arguments
google = parser.add_argument_group(
title='google',
description='Options for the Google provider (Pipelines API)')
google.add_argument(
'--project',
help='Cloud project ID in which to find and delete the job(s)')

# check special flag rules
for arg in provider_required_args[args.provider]:
if not args.__getattribute__(arg):
parser.error('argument --%s is required' % arg)
return args
return provider_base.parse_args(parser, {
'google': ['project'],
'test-fails': [],
'local': [],
}, sys.argv[1:])


def main():
# Parse args and validate
args = parse_arguments()
args = _parse_arguments()

# Compute the age filter (if any)
create_time = param_util.age_to_create_time(args.age)
Expand Down Expand Up @@ -459,7 +455,7 @@ def dstat_job_producer(provider,
if raw_format:
formatted_tasks.append(task.raw_task_data())
else:
formatted_tasks.append(prepare_row(task, full_output))
formatted_tasks.append(_prepare_row(task, full_output))

# Determine if any of the jobs are running.
if task.get_field('task-status') == 'RUNNING':
Expand Down
28 changes: 8 additions & 20 deletions dsub/commands/dsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,22 +183,12 @@ def _parse_arguments(prog, argv):
# Handle version flag and exit if it was passed.
param_util.handle_version_flag()

provider_required_args = {
'google': ['project', 'zones', 'logging'],
'test-fails': [],
'local': ['logging'],
}
epilog = 'Provider-required arguments:\n'
for provider in provider_required_args:
epilog += ' %s: %s\n' % (provider, provider_required_args[provider])
parser = argparse.ArgumentParser(
prog=prog,
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=epilog)
parser = provider_base.create_parser(prog)

# Add dsub core job submission arguments
parser.add_argument(
'--version', '-v', default=False, help='Print the dsub version and exit.')

parser.add_argument(
'--name',
help="""Name for pipeline. Defaults to the script name or
Expand Down Expand Up @@ -332,7 +322,6 @@ def _parse_arguments(prog, argv):
' (either a folder, or file ending in ".log")')

# Add provider-specific arguments
provider_base.add_provider_argument(parser)
google = parser.add_argument_group(
title='google',
description='Options for the Google provider (Pipelines API)')
Expand Down Expand Up @@ -363,13 +352,12 @@ def _parse_arguments(prog, argv):
Allows for connecting to the VM for debugging.
Default is 0; maximum allowed value is 86400 (1 day).""")

args = parser.parse_args(argv)

# check special flag rules
for arg in provider_required_args[args.provider]:
if not args.__getattribute__(arg):
parser.error('argument --%s is required' % arg)
return args
return provider_base.parse_args(
parser, {
'google': ['project', 'zones', 'logging'],
'test-fails': [],
'local': ['logging'],
}, argv)


def _get_job_resources(args):
Expand Down
40 changes: 39 additions & 1 deletion dsub/providers/provider_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Interface for job providers."""
import argparse

from . import google
from . import local
Expand Down Expand Up @@ -48,7 +49,22 @@ def get_provider_name(provider):
return PROVIDER_NAME_MAP[provider.__class__]


def add_provider_argument(parser):
class DsubHelpFormatter(argparse.ArgumentDefaultsHelpFormatter,
argparse.RawDescriptionHelpFormatter):
"""Display defaults in help and display the epilog in its raw format.
As described in https://bugs.python.org/issue13023, there is not a built-in
class to provide both display of defaults as well as displaying the epilog
just as you want it to. The recommended approach is to create a simple
subclass of both Formatters.
"""
pass


def create_parser(prog):
"""Create an argument parser, adding in the list of providers."""
parser = argparse.ArgumentParser(prog=prog, formatter_class=DsubHelpFormatter)

parser.add_argument(
'--provider',
default='google',
Expand All @@ -58,6 +74,28 @@ def add_provider_argument(parser):
are for testing purposes only.""",
metavar='PROVIDER')

return parser


def parse_args(parser, provider_required_args, argv):
"""Add provider required arguments epilog message, parse, and validate."""

# Add the provider required arguments epilog message
epilog = 'Provider-required arguments:\n'
for provider in provider_required_args:
epilog += ' %s: %s\n' % (provider, provider_required_args[provider])
parser.epilog = epilog

# Parse arguments
args = parser.parse_args(argv)

# For the selected provider, check the required arguments
for arg in provider_required_args[args.provider]:
if not args.__getattribute__(arg):
parser.error('argument --%s is required' % arg)

return args


def get_dstat_provider_args(provider, project):
"""A string with the arguments to point dstat to the same provider+project."""
Expand Down

0 comments on commit 9cce71a

Please sign in to comment.