Skip to content

Commit 4a7a514

Browse files
feat(bigquery): add support of model for extract job (#71)
* feat(bigquery): add support of model for extract job * feat(bigquery): nit * feat(bigquery): add source model for create job method * feat(bigquery): nits * feat(bigquery): nit
1 parent df29b7d commit 4a7a514

File tree

5 files changed

+252
-20
lines changed

5 files changed

+252
-20
lines changed

google/cloud/bigquery/client.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from google.cloud.bigquery import job
6666
from google.cloud.bigquery.model import Model
6767
from google.cloud.bigquery.model import ModelReference
68+
from google.cloud.bigquery.model import _model_arg_to_model_ref
6869
from google.cloud.bigquery.query import _QueryResults
6970
from google.cloud.bigquery.retry import DEFAULT_RETRY
7071
from google.cloud.bigquery.routine import Routine
@@ -1364,9 +1365,17 @@ def create_job(self, job_config, retry=DEFAULT_RETRY):
13641365
job_config
13651366
)
13661367
source = _get_sub_prop(job_config, ["extract", "sourceTable"])
1368+
source_type = "Table"
1369+
if not source:
1370+
source = _get_sub_prop(job_config, ["extract", "sourceModel"])
1371+
source_type = "Model"
13671372
destination_uris = _get_sub_prop(job_config, ["extract", "destinationUris"])
13681373
return self.extract_table(
1369-
source, destination_uris, job_config=extract_job_config, retry=retry
1374+
source,
1375+
destination_uris,
1376+
job_config=extract_job_config,
1377+
retry=retry,
1378+
source_type=source_type,
13701379
)
13711380
elif "query" in job_config:
13721381
copy_config = copy.deepcopy(job_config)
@@ -2282,6 +2291,7 @@ def extract_table(
22822291
job_config=None,
22832292
retry=DEFAULT_RETRY,
22842293
timeout=None,
2294+
source_type="Table",
22852295
):
22862296
"""Start a job to extract a table into Cloud Storage files.
22872297
@@ -2292,9 +2302,11 @@ def extract_table(
22922302
source (Union[ \
22932303
google.cloud.bigquery.table.Table, \
22942304
google.cloud.bigquery.table.TableReference, \
2305+
google.cloud.bigquery.model.Model, \
2306+
google.cloud.bigquery.model.ModelReference, \
22952307
src, \
22962308
]):
2297-
Table to be extracted.
2309+
Table or Model to be extracted.
22982310
destination_uris (Union[str, Sequence[str]]):
22992311
URIs of Cloud Storage file(s) into which table data is to be
23002312
extracted; in format
@@ -2319,17 +2331,19 @@ def extract_table(
23192331
timeout (Optional[float]):
23202332
The number of seconds to wait for the underlying HTTP transport
23212333
before using ``retry``.
2322-
Args:
2323-
source (google.cloud.bigquery.table.TableReference): table to be extracted.
2324-
2334+
source_type (str):
2335+
(Optional) Type of source to be extracted.``Table`` or ``Model``.
2336+
Defaults to ``Table``.
23252337
Returns:
23262338
google.cloud.bigquery.job.ExtractJob: A new extract job instance.
23272339
23282340
Raises:
23292341
TypeError:
23302342
If ``job_config`` is not an instance of :class:`~google.cloud.bigquery.job.ExtractJobConfig`
23312343
class.
2332-
"""
2344+
ValueError:
2345+
If ``source_type`` is not among ``Table``,``Model``.
2346+
"""
23332347
job_id = _make_job_id(job_id, job_id_prefix)
23342348

23352349
if project is None:
@@ -2339,7 +2353,17 @@ def extract_table(
23392353
location = self.location
23402354

23412355
job_ref = job._JobReference(job_id, project=project, location=location)
2342-
source = _table_arg_to_table_ref(source, default_project=self.project)
2356+
src = source_type.lower()
2357+
if src == "table":
2358+
source = _table_arg_to_table_ref(source, default_project=self.project)
2359+
elif src == "model":
2360+
source = _model_arg_to_model_ref(source, default_project=self.project)
2361+
else:
2362+
raise ValueError(
2363+
"Cannot pass `{}` as a ``source_type``, pass Table or Model".format(
2364+
source_type
2365+
)
2366+
)
23432367

23442368
if isinstance(destination_uris, six.string_types):
23452369
destination_uris = [destination_uris]

google/cloud/bigquery/job.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1990,8 +1990,11 @@ class ExtractJob(_AsyncJob):
19901990
Args:
19911991
job_id (str): the job's ID.
19921992
1993-
source (google.cloud.bigquery.table.TableReference):
1994-
Table into which data is to be loaded.
1993+
source (Union[ \
1994+
google.cloud.bigquery.table.TableReference, \
1995+
google.cloud.bigquery.model.ModelReference \
1996+
]):
1997+
Table or Model from which data is to be loaded or extracted.
19951998
19961999
destination_uris (List[str]):
19972000
URIs describing where the extracted data will be written in Cloud
@@ -2067,14 +2070,20 @@ def destination_uri_file_counts(self):
20672070
def to_api_repr(self):
20682071
"""Generate a resource for :meth:`_begin`."""
20692072

2073+
configuration = self._configuration.to_api_repr()
20702074
source_ref = {
20712075
"projectId": self.source.project,
20722076
"datasetId": self.source.dataset_id,
2073-
"tableId": self.source.table_id,
20742077
}
20752078

2076-
configuration = self._configuration.to_api_repr()
2077-
_helpers._set_sub_prop(configuration, ["extract", "sourceTable"], source_ref)
2079+
source = "sourceTable"
2080+
if isinstance(self.source, TableReference):
2081+
source_ref["tableId"] = self.source.table_id
2082+
else:
2083+
source_ref["modelId"] = self.source.model_id
2084+
source = "sourceModel"
2085+
2086+
_helpers._set_sub_prop(configuration, ["extract", source], source_ref)
20782087
_helpers._set_sub_prop(
20792088
configuration, ["extract", "destinationUris"], self.destination_uris
20802089
)
@@ -2112,10 +2121,20 @@ def from_api_repr(cls, resource, client):
21122121
source_config = _helpers._get_sub_prop(
21132122
config_resource, ["extract", "sourceTable"]
21142123
)
2115-
dataset = DatasetReference(
2116-
source_config["projectId"], source_config["datasetId"]
2117-
)
2118-
source = dataset.table(source_config["tableId"])
2124+
if source_config:
2125+
dataset = DatasetReference(
2126+
source_config["projectId"], source_config["datasetId"]
2127+
)
2128+
source = dataset.table(source_config["tableId"])
2129+
else:
2130+
source_config = _helpers._get_sub_prop(
2131+
config_resource, ["extract", "sourceModel"]
2132+
)
2133+
dataset = DatasetReference(
2134+
source_config["projectId"], source_config["datasetId"]
2135+
)
2136+
source = dataset.model(source_config["modelId"])
2137+
21192138
destination_uris = _helpers._get_sub_prop(
21202139
config_resource, ["extract", "destinationUris"]
21212140
)

google/cloud/bigquery/model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,3 +433,15 @@ def __repr__(self):
433433
return "ModelReference(project_id='{}', dataset_id='{}', model_id='{}')".format(
434434
self.project, self.dataset_id, self.model_id
435435
)
436+
437+
438+
def _model_arg_to_model_ref(value, default_project=None):
439+
"""Helper to convert a string or Model to ModelReference.
440+
441+
This function keeps ModelReference and other kinds of objects unchanged.
442+
"""
443+
if isinstance(value, six.string_types):
444+
return ModelReference.from_string(value, default_project=default_project)
445+
if isinstance(value, Model):
446+
return value.reference
447+
return value

tests/unit/test_client.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2884,6 +2884,21 @@ def test_create_job_extract_config(self):
28842884
configuration, "google.cloud.bigquery.client.Client.extract_table",
28852885
)
28862886

2887+
def test_create_job_extract_config_for_model(self):
2888+
configuration = {
2889+
"extract": {
2890+
"sourceModel": {
2891+
"projectId": self.PROJECT,
2892+
"datasetId": self.DS_ID,
2893+
"modelId": "source_model",
2894+
},
2895+
"destinationUris": ["gs://test_bucket/dst_object*"],
2896+
}
2897+
}
2898+
self._create_job_helper(
2899+
configuration, "google.cloud.bigquery.client.Client.extract_table",
2900+
)
2901+
28872902
def test_create_job_query_config(self):
28882903
configuration = {
28892904
"query": {"query": "query", "destinationTable": {"tableId": "table_id"}}
@@ -4217,6 +4232,140 @@ def test_extract_table_w_destination_uris(self):
42174232
self.assertEqual(job.source, source)
42184233
self.assertEqual(list(job.destination_uris), [DESTINATION1, DESTINATION2])
42194234

4235+
def test_extract_table_for_source_type_model(self):
4236+
from google.cloud.bigquery.job import ExtractJob
4237+
4238+
JOB = "job_id"
4239+
SOURCE = "source_model"
4240+
DESTINATION = "gs://bucket_name/object_name"
4241+
RESOURCE = {
4242+
"jobReference": {"projectId": self.PROJECT, "jobId": JOB},
4243+
"configuration": {
4244+
"extract": {
4245+
"sourceModel": {
4246+
"projectId": self.PROJECT,
4247+
"datasetId": self.DS_ID,
4248+
"modelId": SOURCE,
4249+
},
4250+
"destinationUris": [DESTINATION],
4251+
}
4252+
},
4253+
}
4254+
creds = _make_credentials()
4255+
http = object()
4256+
client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)
4257+
conn = client._connection = make_connection(RESOURCE)
4258+
dataset = DatasetReference(self.PROJECT, self.DS_ID)
4259+
source = dataset.model(SOURCE)
4260+
4261+
job = client.extract_table(
4262+
source, DESTINATION, job_id=JOB, timeout=7.5, source_type="Model"
4263+
)
4264+
4265+
# Check that extract_table actually starts the job.
4266+
conn.api_request.assert_called_once_with(
4267+
method="POST", path="/projects/PROJECT/jobs", data=RESOURCE, timeout=7.5,
4268+
)
4269+
4270+
# Check the job resource.
4271+
self.assertIsInstance(job, ExtractJob)
4272+
self.assertIs(job._client, client)
4273+
self.assertEqual(job.job_id, JOB)
4274+
self.assertEqual(job.source, source)
4275+
self.assertEqual(list(job.destination_uris), [DESTINATION])
4276+
4277+
def test_extract_table_for_source_type_model_w_string_model_id(self):
4278+
JOB = "job_id"
4279+
source_id = "source_model"
4280+
DESTINATION = "gs://bucket_name/object_name"
4281+
RESOURCE = {
4282+
"jobReference": {"projectId": self.PROJECT, "jobId": JOB},
4283+
"configuration": {
4284+
"extract": {
4285+
"sourceModel": {
4286+
"projectId": self.PROJECT,
4287+
"datasetId": self.DS_ID,
4288+
"modelId": source_id,
4289+
},
4290+
"destinationUris": [DESTINATION],
4291+
}
4292+
},
4293+
}
4294+
creds = _make_credentials()
4295+
http = object()
4296+
client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)
4297+
conn = client._connection = make_connection(RESOURCE)
4298+
4299+
client.extract_table(
4300+
# Test with string for model ID.
4301+
"{}.{}".format(self.DS_ID, source_id),
4302+
DESTINATION,
4303+
job_id=JOB,
4304+
timeout=7.5,
4305+
source_type="Model",
4306+
)
4307+
4308+
# Check that extract_table actually starts the job.
4309+
conn.api_request.assert_called_once_with(
4310+
method="POST", path="/projects/PROJECT/jobs", data=RESOURCE, timeout=7.5,
4311+
)
4312+
4313+
def test_extract_table_for_source_type_model_w_model_object(self):
4314+
from google.cloud.bigquery.model import Model
4315+
4316+
JOB = "job_id"
4317+
DESTINATION = "gs://bucket_name/object_name"
4318+
model_id = "{}.{}.{}".format(self.PROJECT, self.DS_ID, self.MODEL_ID)
4319+
model = Model(model_id)
4320+
RESOURCE = {
4321+
"jobReference": {"projectId": self.PROJECT, "jobId": JOB},
4322+
"configuration": {
4323+
"extract": {
4324+
"sourceModel": {
4325+
"projectId": self.PROJECT,
4326+
"datasetId": self.DS_ID,
4327+
"modelId": self.MODEL_ID,
4328+
},
4329+
"destinationUris": [DESTINATION],
4330+
}
4331+
},
4332+
}
4333+
creds = _make_credentials()
4334+
http = object()
4335+
client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)
4336+
conn = client._connection = make_connection(RESOURCE)
4337+
4338+
client.extract_table(
4339+
# Test with Model class object.
4340+
model,
4341+
DESTINATION,
4342+
job_id=JOB,
4343+
timeout=7.5,
4344+
source_type="Model",
4345+
)
4346+
4347+
# Check that extract_table actually starts the job.
4348+
conn.api_request.assert_called_once_with(
4349+
method="POST", path="/projects/PROJECT/jobs", data=RESOURCE, timeout=7.5,
4350+
)
4351+
4352+
def test_extract_table_for_invalid_source_type_model(self):
4353+
JOB = "job_id"
4354+
SOURCE = "source_model"
4355+
DESTINATION = "gs://bucket_name/object_name"
4356+
creds = _make_credentials()
4357+
http = object()
4358+
client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)
4359+
dataset = DatasetReference(self.PROJECT, self.DS_ID)
4360+
source = dataset.model(SOURCE)
4361+
4362+
with self.assertRaises(ValueError) as exc:
4363+
client.extract_table(
4364+
source, DESTINATION, job_id=JOB, timeout=7.5, source_type="foo"
4365+
)
4366+
4367+
self.assertIn("Cannot pass", exc.exception.args[0])
4368+
42204369
def test_query_defaults(self):
42214370
from google.cloud.bigquery.job import QueryJob
42224371

tests/unit/test_job.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3176,10 +3176,16 @@ def _verifyResourceProperties(self, job, resource):
31763176

31773177
self.assertEqual(job.destination_uris, config["destinationUris"])
31783178

3179-
table_ref = config["sourceTable"]
3180-
self.assertEqual(job.source.project, table_ref["projectId"])
3181-
self.assertEqual(job.source.dataset_id, table_ref["datasetId"])
3182-
self.assertEqual(job.source.table_id, table_ref["tableId"])
3179+
if "sourceTable" in config:
3180+
table_ref = config["sourceTable"]
3181+
self.assertEqual(job.source.project, table_ref["projectId"])
3182+
self.assertEqual(job.source.dataset_id, table_ref["datasetId"])
3183+
self.assertEqual(job.source.table_id, table_ref["tableId"])
3184+
else:
3185+
model_ref = config["sourceModel"]
3186+
self.assertEqual(job.source.project, model_ref["projectId"])
3187+
self.assertEqual(job.source.dataset_id, model_ref["datasetId"])
3188+
self.assertEqual(job.source.model_id, model_ref["modelId"])
31833189

31843190
if "compression" in config:
31853191
self.assertEqual(job.compression, config["compression"])
@@ -3281,6 +3287,28 @@ def test_from_api_repr_bare(self):
32813287
self.assertIs(job._client, client)
32823288
self._verifyResourceProperties(job, RESOURCE)
32833289

3290+
def test_from_api_repr_for_model(self):
3291+
self._setUpConstants()
3292+
client = _make_client(project=self.PROJECT)
3293+
RESOURCE = {
3294+
"id": self.JOB_ID,
3295+
"jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID},
3296+
"configuration": {
3297+
"extract": {
3298+
"sourceModel": {
3299+
"projectId": self.PROJECT,
3300+
"datasetId": self.DS_ID,
3301+
"modelId": "model_id",
3302+
},
3303+
"destinationUris": [self.DESTINATION_URI],
3304+
}
3305+
},
3306+
}
3307+
klass = self._get_target_class()
3308+
job = klass.from_api_repr(RESOURCE, client=client)
3309+
self.assertIs(job._client, client)
3310+
self._verifyResourceProperties(job, RESOURCE)
3311+
32843312
def test_from_api_repr_w_properties(self):
32853313
from google.cloud.bigquery.job import Compression
32863314

0 commit comments

Comments
 (0)