Skip to content

Commit c3b3302

Browse files
committed
fix(ingest/snowflake): fix type annotations + refactor get_connect_args
In the latest releast of types-PyOpenSSL, they removed their dep on types-cryptography (python/typeshed#9459). This caused mypy to start using the type annotations provided by the cryptography package, which caused an issue in our build. Example build failure: https://github.com/datahub-project/datahub/actions/runs/3888835554/jobs/6636552750
1 parent acd2ba1 commit c3b3302

File tree

1 file changed

+20
-17
lines changed
  • metadata-ingestion/src/datahub/ingestion/source_config/sql

1 file changed

+20
-17
lines changed

metadata-ingestion/src/datahub/ingestion/source_config/sql/snowflake.py

+20-17
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Dict, Optional
2+
from typing import Any, Dict, Optional
33

44
import pydantic
55
import snowflake.connector
@@ -153,7 +153,7 @@ class BaseSnowflakeConfig(BaseTimeWindowConfig):
153153
default=True,
154154
description="If enabled, populates the snowflake view->table and table->view lineages (no view->view lineage yet). Requires appropriate grants given to the role, and include_table_lineage to be True. view->table lineage requires Snowflake Enterprise Edition or above.",
155155
)
156-
connect_args: Optional[Dict] = pydantic.Field(
156+
connect_args: Optional[Dict[str, Any]] = pydantic.Field(
157157
default=None,
158158
description="Connect args to pass to Snowflake SqlAlchemy driver",
159159
exclude=True,
@@ -297,28 +297,28 @@ def get_sql_alchemy_url(
297297
},
298298
)
299299

300+
_computed_connect_args: Optional[dict] = None
301+
300302
def get_connect_args(self) -> dict:
301303
"""
302-
Builds connect args and updates self.connect_args so that
303-
Subsequent calls to this method are efficient, i.e. do not read files again
304+
Builds connect args, adding defaults and reading a private key from the file if needed.
305+
Caches the results in a private instance variable to avoid reading the file multiple times.
304306
"""
305307

306-
base_connect_args = {
308+
if self._computed_connect_args is not None:
309+
return self._computed_connect_args
310+
311+
connect_args: dict = {
307312
# Improves performance and avoids timeout errors for larger query result
308313
CLIENT_PREFETCH_THREADS: 10,
309314
CLIENT_SESSION_KEEP_ALIVE: True,
310-
}
311-
312-
if self.connect_args is None:
313-
self.connect_args = base_connect_args
314-
else:
315315
# Let user override the default config values
316-
base_connect_args.update(self.connect_args)
317-
self.connect_args = base_connect_args
316+
**(self.connect_args or {}),
317+
}
318318

319319
if (
320-
self.authentication_type == "KEY_PAIR_AUTHENTICATOR"
321-
and "private_key" not in self.connect_args.keys()
320+
"private_key" not in connect_args
321+
and self.authentication_type == "KEY_PAIR_AUTHENTICATOR"
322322
):
323323
if self.private_key is not None:
324324
pkey_bytes = self.private_key.replace("\\n", "\n").encode()
@@ -337,13 +337,16 @@ def get_connect_args(self) -> dict:
337337
backend=default_backend(),
338338
)
339339

340-
pkb = p_key.private_bytes(
340+
pkb: bytes = p_key.private_bytes(
341341
encoding=serialization.Encoding.DER,
342342
format=serialization.PrivateFormat.PKCS8,
343343
encryption_algorithm=serialization.NoEncryption(),
344344
)
345-
self.connect_args.update({"private_key": pkb})
346-
return self.connect_args
345+
346+
connect_args["private_key"] = pkb
347+
348+
self._computed_connect_args = connect_args
349+
return connect_args
347350

348351

349352
class SnowflakeConfig(BaseSnowflakeConfig, SQLAlchemyConfig):

0 commit comments

Comments
 (0)