@@ -18,7 +18,7 @@ import textwrap
18
18
19
19
from packaging import version
20
20
from urllib .parse import urlparse
21
- from typing import Optional
21
+ from typing import Iterable , Optional
22
22
23
23
PARENT_FOLDER = os .path .realpath (os .path .join (os .path .dirname (__file__ ), ".." ))
24
24
if os .path .isdir (os .path .join (PARENT_FOLDER , ".venv" )):
@@ -31,14 +31,30 @@ DRY_RUN = str(os.environ.get("DRY_RUN")).strip().lower() in ["1", "true"]
31
31
DEFAULT_REGION = "us-east-1"
32
32
DEFAULT_ACCESS_KEY = "test"
33
33
AWS_ENDPOINT_URL = os .environ .get ("AWS_ENDPOINT_URL" )
34
- CUSTOMIZE_ACCESS_KEY = str (os .environ .get ("CUSTOMIZE_ACCESS_KEY" )).strip ().lower () in ["1" , "true" ]
34
+ CUSTOMIZE_ACCESS_KEY = str (os .environ .get ("CUSTOMIZE_ACCESS_KEY" )).strip ().lower () in [
35
+ "1" ,
36
+ "true" ,
37
+ ]
35
38
LOCALHOST_HOSTNAME = "localhost.localstack.cloud"
36
39
S3_HOSTNAME = os .environ .get ("S3_HOSTNAME" ) or f"s3.{ LOCALHOST_HOSTNAME } "
37
40
USE_EXEC = str (os .environ .get ("USE_EXEC" )).strip ().lower () in ["1" , "true" ]
38
41
TF_CMD = os .environ .get ("TF_CMD" ) or "terraform"
39
- TF_UNPROXIED_CMDS = os .environ .get ("TF_UNPROXIED_CMDS" ).split (sep = "," ) if os .environ .get ("TF_UNPROXIED_CMDS" ) else ("fmt" , "validate" , "version" )
40
- LS_PROVIDERS_FILE = os .environ .get ("LS_PROVIDERS_FILE" ) or "localstack_providers_override.tf"
41
- LOCALSTACK_HOSTNAME = urlparse (AWS_ENDPOINT_URL ).hostname or os .environ .get ("LOCALSTACK_HOSTNAME" ) or "localhost"
42
+ ADDITIONAL_TF_OVERRIDE_LOCATIONS = os .environ .get (
43
+ "ADDITIONAL_TF_OVERRIDE_LOCATIONS" , default = ""
44
+ )
45
+ TF_UNPROXIED_CMDS = (
46
+ os .environ .get ("TF_UNPROXIED_CMDS" ).split (sep = "," )
47
+ if os .environ .get ("TF_UNPROXIED_CMDS" )
48
+ else ("fmt" , "validate" , "version" )
49
+ )
50
+ LS_PROVIDERS_FILE = (
51
+ os .environ .get ("LS_PROVIDERS_FILE" ) or "localstack_providers_override.tf"
52
+ )
53
+ LOCALSTACK_HOSTNAME = (
54
+ urlparse (AWS_ENDPOINT_URL ).hostname
55
+ or os .environ .get ("LOCALSTACK_HOSTNAME" )
56
+ or "localhost"
57
+ )
42
58
EDGE_PORT = int (urlparse (AWS_ENDPOINT_URL ).port or os .environ .get ("EDGE_PORT" ) or 4566 )
43
59
TF_VERSION : Optional [version .Version ] = None
44
60
TF_PROVIDER_CONFIG = """
@@ -133,11 +149,19 @@ SERVICE_REPLACEMENTS = {
133
149
# CONFIG GENERATION UTILS
134
150
# ---
135
151
136
- def create_provider_config_file (provider_aliases = None ):
152
+
153
+ def create_provider_config_file (provider_file_path : str , provider_aliases = None ) -> None :
137
154
provider_aliases = provider_aliases or []
138
155
139
156
# Force service alias replacements
140
- SERVICE_REPLACEMENTS .update ({alias : alias_pairs [0 ] for alias_pairs in SERVICE_ALIASES for alias in alias_pairs if alias != alias_pairs [0 ]})
157
+ SERVICE_REPLACEMENTS .update (
158
+ {
159
+ alias : alias_pairs [0 ]
160
+ for alias_pairs in SERVICE_ALIASES
161
+ for alias in alias_pairs
162
+ if alias != alias_pairs [0 ]
163
+ }
164
+ )
141
165
142
166
# create list of service names
143
167
services = list (config .get_service_ports ())
@@ -162,9 +186,11 @@ def create_provider_config_file(provider_aliases=None):
162
186
for provider in provider_aliases :
163
187
provider_config = TF_PROVIDER_CONFIG .replace (
164
188
"<access_key>" ,
165
- get_access_key (provider ) if CUSTOMIZE_ACCESS_KEY else DEFAULT_ACCESS_KEY
189
+ get_access_key (provider ) if CUSTOMIZE_ACCESS_KEY else DEFAULT_ACCESS_KEY ,
190
+ )
191
+ endpoints = "\n " .join (
192
+ [f' { s } = "{ get_service_endpoint (s )} "' for s in services ]
166
193
)
167
- endpoints = "\n " .join ([f' { s } = "{ get_service_endpoint (s )} "' for s in services ])
168
194
provider_config = provider_config .replace ("<endpoints>" , endpoints )
169
195
additional_configs = []
170
196
if use_s3_path_style ():
@@ -178,7 +204,9 @@ def create_provider_config_file(provider_aliases=None):
178
204
if isinstance (region , list ):
179
205
region = region [0 ]
180
206
additional_configs += [f'region = "{ region } "' ]
181
- provider_config = provider_config .replace ("<configs>" , "\n " .join (additional_configs ))
207
+ provider_config = provider_config .replace (
208
+ "<configs>" , "\n " .join (additional_configs )
209
+ )
182
210
provider_configs .append (provider_config )
183
211
184
212
# construct final config file content
@@ -188,10 +216,7 @@ def create_provider_config_file(provider_aliases=None):
188
216
tf_config += generate_s3_backend_config ()
189
217
190
218
# write temporary config file
191
- providers_file = get_providers_file_path ()
192
- write_provider_config_file (providers_file , tf_config )
193
-
194
- return providers_file
219
+ write_provider_config_file (provider_file_path , tf_config )
195
220
196
221
197
222
def write_provider_config_file (providers_file , tf_config ):
@@ -200,12 +225,18 @@ def write_provider_config_file(providers_file, tf_config):
200
225
fp .write (tf_config )
201
226
202
227
203
- def get_providers_file_path () -> str :
204
- """Determine the path under which the providers override file should be stored"""
228
+ def get_default_provider_folder_path () -> str :
229
+ """Determine the folder under which the providers override file should be stored"""
205
230
chdir = [arg for arg in sys .argv if arg .startswith ("-chdir=" )]
206
231
base_dir = "."
207
232
if chdir :
208
233
base_dir = chdir [0 ].removeprefix ("-chdir=" )
234
+
235
+ return os .path .abspath (base_dir )
236
+
237
+
238
+ def get_providers_file_path (base_dir ) -> str :
239
+ """Retrieve the path under which the providers override file should be stored"""
209
240
return os .path .join (base_dir , LS_PROVIDERS_FILE )
210
241
211
242
@@ -217,7 +248,11 @@ def determine_provider_aliases() -> list:
217
248
for _file , obj in tf_files .items ():
218
249
try :
219
250
providers = ensure_list (obj .get ("provider" , []))
220
- aws_providers = [prov ["aws" ] for prov in providers if prov .get ("aws" ) and prov .get ("aws" ).get ("alias" ) not in skipped ]
251
+ aws_providers = [
252
+ prov ["aws" ]
253
+ for prov in providers
254
+ if prov .get ("aws" ) and prov .get ("aws" ).get ("alias" ) not in skipped
255
+ ]
221
256
result .extend (aws_providers )
222
257
except Exception as e :
223
258
print (f"Warning: Unable to extract providers from { _file } :" , e )
@@ -258,7 +293,6 @@ def generate_s3_backend_config() -> str:
258
293
"skip_credentials_validation" : True ,
259
294
"skip_metadata_api_check" : True ,
260
295
"secret_key" : "test" ,
261
-
262
296
"endpoints" : {
263
297
"s3" : get_service_endpoint ("s3" ),
264
298
"iam" : get_service_endpoint ("iam" ),
@@ -269,23 +303,37 @@ def generate_s3_backend_config() -> str:
269
303
}
270
304
# Merge in legacy endpoint configs if not existing already
271
305
if is_tf_legacy and backend_config .get ("endpoints" ):
272
- print ("Warning: Unsupported backend option(s) detected (`endpoints`). Please make sure you always use the corresponding options to your Terraform version." )
306
+ print (
307
+ "Warning: Unsupported backend option(s) detected (`endpoints`). Please make sure you always use the corresponding options to your Terraform version."
308
+ )
273
309
exit (1 )
274
310
for legacy_endpoint , endpoint in legacy_endpoint_mappings .items ():
275
- if legacy_endpoint in backend_config and backend_config .get ("endpoints" ) and endpoint in backend_config ["endpoints" ]:
311
+ if (
312
+ legacy_endpoint in backend_config
313
+ and backend_config .get ("endpoints" )
314
+ and endpoint in backend_config ["endpoints" ]
315
+ ):
276
316
del backend_config [legacy_endpoint ]
277
317
continue
278
- if legacy_endpoint in backend_config and (not backend_config .get ("endpoints" ) or endpoint not in backend_config ["endpoints" ]):
318
+ if legacy_endpoint in backend_config and (
319
+ not backend_config .get ("endpoints" )
320
+ or endpoint not in backend_config ["endpoints" ]
321
+ ):
279
322
if not backend_config .get ("endpoints" ):
280
323
backend_config ["endpoints" ] = {}
281
- backend_config ["endpoints" ].update ({endpoint : backend_config [legacy_endpoint ]})
324
+ backend_config ["endpoints" ].update (
325
+ {endpoint : backend_config [legacy_endpoint ]}
326
+ )
282
327
del backend_config [legacy_endpoint ]
283
328
# Add any missing default endpoints
284
329
if backend_config .get ("endpoints" ):
285
330
backend_config ["endpoints" ] = {
286
331
k : backend_config ["endpoints" ].get (k ) or v
287
- for k , v in configs ["endpoints" ].items ()}
288
- backend_config ["access_key" ] = get_access_key (backend_config ) if CUSTOMIZE_ACCESS_KEY else DEFAULT_ACCESS_KEY
332
+ for k , v in configs ["endpoints" ].items ()
333
+ }
334
+ backend_config ["access_key" ] = (
335
+ get_access_key (backend_config ) if CUSTOMIZE_ACCESS_KEY else DEFAULT_ACCESS_KEY
336
+ )
289
337
configs .update (backend_config )
290
338
if not DRY_RUN :
291
339
get_or_create_bucket (configs ["bucket" ])
@@ -298,22 +346,27 @@ def generate_s3_backend_config() -> str:
298
346
elif isinstance (value , dict ):
299
347
if key == "endpoints" and is_tf_legacy :
300
348
for legacy_endpoint , endpoint in legacy_endpoint_mappings .items ():
301
- config_options += f'\n { legacy_endpoint } = "{ configs [key ][endpoint ]} "'
349
+ config_options += (
350
+ f'\n { legacy_endpoint } = "{ configs [key ][endpoint ]} "'
351
+ )
302
352
continue
303
353
else :
304
354
value = textwrap .indent (
305
- text = f"{ key } = {{\n " + "\n " .join ([f' { k } = "{ v } "' for k , v in value .items ()]) + "\n }" ,
306
- prefix = " " * 4 )
355
+ text = f"{ key } = {{\n "
356
+ + "\n " .join ([f' { k } = "{ v } "' for k , v in value .items ()])
357
+ + "\n }" ,
358
+ prefix = " " * 4 ,
359
+ )
307
360
config_options += f"\n { value } "
308
361
continue
309
362
elif isinstance (value , list ):
310
363
# TODO this will break if it's a list of dicts or other complex object
311
364
# this serialization logic should probably be moved to a separate recursive function
312
365
as_string = [f'"{ item } "' for item in value ]
313
- value = f'[ { ", " .join (as_string )} ]'
366
+ value = f"[ { ', ' .join (as_string )} ]"
314
367
else :
315
368
value = f'"{ str (value )} "'
316
- config_options += f' \n { key } = { value } '
369
+ config_options += f" \n { key } = { value } "
317
370
result = result .replace ("<configs>" , config_options )
318
371
return result
319
372
@@ -337,6 +390,7 @@ def check_override_file(providers_file: str) -> None:
337
390
# AWS CLIENT UTILS
338
391
# ---
339
392
393
+
340
394
def use_s3_path_style () -> bool :
341
395
"""
342
396
Whether to use S3 path addressing (depending on the configured S3 endpoint)
@@ -361,6 +415,7 @@ def get_region() -> str:
361
415
# Note that boto3 is currently not included in the dependencies, to
362
416
# keep the library lightweight.
363
417
import boto3
418
+
364
419
region = boto3 .session .Session ().region_name
365
420
except Exception :
366
421
pass
@@ -369,7 +424,9 @@ def get_region() -> str:
369
424
370
425
371
426
def get_access_key (provider : dict ) -> str :
372
- access_key = str (os .environ .get ("AWS_ACCESS_KEY_ID" ) or provider .get ("access_key" , "" )).strip ()
427
+ access_key = str (
428
+ os .environ .get ("AWS_ACCESS_KEY_ID" ) or provider .get ("access_key" , "" )
429
+ ).strip ()
373
430
if access_key and access_key != DEFAULT_ACCESS_KEY :
374
431
# Change live access key to mocked one
375
432
return deactivate_access_key (access_key )
@@ -378,6 +435,7 @@ def get_access_key(provider: dict) -> str:
378
435
# Note that boto3 is currently not included in the dependencies, to
379
436
# keep the library lightweight.
380
437
import boto3
438
+
381
439
access_key = boto3 .session .Session ().get_credentials ().access_key
382
440
except Exception :
383
441
pass
@@ -387,7 +445,7 @@ def get_access_key(provider: dict) -> str:
387
445
388
446
def deactivate_access_key (access_key : str ) -> str :
389
447
"""Safe guarding user from accidental live credential usage by deactivating access key IDs.
390
- See more: https://docs.localstack.cloud/references/credentials/"""
448
+ See more: https://docs.localstack.cloud/references/credentials/"""
391
449
return "L" + access_key [1 :] if access_key [0 ] == "A" else access_key
392
450
393
451
@@ -413,10 +471,14 @@ def get_service_endpoint(service: str) -> str:
413
471
414
472
def connect_to_service (service : str , region : str = None ):
415
473
import boto3
474
+
416
475
region = region or get_region ()
417
476
return boto3 .client (
418
- service , endpoint_url = get_service_endpoint (service ), region_name = region ,
419
- aws_access_key_id = "test" , aws_secret_access_key = "test" ,
477
+ service ,
478
+ endpoint_url = get_service_endpoint (service ),
479
+ region_name = region ,
480
+ aws_access_key_id = "test" ,
481
+ aws_secret_access_key = "test" ,
420
482
)
421
483
422
484
@@ -440,9 +502,10 @@ def get_or_create_ddb_table(table_name: str, region: str = None):
440
502
return ddb_client .describe_table (TableName = table_name )
441
503
except Exception :
442
504
return ddb_client .create_table (
443
- TableName = table_name , BillingMode = "PAY_PER_REQUEST" ,
505
+ TableName = table_name ,
506
+ BillingMode = "PAY_PER_REQUEST" ,
444
507
KeySchema = [{"AttributeName" : "LockID" , "KeyType" : "HASH" }],
445
- AttributeDefinitions = [{"AttributeName" : "LockID" , "AttributeType" : "S" }]
508
+ AttributeDefinitions = [{"AttributeName" : "LockID" , "AttributeType" : "S" }],
446
509
)
447
510
448
511
@@ -469,13 +532,15 @@ def parse_tf_files() -> dict:
469
532
470
533
def get_tf_version (env ):
471
534
global TF_VERSION
472
- output = subprocess .run ([f"{ TF_CMD } " , "version" , "-json" ], env = env , check = True , capture_output = True ).stdout .decode ("utf-8" )
535
+ output = subprocess .run (
536
+ [f"{ TF_CMD } " , "version" , "-json" ], env = env , check = True , capture_output = True
537
+ ).stdout .decode ("utf-8" )
473
538
TF_VERSION = version .parse (json .loads (output )["terraform_version" ])
474
539
475
540
476
541
def run_tf_exec (cmd , env ):
477
542
"""Run terraform using os.exec - can be useful as it does not require any I/O
478
- handling for stdin/out/err. Does *not* allow us to perform any cleanup logic."""
543
+ handling for stdin/out/err. Does *not* allow us to perform any cleanup logic."""
479
544
os .execvpe (cmd [0 ], cmd , env = env )
480
545
481
546
@@ -485,18 +550,41 @@ def run_tf_subprocess(cmd, env):
485
550
486
551
# register signal handlers
487
552
import signal
553
+
488
554
signal .signal (signal .SIGINT , signal_handler )
489
555
490
556
PROCESS = subprocess .Popen (
491
- cmd , stdin = sys .stdin , stdout = sys .stdout , stderr = sys .stdout )
557
+ cmd , stdin = sys .stdin , stdout = sys .stdout , stderr = sys .stdout
558
+ )
492
559
PROCESS .communicate ()
493
560
sys .exit (PROCESS .returncode )
494
561
495
562
563
+ def cleanup_override_files (override_files : Iterable [str ]):
564
+ for file_path in override_files :
565
+ try :
566
+ os .remove (file_path )
567
+ except Exception :
568
+ print (
569
+ f"Count not clean up '{ file_path } '. This is not normally a problem but you can delete this file manually."
570
+ )
571
+
572
+
573
+ def get_folder_paths_that_require_an_override_file () -> Iterable [str ]:
574
+ if not is_override_needed (sys .argv [1 :]):
575
+ return
576
+
577
+ yield get_default_provider_folder_path ()
578
+ for path in ADDITIONAL_TF_OVERRIDE_LOCATIONS .split (sep = "," ):
579
+ if path .strip ():
580
+ yield path
581
+
582
+
496
583
# ---
497
584
# UTIL FUNCTIONS
498
585
# ---
499
586
587
+
500
588
def signal_handler (sig , frame ):
501
589
PROCESS .send_signal (sig )
502
590
@@ -517,6 +605,7 @@ def to_str(obj) -> bytes:
517
605
# MAIN ENTRYPOINT
518
606
# ---
519
607
608
+
520
609
def main ():
521
610
env = dict (os .environ )
522
611
cmd = [TF_CMD ] + sys .argv [1 :]
@@ -529,26 +618,25 @@ def main():
529
618
print (f"Unable to determine version. See error message for details: { e } " )
530
619
exit (1 )
531
620
532
- if is_override_needed (sys .argv [1 :]):
533
- check_override_file (get_providers_file_path ())
621
+ config_override_files = []
622
+
623
+ for folder_path in get_folder_paths_that_require_an_override_file ():
624
+ config_file_path = get_providers_file_path (folder_path )
625
+ check_override_file (config_file_path )
534
626
535
- # create TF provider config file
536
627
providers = determine_provider_aliases ()
537
- config_file = create_provider_config_file (providers )
538
- else :
539
- config_file = None
628
+ create_provider_config_file (config_file_path , providers )
629
+ config_override_files .append (config_file_path )
540
630
541
631
# call terraform command if not dry-run or any of the commands
542
- if not DRY_RUN or not is_override_needed ( sys . argv [ 1 :]) :
632
+ if not DRY_RUN or not config_override_files :
543
633
try :
544
634
if USE_EXEC :
545
635
run_tf_exec (cmd , env )
546
636
else :
547
637
run_tf_subprocess (cmd , env )
548
638
finally :
549
- # fall through if haven't set during dry-run
550
- if config_file :
551
- os .remove (config_file )
639
+ cleanup_override_files (config_override_files )
552
640
553
641
554
642
if __name__ == "__main__" :
0 commit comments