41
41
except ImportError : # pragma: NO COVER
42
42
from collections import Mapping # type: ignore
43
43
import abc
44
+ import base64
44
45
import json
45
46
import os
46
47
from typing import NamedTuple
@@ -145,9 +146,88 @@ def get_subject_token(self, context, request):
145
146
class _X509Supplier (SubjectTokenSupplier ):
146
147
"""Internal supplier for X509 workload credentials. This class is used internally and always returns an empty string as the subject token."""
147
148
149
+ def __init__ (self , trust_chain_path , leaf_cert_callback ):
150
+ self ._trust_chain_path = trust_chain_path
151
+ self ._leaf_cert_callback = leaf_cert_callback
152
+
148
153
@_helpers .copy_docstring (SubjectTokenSupplier )
149
154
def get_subject_token (self , context , request ):
150
- return ""
155
+ # Import OpennSSL inline because it is an extra import only required by customers
156
+ # using mTLS.
157
+ from OpenSSL import crypto
158
+
159
+ leaf_cert = crypto .load_certificate (
160
+ crypto .FILETYPE_PEM , self ._leaf_cert_callback ()
161
+ )
162
+ trust_chain = self ._read_trust_chain ()
163
+ cert_chain = []
164
+
165
+ cert_chain .append (_X509Supplier ._encode_cert (leaf_cert ))
166
+
167
+ if trust_chain is None or len (trust_chain ) == 0 :
168
+ return json .dumps (cert_chain )
169
+
170
+ # Append the first cert if it is not the leaf cert.
171
+ first_cert = _X509Supplier ._encode_cert (trust_chain [0 ])
172
+ if first_cert != cert_chain [0 ]:
173
+ cert_chain .append (first_cert )
174
+
175
+ for i in range (1 , len (trust_chain )):
176
+ encoded = _X509Supplier ._encode_cert (trust_chain [i ])
177
+ # Check if the current cert is the leaf cert and raise an exception if it is.
178
+ if encoded == cert_chain [0 ]:
179
+ raise exceptions .RefreshError (
180
+ "The leaf certificate must be at the top of the trust chain file"
181
+ )
182
+ else :
183
+ cert_chain .append (encoded )
184
+ return json .dumps (cert_chain )
185
+
186
+ def _read_trust_chain (self ):
187
+ # Import OpennSSL inline because it is an extra import only required by customers
188
+ # using mTLS.
189
+ from OpenSSL import crypto
190
+
191
+ certificate_trust_chain = []
192
+ # If no trust chain path was provided, return an empty list.
193
+ if self ._trust_chain_path is None or self ._trust_chain_path == "" :
194
+ return certificate_trust_chain
195
+ try :
196
+ # Open the trust chain file.
197
+ with open (self ._trust_chain_path , "rb" ) as f :
198
+ trust_chain_data = f .read ()
199
+ # Split PEM data into individual certificates.
200
+ cert_blocks = trust_chain_data .split (b"-----BEGIN CERTIFICATE-----" )
201
+ for cert_block in cert_blocks :
202
+ # Skip empty blocks.
203
+ if cert_block .strip ():
204
+ cert_data = b"-----BEGIN CERTIFICATE-----" + cert_block
205
+ try :
206
+ # Load each certificate and add it to the trust chain.
207
+ cert = crypto .load_certificate (
208
+ crypto .FILETYPE_PEM , cert_data
209
+ )
210
+ certificate_trust_chain .append (cert )
211
+ except Exception as e :
212
+ raise exceptions .RefreshError (
213
+ "Error loading PEM certificates from the trust chain file '{}'" .format (
214
+ self ._trust_chain_path
215
+ )
216
+ ) from e
217
+ return certificate_trust_chain
218
+ except FileNotFoundError :
219
+ raise exceptions .RefreshError (
220
+ "Trust chain file '{}' was not found." .format (self ._trust_chain_path )
221
+ )
222
+
223
+ def _encode_cert (cert ):
224
+ # Import OpennSSL inline because it is an extra import only required by customers
225
+ # using mTLS.
226
+ from OpenSSL import crypto
227
+
228
+ return base64 .b64encode (
229
+ crypto .dump_certificate (crypto .FILETYPE_ASN1 , cert )
230
+ ).decode ("utf-8" )
151
231
152
232
153
233
def _parse_token_data (token_content , format_type = "text" , subject_token_field_name = None ):
@@ -296,7 +376,9 @@ def __init__(
296
376
self ._credential_source_headers ,
297
377
)
298
378
else : # self._credential_source_certificate
299
- self ._subject_token_supplier = _X509Supplier ()
379
+ self ._subject_token_supplier = _X509Supplier (
380
+ self ._trust_chain_path , self ._get_cert_bytes
381
+ )
300
382
301
383
@_helpers .copy_docstring (external_account .Credentials )
302
384
def retrieve_subject_token (self , request ):
@@ -314,6 +396,10 @@ def _get_mtls_cert_and_key_paths(self):
314
396
self ._certificate_config_location
315
397
)
316
398
399
+ def _get_cert_bytes (self ):
400
+ cert_path , _ = self ._get_mtls_cert_and_key_paths ()
401
+ return _mtls_helper ._read_cert_file (cert_path )
402
+
317
403
def _mtls_required (self ):
318
404
return self ._credential_source_certificate is not None
319
405
@@ -350,6 +436,9 @@ def _validate_certificate_config(self):
350
436
use_default = self ._credential_source_certificate .get (
351
437
"use_default_certificate_config"
352
438
)
439
+ self ._trust_chain_path = self ._credential_source_certificate .get (
440
+ "trust_chain_path"
441
+ )
353
442
if self ._certificate_config_location and use_default :
354
443
raise exceptions .MalformedError (
355
444
"Invalid certificate configuration, certificate_config_location cannot be specified when use_default_certificate_config = true."
0 commit comments