Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(bigquery): add close() method to client for releasing open sockets #9894

Merged
merged 5 commits into from
Nov 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions bigquery/google/cloud/bigquery/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,18 @@ def location(self):
"""Default location for jobs / datasets / tables."""
return self._location

def close(self):
"""Close the underlying transport objects, releasing system resources.

.. note::

The client instance can be used for making additional requests even
after closing, in which case the underlying connections are
automatically re-created.
"""
self._http._auth_request.session.close()
self._http.close()

def get_service_account_email(self, project=None):
"""Get the email address of the project's BigQuery service account

Expand Down
160 changes: 92 additions & 68 deletions bigquery/google/cloud/bigquery/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@

import re
import ast
import functools
import sys
import time
from concurrent import futures
Expand Down Expand Up @@ -494,86 +495,91 @@ def _cell_magic(line, query):
args.use_bqstorage_api or context.use_bqstorage_api, context.credentials
)

if args.max_results:
max_results = int(args.max_results)
else:
max_results = None
close_transports = functools.partial(_close_transports, client, bqstorage_client)
plamut marked this conversation as resolved.
Show resolved Hide resolved

query = query.strip()
try:
if args.max_results:
max_results = int(args.max_results)
else:
max_results = None

query = query.strip()

# Any query that does not contain whitespace (aside from leading and trailing whitespace)
# is assumed to be a table id
if not re.search(r"\s", query):
try:
rows = client.list_rows(query, max_results=max_results)
except Exception as ex:
_handle_error(ex, args.destination_var)
return

result = rows.to_dataframe(bqstorage_client=bqstorage_client)
if args.destination_var:
IPython.get_ipython().push({args.destination_var: result})
return
else:
return result

job_config = bigquery.job.QueryJobConfig()
job_config.query_parameters = params
job_config.use_legacy_sql = args.use_legacy_sql
job_config.dry_run = args.dry_run

if args.destination_table:
split = args.destination_table.split(".")
if len(split) != 2:
raise ValueError(
"--destination_table should be in a <dataset_id>.<table_id> format."
)
dataset_id, table_id = split
job_config.allow_large_results = True
dataset_ref = client.dataset(dataset_id)
destination_table_ref = dataset_ref.table(table_id)
job_config.destination = destination_table_ref
job_config.create_disposition = "CREATE_IF_NEEDED"
job_config.write_disposition = "WRITE_TRUNCATE"
_create_dataset_if_necessary(client, dataset_id)

if args.maximum_bytes_billed == "None":
job_config.maximum_bytes_billed = 0
elif args.maximum_bytes_billed is not None:
value = int(args.maximum_bytes_billed)
job_config.maximum_bytes_billed = value

# Any query that does not contain whitespace (aside from leading and trailing whitespace)
# is assumed to be a table id
if not re.search(r"\s", query):
try:
rows = client.list_rows(query, max_results=max_results)
query_job = _run_query(client, query, job_config=job_config)
except Exception as ex:
_handle_error(ex, args.destination_var)
return

result = rows.to_dataframe(bqstorage_client=bqstorage_client)
if args.destination_var:
IPython.get_ipython().push({args.destination_var: result})
return
else:
return result

job_config = bigquery.job.QueryJobConfig()
job_config.query_parameters = params
job_config.use_legacy_sql = args.use_legacy_sql
job_config.dry_run = args.dry_run
if not args.verbose:
display.clear_output()

if args.destination_table:
split = args.destination_table.split(".")
if len(split) != 2:
raise ValueError(
"--destination_table should be in a <dataset_id>.<table_id> format."
if args.dry_run and args.destination_var:
IPython.get_ipython().push({args.destination_var: query_job})
return
elif args.dry_run:
print(
"Query validated. This query will process {} bytes.".format(
query_job.total_bytes_processed
)
)
dataset_id, table_id = split
job_config.allow_large_results = True
dataset_ref = client.dataset(dataset_id)
destination_table_ref = dataset_ref.table(table_id)
job_config.destination = destination_table_ref
job_config.create_disposition = "CREATE_IF_NEEDED"
job_config.write_disposition = "WRITE_TRUNCATE"
_create_dataset_if_necessary(client, dataset_id)

if args.maximum_bytes_billed == "None":
job_config.maximum_bytes_billed = 0
elif args.maximum_bytes_billed is not None:
value = int(args.maximum_bytes_billed)
job_config.maximum_bytes_billed = value

try:
query_job = _run_query(client, query, job_config=job_config)
except Exception as ex:
_handle_error(ex, args.destination_var)
return

if not args.verbose:
display.clear_output()
return query_job

if args.dry_run and args.destination_var:
IPython.get_ipython().push({args.destination_var: query_job})
return
elif args.dry_run:
print(
"Query validated. This query will process {} bytes.".format(
query_job.total_bytes_processed
if max_results:
result = query_job.result(max_results=max_results).to_dataframe(
bqstorage_client=bqstorage_client
)
)
return query_job

if max_results:
result = query_job.result(max_results=max_results).to_dataframe(
bqstorage_client=bqstorage_client
)
else:
result = query_job.to_dataframe(bqstorage_client=bqstorage_client)
else:
result = query_job.to_dataframe(bqstorage_client=bqstorage_client)

if args.destination_var:
IPython.get_ipython().push({args.destination_var: result})
else:
return result
if args.destination_var:
IPython.get_ipython().push({args.destination_var: result})
else:
return result
finally:
close_transports()


def _make_bqstorage_client(use_bqstorage_api, credentials):
Expand Down Expand Up @@ -601,3 +607,21 @@ def _make_bqstorage_client(use_bqstorage_api, credentials):
credentials=credentials,
client_info=gapic_client_info.ClientInfo(user_agent=IPYTHON_USER_AGENT),
)


def _close_transports(client, bqstorage_client):
"""Close the given clients' underlying transport channels.

Closing the transport is needed to release system resources, namely open
sockets.

Args:
client (:class:`~google.cloud.bigquery.client.Client`):
bqstorage_client
(Optional[:class:`~google.cloud.bigquery_storage_v1beta1.BigQueryStorageClient`]):
A client for the BigQuery Storage API.

"""
client.close()
if bqstorage_client is not None:
bqstorage_client.transport.channel.close()
2 changes: 1 addition & 1 deletion bigquery/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def system(session):
session.install("--pre", "grpcio")

# Install all test dependencies, then install local packages in place.
session.install("mock", "pytest")
session.install("mock", "pytest", "psutil")
tswast marked this conversation as resolved.
Show resolved Hide resolved
for local_dep in LOCAL_DEPS:
session.install("-e", local_dep)
session.install("-e", os.path.join("..", "storage"))
Expand Down
28 changes: 28 additions & 0 deletions bigquery/tests/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import re

import six
import psutil
import pytest
import pytz

Expand Down Expand Up @@ -203,6 +204,27 @@ def _create_bucket(self, bucket_name, location=None):

return bucket

def test_close_releases_open_sockets(self):
current_process = psutil.Process()
conn_count_start = len(current_process.connections())

client = Config.CLIENT
client.query(
"""
SELECT
source_year AS year, COUNT(is_male) AS birth_count
FROM `bigquery-public-data.samples.natality`
GROUP BY year
ORDER BY year DESC
LIMIT 15
"""
)

client.close()

conn_count_end = len(current_process.connections())
self.assertEqual(conn_count_end, conn_count_start)

def test_create_dataset(self):
DATASET_ID = _make_dataset_id("create_dataset")
dataset = self.temp_dataset(DATASET_ID)
Expand Down Expand Up @@ -2417,6 +2439,9 @@ def temp_dataset(self, dataset_id, location=None):
@pytest.mark.usefixtures("ipython_interactive")
def test_bigquery_magic():
ip = IPython.get_ipython()
current_process = psutil.Process()
conn_count_start = len(current_process.connections())

ip.extension_manager.load_extension("google.cloud.bigquery")
sql = """
SELECT
Expand All @@ -2432,6 +2457,8 @@ def test_bigquery_magic():
with io.capture_output() as captured:
result = ip.run_cell_magic("bigquery", "", sql)

conn_count_end = len(current_process.connections())

lines = re.split("\n|\r", captured.stdout)
# Removes blanks & terminal code (result of display clearing)
updates = list(filter(lambda x: bool(x) and x != "\x1b[2K", lines))
Expand All @@ -2441,6 +2468,7 @@ def test_bigquery_magic():
assert isinstance(result, pandas.DataFrame)
assert len(result) == 10 # verify row count
assert list(result) == ["url", "view_count"] # verify column names
assert conn_count_end == conn_count_start # system resources are released


def _job_done(instance):
Expand Down
11 changes: 11 additions & 0 deletions bigquery/tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1398,6 +1398,17 @@ def test_create_table_alreadyexists_w_exists_ok_true(self):
]
)

def test_close(self):
creds = _make_credentials()
http = mock.Mock()
http._auth_request.session = mock.Mock()
client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)

client.close()

http.close.assert_called_once()
http._auth_request.session.close.assert_called_once()

def test_get_model(self):
path = "projects/%s/datasets/%s/models/%s" % (
self.PROJECT,
Expand Down
67 changes: 67 additions & 0 deletions bigquery/tests/unit/test_magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,7 @@ def test_bigquery_magic_with_bqstorage_from_argument(monkeypatch):
bqstorage_instance_mock = mock.create_autospec(
bigquery_storage_v1beta1.BigQueryStorageClient, instance=True
)
bqstorage_instance_mock.transport = mock.Mock()
bqstorage_mock.return_value = bqstorage_instance_mock
bqstorage_client_patch = mock.patch(
"google.cloud.bigquery_storage_v1beta1.BigQueryStorageClient", bqstorage_mock
Expand Down Expand Up @@ -601,6 +602,7 @@ def test_bigquery_magic_with_bqstorage_from_context(monkeypatch):
bqstorage_instance_mock = mock.create_autospec(
bigquery_storage_v1beta1.BigQueryStorageClient, instance=True
)
bqstorage_instance_mock.transport = mock.Mock()
bqstorage_mock.return_value = bqstorage_instance_mock
bqstorage_client_patch = mock.patch(
"google.cloud.bigquery_storage_v1beta1.BigQueryStorageClient", bqstorage_mock
Expand Down Expand Up @@ -728,6 +730,41 @@ def test_bigquery_magic_w_max_results_valid_calls_queryjob_result():
query_job_mock.result.assert_called_with(max_results=5)


@pytest.mark.usefixtures("ipython_interactive")
def test_bigquery_magic_w_max_results_query_job_results_fails():
ip = IPython.get_ipython()
ip.extension_manager.load_extension("google.cloud.bigquery")
magics.context._project = None

credentials_mock = mock.create_autospec(
google.auth.credentials.Credentials, instance=True
)
default_patch = mock.patch(
"google.auth.default", return_value=(credentials_mock, "general-project")
)
client_query_patch = mock.patch(
"google.cloud.bigquery.client.Client.query", autospec=True
)
close_transports_patch = mock.patch(
"google.cloud.bigquery.magics._close_transports", autospec=True,
)

sql = "SELECT 17 AS num"

query_job_mock = mock.create_autospec(
google.cloud.bigquery.job.QueryJob, instance=True
)
query_job_mock.result.side_effect = [[], OSError]

with pytest.raises(
OSError
), client_query_patch as client_query_mock, default_patch, close_transports_patch as close_transports:
client_query_mock.return_value = query_job_mock
ip.run_cell_magic("bigquery", "--max_results=5", sql)

assert close_transports.called


def test_bigquery_magic_w_table_id_invalid():
ip = IPython.get_ipython()
ip.extension_manager.load_extension("google.cloud.bigquery")
Expand Down Expand Up @@ -820,6 +857,7 @@ def test_bigquery_magic_w_table_id_and_bqstorage_client():
bqstorage_instance_mock = mock.create_autospec(
bigquery_storage_v1beta1.BigQueryStorageClient, instance=True
)
bqstorage_instance_mock.transport = mock.Mock()
bqstorage_mock.return_value = bqstorage_instance_mock
bqstorage_client_patch = mock.patch(
"google.cloud.bigquery_storage_v1beta1.BigQueryStorageClient", bqstorage_mock
Expand Down Expand Up @@ -1290,3 +1328,32 @@ def test_bigquery_magic_w_destination_table():
assert job_config_used.write_disposition == "WRITE_TRUNCATE"
assert job_config_used.destination.dataset_id == "dataset_id"
assert job_config_used.destination.table_id == "table_id"


@pytest.mark.usefixtures("ipython_interactive")
def test_bigquery_magic_create_dataset_fails():
ip = IPython.get_ipython()
ip.extension_manager.load_extension("google.cloud.bigquery")
magics.context.credentials = mock.create_autospec(
google.auth.credentials.Credentials, instance=True
)

create_dataset_if_necessary_patch = mock.patch(
"google.cloud.bigquery.magics._create_dataset_if_necessary",
autospec=True,
side_effect=OSError,
)
close_transports_patch = mock.patch(
"google.cloud.bigquery.magics._close_transports", autospec=True,
)

with pytest.raises(
OSError
), create_dataset_if_necessary_patch, close_transports_patch as close_transports:
ip.run_cell_magic(
"bigquery",
"--destination_table dataset_id.table_id",
"SELECT foo FROM WHERE LIMIT bar",
)

assert close_transports.called