Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --params option to %%bigquery magic #6277

Merged
merged 8 commits into from
Oct 30, 2018
19 changes: 19 additions & 0 deletions bigquery/google/cloud/bigquery/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
amount of time for the query to complete will not be cleared after the
query is finished. By default, this information will be displayed but
will be cleared after the query is finished.
* ``--params <params dictionary>`` (optional, line argument):
If present, the argument must be a parsable JSON string. This dictionary
alixhami marked this conversation as resolved.
Show resolved Hide resolved
will be used to format values enclosed within {} in the query.
tseaver marked this conversation as resolved.
Show resolved Hide resolved
* ``<query>`` (required, cell argument):
SQL query to run.

Expand Down Expand Up @@ -100,6 +103,7 @@

alixhami marked this conversation as resolved.
Show resolved Hide resolved
from __future__ import print_function

import ast
import time
from concurrent import futures

Expand Down Expand Up @@ -249,6 +253,13 @@ def _run_query(client, query, job_config=None):
'amount of time for the query to finish. By default, this '
'information will be displayed as the query runs, but will be '
'cleared after the query is finished.'))
@magic_arguments.argument(
'--params',
nargs='+',
default=None,
help=('Parameters to format the query string. If present, it should be a '
'parsable JSON string. The parsed dictionary will be used for string'
alixhami marked this conversation as resolved.
Show resolved Hide resolved
'replacement in the query'))
def _cell_magic(line, query):
"""Underlying function for bigquery cell magic

Expand All @@ -265,6 +276,14 @@ def _cell_magic(line, query):
"""
args = magic_arguments.parse_argstring(_cell_magic, line)

if args.params is not None:
try:
params = ast.literal_eval(''.join(args.params))
except Exception:
raise SyntaxError('--params is not a correctly formatted JSON string')
alixhami marked this conversation as resolved.
Show resolved Hide resolved

query = query.format(**params)
tseaver marked this conversation as resolved.
Show resolved Hide resolved

project = args.project or context.project
client = bigquery.Client(project=project, credentials=context.credentials)
job_config = bigquery.job.QueryJobConfig()
Expand Down
30 changes: 30 additions & 0 deletions bigquery/tests/unit/test_magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,33 @@ def test_bigquery_magic_with_project():
assert client_used.project == 'specific-project'
# context project should not change
assert magics.context.project == 'general-project'


@pytest.mark.usefixtures('ipython_interactive')
@pytest.mark.skipif(pandas is None, reason='Requires `pandas`')
def test_bigquery_magic_with_formatting_params():
ip = IPython.get_ipython()
ip.extension_manager.load_extension('google.cloud.bigquery')
magics.context.credentials = mock.create_autospec(
google.auth.credentials.Credentials, instance=True)

sql = 'SELECT {num} AS num'
result = pandas.DataFrame([17], columns=['num'])
assert 'myvariable' not in ip.user_ns
alixhami marked this conversation as resolved.
Show resolved Hide resolved

run_query_patch = mock.patch(
'google.cloud.bigquery.magics._run_query', autospec=True)
query_job_mock = mock.create_autospec(
google.cloud.bigquery.job.QueryJob, instance=True)
query_job_mock.to_dataframe.return_value = result
with run_query_patch as run_query_mock:
run_query_mock.return_value = query_job_mock

ip.run_cell_magic('bigquery', 'df --params {"num":17}', sql)
alixhami marked this conversation as resolved.
Show resolved Hide resolved
run_query_mock.assert_called_once_with(mock.ANY, sql.format(num=17), mock.ANY)

assert 'df' in ip.user_ns # verify that variable exists
df = ip.user_ns['df']
assert len(df) == len(result) # verify row count
assert list(df) == list(result) # verify column names
tseaver marked this conversation as resolved.
Show resolved Hide resolved