Skip to content

Commit

Permalink
feat(bigquery): add support of model for extract job (#71)
Browse files Browse the repository at this point in the history
* feat(bigquery): add support of model for extract job

* feat(bigquery): nit

* feat(bigquery): add source model for create job method

* feat(bigquery): nits

* feat(bigquery): nit
  • Loading branch information
HemangChothani authored May 11, 2020
1 parent df29b7d commit 4a7a514
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 20 deletions.
38 changes: 31 additions & 7 deletions google/cloud/bigquery/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from google.cloud.bigquery import job
from google.cloud.bigquery.model import Model
from google.cloud.bigquery.model import ModelReference
from google.cloud.bigquery.model import _model_arg_to_model_ref
from google.cloud.bigquery.query import _QueryResults
from google.cloud.bigquery.retry import DEFAULT_RETRY
from google.cloud.bigquery.routine import Routine
Expand Down Expand Up @@ -1364,9 +1365,17 @@ def create_job(self, job_config, retry=DEFAULT_RETRY):
job_config
)
source = _get_sub_prop(job_config, ["extract", "sourceTable"])
source_type = "Table"
if not source:
source = _get_sub_prop(job_config, ["extract", "sourceModel"])
source_type = "Model"
destination_uris = _get_sub_prop(job_config, ["extract", "destinationUris"])
return self.extract_table(
source, destination_uris, job_config=extract_job_config, retry=retry
source,
destination_uris,
job_config=extract_job_config,
retry=retry,
source_type=source_type,
)
elif "query" in job_config:
copy_config = copy.deepcopy(job_config)
Expand Down Expand Up @@ -2282,6 +2291,7 @@ def extract_table(
job_config=None,
retry=DEFAULT_RETRY,
timeout=None,
source_type="Table",
):
"""Start a job to extract a table into Cloud Storage files.
Expand All @@ -2292,9 +2302,11 @@ def extract_table(
source (Union[ \
google.cloud.bigquery.table.Table, \
google.cloud.bigquery.table.TableReference, \
google.cloud.bigquery.model.Model, \
google.cloud.bigquery.model.ModelReference, \
src, \
]):
Table to be extracted.
Table or Model to be extracted.
destination_uris (Union[str, Sequence[str]]):
URIs of Cloud Storage file(s) into which table data is to be
extracted; in format
Expand All @@ -2319,17 +2331,19 @@ def extract_table(
timeout (Optional[float]):
The number of seconds to wait for the underlying HTTP transport
before using ``retry``.
Args:
source (google.cloud.bigquery.table.TableReference): table to be extracted.
source_type (str):
(Optional) Type of source to be extracted.``Table`` or ``Model``.
Defaults to ``Table``.
Returns:
google.cloud.bigquery.job.ExtractJob: A new extract job instance.
Raises:
TypeError:
If ``job_config`` is not an instance of :class:`~google.cloud.bigquery.job.ExtractJobConfig`
class.
"""
ValueError:
If ``source_type`` is not among ``Table``,``Model``.
"""
job_id = _make_job_id(job_id, job_id_prefix)

if project is None:
Expand All @@ -2339,7 +2353,17 @@ def extract_table(
location = self.location

job_ref = job._JobReference(job_id, project=project, location=location)
source = _table_arg_to_table_ref(source, default_project=self.project)
src = source_type.lower()
if src == "table":
source = _table_arg_to_table_ref(source, default_project=self.project)
elif src == "model":
source = _model_arg_to_model_ref(source, default_project=self.project)
else:
raise ValueError(
"Cannot pass `{}` as a ``source_type``, pass Table or Model".format(
source_type
)
)

if isinstance(destination_uris, six.string_types):
destination_uris = [destination_uris]
Expand Down
37 changes: 28 additions & 9 deletions google/cloud/bigquery/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1990,8 +1990,11 @@ class ExtractJob(_AsyncJob):
Args:
job_id (str): the job's ID.
source (google.cloud.bigquery.table.TableReference):
Table into which data is to be loaded.
source (Union[ \
google.cloud.bigquery.table.TableReference, \
google.cloud.bigquery.model.ModelReference \
]):
Table or Model from which data is to be loaded or extracted.
destination_uris (List[str]):
URIs describing where the extracted data will be written in Cloud
Expand Down Expand Up @@ -2067,14 +2070,20 @@ def destination_uri_file_counts(self):
def to_api_repr(self):
"""Generate a resource for :meth:`_begin`."""

configuration = self._configuration.to_api_repr()
source_ref = {
"projectId": self.source.project,
"datasetId": self.source.dataset_id,
"tableId": self.source.table_id,
}

configuration = self._configuration.to_api_repr()
_helpers._set_sub_prop(configuration, ["extract", "sourceTable"], source_ref)
source = "sourceTable"
if isinstance(self.source, TableReference):
source_ref["tableId"] = self.source.table_id
else:
source_ref["modelId"] = self.source.model_id
source = "sourceModel"

_helpers._set_sub_prop(configuration, ["extract", source], source_ref)
_helpers._set_sub_prop(
configuration, ["extract", "destinationUris"], self.destination_uris
)
Expand Down Expand Up @@ -2112,10 +2121,20 @@ def from_api_repr(cls, resource, client):
source_config = _helpers._get_sub_prop(
config_resource, ["extract", "sourceTable"]
)
dataset = DatasetReference(
source_config["projectId"], source_config["datasetId"]
)
source = dataset.table(source_config["tableId"])
if source_config:
dataset = DatasetReference(
source_config["projectId"], source_config["datasetId"]
)
source = dataset.table(source_config["tableId"])
else:
source_config = _helpers._get_sub_prop(
config_resource, ["extract", "sourceModel"]
)
dataset = DatasetReference(
source_config["projectId"], source_config["datasetId"]
)
source = dataset.model(source_config["modelId"])

destination_uris = _helpers._get_sub_prop(
config_resource, ["extract", "destinationUris"]
)
Expand Down
12 changes: 12 additions & 0 deletions google/cloud/bigquery/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,3 +433,15 @@ def __repr__(self):
return "ModelReference(project_id='{}', dataset_id='{}', model_id='{}')".format(
self.project, self.dataset_id, self.model_id
)


def _model_arg_to_model_ref(value, default_project=None):
"""Helper to convert a string or Model to ModelReference.
This function keeps ModelReference and other kinds of objects unchanged.
"""
if isinstance(value, six.string_types):
return ModelReference.from_string(value, default_project=default_project)
if isinstance(value, Model):
return value.reference
return value
149 changes: 149 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2884,6 +2884,21 @@ def test_create_job_extract_config(self):
configuration, "google.cloud.bigquery.client.Client.extract_table",
)

def test_create_job_extract_config_for_model(self):
configuration = {
"extract": {
"sourceModel": {
"projectId": self.PROJECT,
"datasetId": self.DS_ID,
"modelId": "source_model",
},
"destinationUris": ["gs://test_bucket/dst_object*"],
}
}
self._create_job_helper(
configuration, "google.cloud.bigquery.client.Client.extract_table",
)

def test_create_job_query_config(self):
configuration = {
"query": {"query": "query", "destinationTable": {"tableId": "table_id"}}
Expand Down Expand Up @@ -4217,6 +4232,140 @@ def test_extract_table_w_destination_uris(self):
self.assertEqual(job.source, source)
self.assertEqual(list(job.destination_uris), [DESTINATION1, DESTINATION2])

def test_extract_table_for_source_type_model(self):
from google.cloud.bigquery.job import ExtractJob

JOB = "job_id"
SOURCE = "source_model"
DESTINATION = "gs://bucket_name/object_name"
RESOURCE = {
"jobReference": {"projectId": self.PROJECT, "jobId": JOB},
"configuration": {
"extract": {
"sourceModel": {
"projectId": self.PROJECT,
"datasetId": self.DS_ID,
"modelId": SOURCE,
},
"destinationUris": [DESTINATION],
}
},
}
creds = _make_credentials()
http = object()
client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)
conn = client._connection = make_connection(RESOURCE)
dataset = DatasetReference(self.PROJECT, self.DS_ID)
source = dataset.model(SOURCE)

job = client.extract_table(
source, DESTINATION, job_id=JOB, timeout=7.5, source_type="Model"
)

# Check that extract_table actually starts the job.
conn.api_request.assert_called_once_with(
method="POST", path="/projects/PROJECT/jobs", data=RESOURCE, timeout=7.5,
)

# Check the job resource.
self.assertIsInstance(job, ExtractJob)
self.assertIs(job._client, client)
self.assertEqual(job.job_id, JOB)
self.assertEqual(job.source, source)
self.assertEqual(list(job.destination_uris), [DESTINATION])

def test_extract_table_for_source_type_model_w_string_model_id(self):
JOB = "job_id"
source_id = "source_model"
DESTINATION = "gs://bucket_name/object_name"
RESOURCE = {
"jobReference": {"projectId": self.PROJECT, "jobId": JOB},
"configuration": {
"extract": {
"sourceModel": {
"projectId": self.PROJECT,
"datasetId": self.DS_ID,
"modelId": source_id,
},
"destinationUris": [DESTINATION],
}
},
}
creds = _make_credentials()
http = object()
client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)
conn = client._connection = make_connection(RESOURCE)

client.extract_table(
# Test with string for model ID.
"{}.{}".format(self.DS_ID, source_id),
DESTINATION,
job_id=JOB,
timeout=7.5,
source_type="Model",
)

# Check that extract_table actually starts the job.
conn.api_request.assert_called_once_with(
method="POST", path="/projects/PROJECT/jobs", data=RESOURCE, timeout=7.5,
)

def test_extract_table_for_source_type_model_w_model_object(self):
from google.cloud.bigquery.model import Model

JOB = "job_id"
DESTINATION = "gs://bucket_name/object_name"
model_id = "{}.{}.{}".format(self.PROJECT, self.DS_ID, self.MODEL_ID)
model = Model(model_id)
RESOURCE = {
"jobReference": {"projectId": self.PROJECT, "jobId": JOB},
"configuration": {
"extract": {
"sourceModel": {
"projectId": self.PROJECT,
"datasetId": self.DS_ID,
"modelId": self.MODEL_ID,
},
"destinationUris": [DESTINATION],
}
},
}
creds = _make_credentials()
http = object()
client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)
conn = client._connection = make_connection(RESOURCE)

client.extract_table(
# Test with Model class object.
model,
DESTINATION,
job_id=JOB,
timeout=7.5,
source_type="Model",
)

# Check that extract_table actually starts the job.
conn.api_request.assert_called_once_with(
method="POST", path="/projects/PROJECT/jobs", data=RESOURCE, timeout=7.5,
)

def test_extract_table_for_invalid_source_type_model(self):
JOB = "job_id"
SOURCE = "source_model"
DESTINATION = "gs://bucket_name/object_name"
creds = _make_credentials()
http = object()
client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)
dataset = DatasetReference(self.PROJECT, self.DS_ID)
source = dataset.model(SOURCE)

with self.assertRaises(ValueError) as exc:
client.extract_table(
source, DESTINATION, job_id=JOB, timeout=7.5, source_type="foo"
)

self.assertIn("Cannot pass", exc.exception.args[0])

def test_query_defaults(self):
from google.cloud.bigquery.job import QueryJob

Expand Down
36 changes: 32 additions & 4 deletions tests/unit/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -3176,10 +3176,16 @@ def _verifyResourceProperties(self, job, resource):

self.assertEqual(job.destination_uris, config["destinationUris"])

table_ref = config["sourceTable"]
self.assertEqual(job.source.project, table_ref["projectId"])
self.assertEqual(job.source.dataset_id, table_ref["datasetId"])
self.assertEqual(job.source.table_id, table_ref["tableId"])
if "sourceTable" in config:
table_ref = config["sourceTable"]
self.assertEqual(job.source.project, table_ref["projectId"])
self.assertEqual(job.source.dataset_id, table_ref["datasetId"])
self.assertEqual(job.source.table_id, table_ref["tableId"])
else:
model_ref = config["sourceModel"]
self.assertEqual(job.source.project, model_ref["projectId"])
self.assertEqual(job.source.dataset_id, model_ref["datasetId"])
self.assertEqual(job.source.model_id, model_ref["modelId"])

if "compression" in config:
self.assertEqual(job.compression, config["compression"])
Expand Down Expand Up @@ -3281,6 +3287,28 @@ def test_from_api_repr_bare(self):
self.assertIs(job._client, client)
self._verifyResourceProperties(job, RESOURCE)

def test_from_api_repr_for_model(self):
self._setUpConstants()
client = _make_client(project=self.PROJECT)
RESOURCE = {
"id": self.JOB_ID,
"jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID},
"configuration": {
"extract": {
"sourceModel": {
"projectId": self.PROJECT,
"datasetId": self.DS_ID,
"modelId": "model_id",
},
"destinationUris": [self.DESTINATION_URI],
}
},
}
klass = self._get_target_class()
job = klass.from_api_repr(RESOURCE, client=client)
self.assertIs(job._client, client)
self._verifyResourceProperties(job, RESOURCE)

def test_from_api_repr_w_properties(self):
from google.cloud.bigquery.job import Compression

Expand Down

0 comments on commit 4a7a514

Please sign in to comment.