Skip to content

Commit da0ee3e

Browse files
fix: expose ssl credentials from transport (#677)
Expose ssl credentials from transport. This is used to fix pubsub client [mtls issue](googleapis/python-pubsub#224). Pubsub client creates its own transport so mtls is completely missing. The solution would be taking the ssl credentials from the auto-generated client's transport and passing it when the handwritten client creates the transport.
1 parent 0fe9330 commit da0ee3e

File tree

5 files changed

+17
-0
lines changed

5 files changed

+17
-0
lines changed

gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/grpc.py.j2

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,16 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
8888
google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
8989
creation failed for any reason.
9090
"""
91+
self._ssl_channel_credentials = ssl_channel_credentials
92+
9193
if channel:
9294
# Sanity check: Ensure that channel and credentials are not both
9395
# provided.
9496
credentials = False
9597

9698
# If a channel was explicitly provided, set it.
9799
self._grpc_channel = channel
100+
self._ssl_channel_credentials = None
98101
elif api_mtls_endpoint:
99102
warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning)
100103

@@ -122,6 +125,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
122125
scopes=scopes or self.AUTH_SCOPES,
123126
quota_project_id=quota_project_id,
124127
)
128+
self._ssl_channel_credentials = ssl_credentials
125129
else:
126130
host = host if ":" in host else host + ":443"
127131

gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,7 @@ def test_{{ service.name|snake_case }}_grpc_transport_channel():
708708
)
709709
assert transport.grpc_channel == channel
710710
assert transport._host == "squid.clam.whelk:443"
711+
assert transport._ssl_channel_credentials == None
711712

712713

713714
@pytest.mark.parametrize("transport_class", [transports.{{ service.grpc_transport_name }}])
@@ -749,6 +750,7 @@ def test_{{ service.name|snake_case }}_transport_channel_mtls_with_client_cert_s
749750
quota_project_id=None,
750751
)
751752
assert transport.grpc_channel == mock_grpc_channel
753+
assert transport._ssl_channel_credentials == mock_ssl_cred
752754

753755

754756
@pytest.mark.parametrize("transport_class", [transports.{{ service.grpc_transport_name }},])

gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,16 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
9696
google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
9797
and ``credentials_file`` are passed.
9898
"""
99+
self._ssl_channel_credentials = ssl_channel_credentials
100+
99101
if channel:
100102
# Sanity check: Ensure that channel and credentials are not both
101103
# provided.
102104
credentials = False
103105

104106
# If a channel was explicitly provided, set it.
105107
self._grpc_channel = channel
108+
self._ssl_channel_credentials = None
106109
elif api_mtls_endpoint:
107110
warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning)
108111

@@ -130,6 +133,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
130133
scopes=scopes or self.AUTH_SCOPES,
131134
quota_project_id=quota_project_id,
132135
)
136+
self._ssl_channel_credentials = ssl_credentials
133137
else:
134138
host = host if ":" in host else host + ":443"
135139

gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc_asyncio.py.j2

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,16 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
140140
google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
141141
and ``credentials_file`` are passed.
142142
"""
143+
self._ssl_channel_credentials = ssl_channel_credentials
144+
143145
if channel:
144146
# Sanity check: Ensure that channel and credentials are not both
145147
# provided.
146148
credentials = False
147149

148150
# If a channel was explicitly provided, set it.
149151
self._grpc_channel = channel
152+
self._ssl_channel_credentials = None
150153
elif api_mtls_endpoint:
151154
warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning)
152155

@@ -174,6 +177,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
174177
scopes=scopes or self.AUTH_SCOPES,
175178
quota_project_id=quota_project_id,
176179
)
180+
self._ssl_channel_credentials = ssl_credentials
177181
else:
178182
host = host if ":" in host else host + ":443"
179183

gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,6 +1184,7 @@ def test_{{ service.name|snake_case }}_grpc_transport_channel():
11841184
)
11851185
assert transport.grpc_channel == channel
11861186
assert transport._host == "squid.clam.whelk:443"
1187+
assert transport._ssl_channel_credentials == None
11871188

11881189

11891190
def test_{{ service.name|snake_case }}_grpc_asyncio_transport_channel():
@@ -1196,6 +1197,7 @@ def test_{{ service.name|snake_case }}_grpc_asyncio_transport_channel():
11961197
)
11971198
assert transport.grpc_channel == channel
11981199
assert transport._host == "squid.clam.whelk:443"
1200+
assert transport._ssl_channel_credentials == None
11991201

12001202

12011203
@pytest.mark.parametrize("transport_class", [transports.{{ service.grpc_transport_name }}, transports.{{ service.grpc_asyncio_transport_name }}])
@@ -1237,6 +1239,7 @@ def test_{{ service.name|snake_case }}_transport_channel_mtls_with_client_cert_s
12371239
quota_project_id=None,
12381240
)
12391241
assert transport.grpc_channel == mock_grpc_channel
1242+
assert transport._ssl_channel_credentials == mock_ssl_cred
12401243

12411244

12421245
@pytest.mark.parametrize("transport_class", [transports.{{ service.grpc_transport_name }}, transports.{{ service.grpc_asyncio_transport_name }}])

0 commit comments

Comments
 (0)