1
+ import contextlib
1
2
import dataclasses
2
3
import enum
3
4
import itertools
4
5
import json
5
6
import logging
7
+ import os
6
8
import pathlib
7
9
import tempfile
8
10
import uuid
15
17
from datahub .emitter .mce_builder import get_sys_time , make_ts_millis
16
18
from datahub .emitter .mcp import MetadataChangeProposalWrapper
17
19
from datahub .emitter .sql_parsing_builder import compute_upstream_fields
20
+ from datahub .ingestion .api .closeable import Closeable
18
21
from datahub .ingestion .api .report import Report
19
22
from datahub .ingestion .api .workunit import MetadataWorkUnit
20
23
from datahub .ingestion .graph .client import DataHubGraph
53
56
QueryId = str
54
57
UrnStr = str
55
58
56
- _DEFAULT_USER_URN = CorpUserUrn ("_ingestion" )
57
- _MISSING_SESSION_ID = "__MISSING_SESSION_ID"
58
-
59
59
60
60
class QueryLogSetting (enum .Enum ):
61
61
DISABLED = "DISABLED"
62
62
STORE_ALL = "STORE_ALL"
63
63
STORE_FAILED = "STORE_FAILED"
64
64
65
65
66
+ _DEFAULT_USER_URN = CorpUserUrn ("_ingestion" )
67
+ _MISSING_SESSION_ID = "__MISSING_SESSION_ID"
68
+ _DEFAULT_QUERY_LOG_SETTING = QueryLogSetting [
69
+ os .getenv ("DATAHUB_SQL_AGG_QUERY_LOG" ) or QueryLogSetting .DISABLED .name
70
+ ]
71
+
72
+
73
+ @dataclasses .dataclass
74
+ class LoggedQuery :
75
+ query : str
76
+ session_id : Optional [str ]
77
+ timestamp : Optional [datetime ]
78
+ user : Optional [UrnStr ]
79
+ default_db : Optional [str ]
80
+ default_schema : Optional [str ]
81
+
82
+
66
83
@dataclasses .dataclass
67
84
class ViewDefinition :
68
85
view_definition : str
@@ -170,7 +187,7 @@ def compute_stats(self) -> None:
170
187
return super ().compute_stats ()
171
188
172
189
173
- class SqlParsingAggregator :
190
+ class SqlParsingAggregator ( Closeable ) :
174
191
def __init__ (
175
192
self ,
176
193
* ,
@@ -185,7 +202,7 @@ def __init__(
185
202
usage_config : Optional [BaseUsageConfig ] = None ,
186
203
is_temp_table : Optional [Callable [[UrnStr ], bool ]] = None ,
187
204
format_queries : bool = True ,
188
- query_log : QueryLogSetting = QueryLogSetting . DISABLED ,
205
+ query_log : QueryLogSetting = _DEFAULT_QUERY_LOG_SETTING ,
189
206
) -> None :
190
207
self .platform = DataPlatformUrn (platform )
191
208
self .platform_instance = platform_instance
@@ -210,13 +227,18 @@ def __init__(
210
227
self .format_queries = format_queries
211
228
self .query_log = query_log
212
229
230
+ # The exit stack helps ensure that we close all the resources we open.
231
+ self ._exit_stack = contextlib .ExitStack ()
232
+
213
233
# Set up the schema resolver.
214
234
self ._schema_resolver : SchemaResolver
215
235
if graph is None :
216
- self ._schema_resolver = SchemaResolver (
217
- platform = self .platform .platform_name ,
218
- platform_instance = self .platform_instance ,
219
- env = self .env ,
236
+ self ._schema_resolver = self ._exit_stack .enter_context (
237
+ SchemaResolver (
238
+ platform = self .platform .platform_name ,
239
+ platform_instance = self .platform_instance ,
240
+ env = self .env ,
241
+ )
220
242
)
221
243
else :
222
244
self ._schema_resolver = None # type: ignore
@@ -235,44 +257,54 @@ def __init__(
235
257
236
258
# By providing a filename explicitly here, we also ensure that the file
237
259
# is not automatically deleted on exit.
238
- self ._shared_connection = ConnectionWrapper (filename = query_log_path )
260
+ self ._shared_connection = self ._exit_stack .enter_context (
261
+ ConnectionWrapper (filename = query_log_path )
262
+ )
239
263
240
264
# Stores the logged queries.
241
- self ._logged_queries = FileBackedList [str ](
265
+ self ._logged_queries = FileBackedList [LoggedQuery ](
242
266
shared_connection = self ._shared_connection , tablename = "stored_queries"
243
267
)
268
+ self ._exit_stack .push (self ._logged_queries )
244
269
245
270
# Map of query_id -> QueryMetadata
246
271
self ._query_map = FileBackedDict [QueryMetadata ](
247
272
shared_connection = self ._shared_connection , tablename = "query_map"
248
273
)
274
+ self ._exit_stack .push (self ._query_map )
249
275
250
276
# Map of downstream urn -> { query ids }
251
277
self ._lineage_map = FileBackedDict [OrderedSet [QueryId ]](
252
278
shared_connection = self ._shared_connection , tablename = "lineage_map"
253
279
)
280
+ self ._exit_stack .push (self ._lineage_map )
254
281
255
282
# Map of view urn -> view definition
256
283
self ._view_definitions = FileBackedDict [ViewDefinition ](
257
284
shared_connection = self ._shared_connection , tablename = "view_definitions"
258
285
)
286
+ self ._exit_stack .push (self ._view_definitions )
259
287
260
288
# Map of session ID -> {temp table name -> query id}
261
289
# Needs to use the query_map to find the info about the query.
262
290
# This assumes that a temp table is created at most once per session.
263
291
self ._temp_lineage_map = FileBackedDict [Dict [UrnStr , QueryId ]](
264
292
shared_connection = self ._shared_connection , tablename = "temp_lineage_map"
265
293
)
294
+ self ._exit_stack .push (self ._temp_lineage_map )
266
295
267
296
# Map of query ID -> schema fields, only for query IDs that generate temp tables.
268
297
self ._inferred_temp_schemas = FileBackedDict [List [models .SchemaFieldClass ]](
269
- shared_connection = self ._shared_connection , tablename = "inferred_temp_schemas"
298
+ shared_connection = self ._shared_connection ,
299
+ tablename = "inferred_temp_schemas" ,
270
300
)
301
+ self ._exit_stack .push (self ._inferred_temp_schemas )
271
302
272
303
# Map of table renames, from original UrnStr to new UrnStr.
273
304
self ._table_renames = FileBackedDict [UrnStr ](
274
305
shared_connection = self ._shared_connection , tablename = "table_renames"
275
306
)
307
+ self ._exit_stack .push (self ._table_renames )
276
308
277
309
# Usage aggregator. This will only be initialized if usage statistics are enabled.
278
310
# TODO: Replace with FileBackedDict.
@@ -281,6 +313,9 @@ def __init__(
281
313
assert self .usage_config is not None
282
314
self ._usage_aggregator = UsageAggregator (config = self .usage_config )
283
315
316
+ def close (self ) -> None :
317
+ self ._exit_stack .close ()
318
+
284
319
@property
285
320
def _need_schemas (self ) -> bool :
286
321
return self .generate_lineage or self .generate_usage_statistics
@@ -499,6 +534,9 @@ def add_observed_query(
499
534
default_db = default_db ,
500
535
default_schema = default_schema ,
501
536
schema_resolver = schema_resolver ,
537
+ session_id = session_id ,
538
+ timestamp = query_timestamp ,
539
+ user = user ,
502
540
)
503
541
if parsed .debug_info .error :
504
542
self .report .observed_query_parse_failures .append (
@@ -700,6 +738,9 @@ def _run_sql_parser(
700
738
default_db : Optional [str ],
701
739
default_schema : Optional [str ],
702
740
schema_resolver : SchemaResolverInterface ,
741
+ session_id : str = _MISSING_SESSION_ID ,
742
+ timestamp : Optional [datetime ] = None ,
743
+ user : Optional [CorpUserUrn ] = None ,
703
744
) -> SqlParsingResult :
704
745
parsed = sqlglot_lineage (
705
746
query ,
@@ -712,7 +753,15 @@ def _run_sql_parser(
712
753
if self .query_log == QueryLogSetting .STORE_ALL or (
713
754
self .query_log == QueryLogSetting .STORE_FAILED and parsed .debug_info .error
714
755
):
715
- self ._logged_queries .append (query )
756
+ query_log_entry = LoggedQuery (
757
+ query = query ,
758
+ session_id = session_id if session_id != _MISSING_SESSION_ID else None ,
759
+ timestamp = timestamp ,
760
+ user = user .urn () if user else None ,
761
+ default_db = default_db ,
762
+ default_schema = default_schema ,
763
+ )
764
+ self ._logged_queries .append (query_log_entry )
716
765
717
766
# Also add some extra logging.
718
767
if parsed .debug_info .error :
0 commit comments