Skip to content

Commit 1823cad

Browse files
feat: add mtls support to client (#492)
* feat: add mtls feature
1 parent 3138d41 commit 1823cad

File tree

6 files changed

+79
-12
lines changed

6 files changed

+79
-12
lines changed

google/cloud/bigquery/_http.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,42 @@
1414

1515
"""Create / interact with Google BigQuery connections."""
1616

17+
import os
18+
import pkg_resources
19+
1720
from google.cloud import _http
1821

1922
from google.cloud.bigquery import __version__
2023

2124

25+
# TODO: Increase the minimum version of google-cloud-core to 1.6.0
26+
# and remove this logic. See:
27+
# https://github.com/googleapis/python-bigquery/issues/509
28+
if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE") == "true": # pragma: NO COVER
29+
release = pkg_resources.get_distribution("google-cloud-core").parsed_version
30+
if release < pkg_resources.parse_version("1.6.0"):
31+
raise ImportError("google-cloud-core >= 1.6.0 is required to use mTLS feature")
32+
33+
2234
class Connection(_http.JSONConnection):
2335
"""A connection to Google BigQuery via the JSON REST API.
2436
2537
Args:
2638
client (google.cloud.bigquery.client.Client): The client that owns the current connection.
2739
2840
client_info (Optional[google.api_core.client_info.ClientInfo]): Instance used to generate user agent.
41+
42+
api_endpoint (str): The api_endpoint to use. If None, the library will decide what endpoint to use.
2943
"""
3044

3145
DEFAULT_API_ENDPOINT = "https://bigquery.googleapis.com"
46+
DEFAULT_API_MTLS_ENDPOINT = "https://bigquery.mtls.googleapis.com"
3247

33-
def __init__(self, client, client_info=None, api_endpoint=DEFAULT_API_ENDPOINT):
48+
def __init__(self, client, client_info=None, api_endpoint=None):
3449
super(Connection, self).__init__(client, client_info)
35-
self.API_BASE_URL = api_endpoint
50+
self.API_BASE_URL = api_endpoint or self.DEFAULT_API_ENDPOINT
51+
self.API_BASE_MTLS_URL = self.DEFAULT_API_MTLS_ENDPOINT
52+
self.ALLOW_AUTO_SWITCH_TO_MTLS_URL = api_endpoint is None
3653
self._client_info.gapic_version = __version__
3754
self._client_info.client_library_version = __version__
3855

google/cloud/bigquery/client.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,7 @@
7878
_DEFAULT_CHUNKSIZE = 1048576 # 1024 * 1024 B = 1 MB
7979
_MAX_MULTIPART_SIZE = 5 * 1024 * 1024
8080
_DEFAULT_NUM_RETRIES = 6
81-
_BASE_UPLOAD_TEMPLATE = (
82-
"https://bigquery.googleapis.com/upload/bigquery/v2/projects/"
83-
"{project}/jobs?uploadType="
84-
)
81+
_BASE_UPLOAD_TEMPLATE = "{host}/upload/bigquery/v2/projects/{project}/jobs?uploadType="
8582
_MULTIPART_URL_TEMPLATE = _BASE_UPLOAD_TEMPLATE + "multipart"
8683
_RESUMABLE_URL_TEMPLATE = _BASE_UPLOAD_TEMPLATE + "resumable"
8784
_GENERIC_CONTENT_TYPE = "*/*"
@@ -2547,7 +2544,15 @@ def _initiate_resumable_upload(
25472544

25482545
if project is None:
25492546
project = self.project
2550-
upload_url = _RESUMABLE_URL_TEMPLATE.format(project=project)
2547+
# TODO: Increase the minimum version of google-cloud-core to 1.6.0
2548+
# and remove this logic. See:
2549+
# https://github.com/googleapis/python-bigquery/issues/509
2550+
hostname = (
2551+
self._connection.API_BASE_URL
2552+
if not hasattr(self._connection, "get_api_base_url_for_mtls")
2553+
else self._connection.get_api_base_url_for_mtls()
2554+
)
2555+
upload_url = _RESUMABLE_URL_TEMPLATE.format(host=hostname, project=project)
25512556

25522557
# TODO: modify ResumableUpload to take a retry.Retry object
25532558
# that it can use for the initial RPC.
@@ -2616,7 +2621,15 @@ def _do_multipart_upload(
26162621
if project is None:
26172622
project = self.project
26182623

2619-
upload_url = _MULTIPART_URL_TEMPLATE.format(project=project)
2624+
# TODO: Increase the minimum version of google-cloud-core to 1.6.0
2625+
# and remove this logic. See:
2626+
# https://github.com/googleapis/python-bigquery/issues/509
2627+
hostname = (
2628+
self._connection.API_BASE_URL
2629+
if not hasattr(self._connection, "get_api_base_url_for_mtls")
2630+
else self._connection.get_api_base_url_for_mtls()
2631+
)
2632+
upload_url = _MULTIPART_URL_TEMPLATE.format(host=hostname, project=project)
26202633
upload = MultipartUpload(upload_url, headers=headers)
26212634

26222635
if num_retries is not None:

tests/system/test_client.py

+6
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import uuid
2929

3030
import psutil
31+
import pytest
3132
import pytz
3233
import pkg_resources
3334

@@ -132,6 +133,8 @@
132133
else:
133134
PYARROW_INSTALLED_VERSION = None
134135

136+
MTLS_TESTING = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE") == "true"
137+
135138

136139
def _has_rows(result):
137140
return len(result) > 0
@@ -2651,6 +2654,9 @@ def test_insert_rows_nested_nested_dictionary(self):
26512654
expected_rows = [("Some value", record)]
26522655
self.assertEqual(row_tuples, expected_rows)
26532656

2657+
@pytest.mark.skipif(
2658+
MTLS_TESTING, reason="mTLS testing has no permission to the max-value.js file"
2659+
)
26542660
def test_create_routine(self):
26552661
routine_name = "test_routine"
26562662
dataset = self.temp_dataset(_make_dataset_id("create_routine"))

tests/unit/helpers.py

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ def make_connection(*responses):
2121
mock_conn = mock.create_autospec(google.cloud.bigquery._http.Connection)
2222
mock_conn.user_agent = "testing 1.2.3"
2323
mock_conn.api_request.side_effect = list(responses) + [NotFound("miss")]
24+
mock_conn.API_BASE_URL = "https://bigquery.googleapis.com"
25+
mock_conn.get_api_base_url_for_mtls = mock.Mock(return_value=mock_conn.API_BASE_URL)
2426
return mock_conn
2527

2628

tests/unit/test__http.py

+14
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ def _get_target_class():
3232
return Connection
3333

3434
def _make_one(self, *args, **kw):
35+
if "api_endpoint" not in kw:
36+
kw["api_endpoint"] = "https://bigquery.googleapis.com"
37+
3538
return self._get_target_class()(*args, **kw)
3639

3740
def test_build_api_url_no_extra_query_params(self):
@@ -138,3 +141,14 @@ def test_extra_headers_replace(self):
138141
url=expected_uri,
139142
timeout=self._get_default_timeout(),
140143
)
144+
145+
def test_ctor_mtls(self):
146+
conn = self._make_one(object(), api_endpoint=None)
147+
self.assertEqual(conn.ALLOW_AUTO_SWITCH_TO_MTLS_URL, True)
148+
self.assertEqual(conn.API_BASE_URL, "https://bigquery.googleapis.com")
149+
self.assertEqual(conn.API_BASE_MTLS_URL, "https://bigquery.mtls.googleapis.com")
150+
151+
conn = self._make_one(object(), api_endpoint="http://foo")
152+
self.assertEqual(conn.ALLOW_AUTO_SWITCH_TO_MTLS_URL, False)
153+
self.assertEqual(conn.API_BASE_URL, "http://foo")
154+
self.assertEqual(conn.API_BASE_MTLS_URL, "https://bigquery.mtls.googleapis.com")

tests/unit/test_client.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -2057,6 +2057,7 @@ def test_get_table_sets_user_agent(self):
20572057
url=mock.ANY, method=mock.ANY, headers=mock.ANY, data=mock.ANY
20582058
)
20592059
http.reset_mock()
2060+
http.is_mtls = False
20602061
mock_response.status_code = 200
20612062
mock_response.json.return_value = self._make_table_resource()
20622063
user_agent_override = client_info.ClientInfo(user_agent="my-application/1.2.3")
@@ -4425,7 +4426,7 @@ def _mock_transport(self, status_code, headers, content=b""):
44254426
fake_transport.request.return_value = fake_response
44264427
return fake_transport
44274428

4428-
def _initiate_resumable_upload_helper(self, num_retries=None):
4429+
def _initiate_resumable_upload_helper(self, num_retries=None, mtls=False):
44294430
from google.resumable_media.requests import ResumableUpload
44304431
from google.cloud.bigquery.client import _DEFAULT_CHUNKSIZE
44314432
from google.cloud.bigquery.client import _GENERIC_CONTENT_TYPE
@@ -4440,6 +4441,8 @@ def _initiate_resumable_upload_helper(self, num_retries=None):
44404441
fake_transport = self._mock_transport(http.client.OK, response_headers)
44414442
client = self._make_one(project=self.PROJECT, _http=fake_transport)
44424443
conn = client._connection = make_connection()
4444+
if mtls:
4445+
conn.get_api_base_url_for_mtls = mock.Mock(return_value="https://foo.mtls")
44434446

44444447
# Create some mock arguments and call the method under test.
44454448
data = b"goodbye gudbi gootbee"
@@ -4454,8 +4457,10 @@ def _initiate_resumable_upload_helper(self, num_retries=None):
44544457

44554458
# Check the returned values.
44564459
self.assertIsInstance(upload, ResumableUpload)
4460+
4461+
host_name = "https://foo.mtls" if mtls else "https://bigquery.googleapis.com"
44574462
upload_url = (
4458-
f"https://bigquery.googleapis.com/upload/bigquery/v2/projects/{self.PROJECT}"
4463+
f"{host_name}/upload/bigquery/v2/projects/{self.PROJECT}"
44594464
"/jobs?uploadType=resumable"
44604465
)
44614466
self.assertEqual(upload.upload_url, upload_url)
@@ -4494,11 +4499,14 @@ def _initiate_resumable_upload_helper(self, num_retries=None):
44944499
def test__initiate_resumable_upload(self):
44954500
self._initiate_resumable_upload_helper()
44964501

4502+
def test__initiate_resumable_upload_mtls(self):
4503+
self._initiate_resumable_upload_helper(mtls=True)
4504+
44974505
def test__initiate_resumable_upload_with_retry(self):
44984506
self._initiate_resumable_upload_helper(num_retries=11)
44994507

45004508
def _do_multipart_upload_success_helper(
4501-
self, get_boundary, num_retries=None, project=None
4509+
self, get_boundary, num_retries=None, project=None, mtls=False
45024510
):
45034511
from google.cloud.bigquery.client import _get_upload_headers
45044512
from google.cloud.bigquery.job import LoadJob
@@ -4508,6 +4516,8 @@ def _do_multipart_upload_success_helper(
45084516
fake_transport = self._mock_transport(http.client.OK, {})
45094517
client = self._make_one(project=self.PROJECT, _http=fake_transport)
45104518
conn = client._connection = make_connection()
4519+
if mtls:
4520+
conn.get_api_base_url_for_mtls = mock.Mock(return_value="https://foo.mtls")
45114521

45124522
if project is None:
45134523
project = self.PROJECT
@@ -4530,8 +4540,9 @@ def _do_multipart_upload_success_helper(
45304540
self.assertEqual(stream.tell(), size)
45314541
get_boundary.assert_called_once_with()
45324542

4543+
host_name = "https://foo.mtls" if mtls else "https://bigquery.googleapis.com"
45334544
upload_url = (
4534-
f"https://bigquery.googleapis.com/upload/bigquery/v2/projects/{project}"
4545+
f"{host_name}/upload/bigquery/v2/projects/{project}"
45354546
"/jobs?uploadType=multipart"
45364547
)
45374548
payload = (
@@ -4556,6 +4567,10 @@ def _do_multipart_upload_success_helper(
45564567
def test__do_multipart_upload(self, get_boundary):
45574568
self._do_multipart_upload_success_helper(get_boundary)
45584569

4570+
@mock.patch("google.resumable_media._upload.get_boundary", return_value=b"==0==")
4571+
def test__do_multipart_upload_mtls(self, get_boundary):
4572+
self._do_multipart_upload_success_helper(get_boundary, mtls=True)
4573+
45594574
@mock.patch("google.resumable_media._upload.get_boundary", return_value=b"==0==")
45604575
def test__do_multipart_upload_with_retry(self, get_boundary):
45614576
self._do_multipart_upload_success_helper(get_boundary, num_retries=8)

0 commit comments

Comments
 (0)