Skip to content

Commit b0163c4

Browse files
authored
feat(ingest): utilities for query logs (#10036)
1 parent 4535f2a commit b0163c4

File tree

10 files changed

+583
-291
lines changed

10 files changed

+583
-291
lines changed

metadata-ingestion/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@
9999
sqlglot_lib = {
100100
# Using an Acryl fork of sqlglot.
101101
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:hsheth?expand=1
102-
"acryl-sqlglot==22.3.1.dev3",
102+
"acryl-sqlglot==22.4.1.dev4",
103103
}
104104

105105
classification_lib = {

metadata-ingestion/src/datahub/cli/check_cli.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import dataclasses
2+
import json
13
import logging
4+
import pathlib
25
import pprint
36
import shutil
47
import tempfile
@@ -17,6 +20,7 @@
1720
from datahub.ingestion.source.source_registry import source_registry
1821
from datahub.ingestion.transformer.transform_registry import transform_registry
1922
from datahub.telemetry import telemetry
23+
from datahub.utilities.file_backed_collections import ConnectionWrapper, FileBackedList
2024

2125
logger = logging.getLogger(__name__)
2226

@@ -339,3 +343,28 @@ def test_path_spec(config: str, input: str, path_spec_key: str) -> None:
339343
f"Failed to validate pattern {pattern_dicts} in path {path_spec_key}"
340344
)
341345
raise e
346+
347+
348+
@check.command()
349+
@click.argument("query-log-file", type=click.Path(exists=True, dir_okay=False))
350+
@click.option("--output", type=click.Path())
351+
def extract_sql_agg_log(query_log_file: str, output: Optional[str]) -> None:
352+
"""Convert a sqlite db generated by the SqlParsingAggregator into a JSON."""
353+
354+
from datahub.sql_parsing.sql_parsing_aggregator import LoggedQuery
355+
356+
assert dataclasses.is_dataclass(LoggedQuery)
357+
358+
shared_connection = ConnectionWrapper(pathlib.Path(query_log_file))
359+
query_log = FileBackedList[LoggedQuery](
360+
shared_connection=shared_connection, tablename="stored_queries"
361+
)
362+
logger.info(f"Extracting {len(query_log)} queries from {query_log_file}")
363+
queries = [dataclasses.asdict(query) for query in query_log]
364+
365+
if output:
366+
with open(output, "w") as f:
367+
json.dump(queries, f, indent=2)
368+
logger.info(f"Extracted {len(queries)} queries to {output}")
369+
else:
370+
click.echo(json.dumps(queries, indent=2))

metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import contextlib
12
import dataclasses
23
import enum
34
import itertools
45
import json
56
import logging
7+
import os
68
import pathlib
79
import tempfile
810
import uuid
@@ -15,6 +17,7 @@
1517
from datahub.emitter.mce_builder import get_sys_time, make_ts_millis
1618
from datahub.emitter.mcp import MetadataChangeProposalWrapper
1719
from datahub.emitter.sql_parsing_builder import compute_upstream_fields
20+
from datahub.ingestion.api.closeable import Closeable
1821
from datahub.ingestion.api.report import Report
1922
from datahub.ingestion.api.workunit import MetadataWorkUnit
2023
from datahub.ingestion.graph.client import DataHubGraph
@@ -53,16 +56,30 @@
5356
QueryId = str
5457
UrnStr = str
5558

56-
_DEFAULT_USER_URN = CorpUserUrn("_ingestion")
57-
_MISSING_SESSION_ID = "__MISSING_SESSION_ID"
58-
5959

6060
class QueryLogSetting(enum.Enum):
6161
DISABLED = "DISABLED"
6262
STORE_ALL = "STORE_ALL"
6363
STORE_FAILED = "STORE_FAILED"
6464

6565

66+
_DEFAULT_USER_URN = CorpUserUrn("_ingestion")
67+
_MISSING_SESSION_ID = "__MISSING_SESSION_ID"
68+
_DEFAULT_QUERY_LOG_SETTING = QueryLogSetting[
69+
os.getenv("DATAHUB_SQL_AGG_QUERY_LOG") or QueryLogSetting.DISABLED.name
70+
]
71+
72+
73+
@dataclasses.dataclass
74+
class LoggedQuery:
75+
query: str
76+
session_id: Optional[str]
77+
timestamp: Optional[datetime]
78+
user: Optional[UrnStr]
79+
default_db: Optional[str]
80+
default_schema: Optional[str]
81+
82+
6683
@dataclasses.dataclass
6784
class ViewDefinition:
6885
view_definition: str
@@ -170,7 +187,7 @@ def compute_stats(self) -> None:
170187
return super().compute_stats()
171188

172189

173-
class SqlParsingAggregator:
190+
class SqlParsingAggregator(Closeable):
174191
def __init__(
175192
self,
176193
*,
@@ -185,7 +202,7 @@ def __init__(
185202
usage_config: Optional[BaseUsageConfig] = None,
186203
is_temp_table: Optional[Callable[[UrnStr], bool]] = None,
187204
format_queries: bool = True,
188-
query_log: QueryLogSetting = QueryLogSetting.DISABLED,
205+
query_log: QueryLogSetting = _DEFAULT_QUERY_LOG_SETTING,
189206
) -> None:
190207
self.platform = DataPlatformUrn(platform)
191208
self.platform_instance = platform_instance
@@ -210,13 +227,18 @@ def __init__(
210227
self.format_queries = format_queries
211228
self.query_log = query_log
212229

230+
# The exit stack helps ensure that we close all the resources we open.
231+
self._exit_stack = contextlib.ExitStack()
232+
213233
# Set up the schema resolver.
214234
self._schema_resolver: SchemaResolver
215235
if graph is None:
216-
self._schema_resolver = SchemaResolver(
217-
platform=self.platform.platform_name,
218-
platform_instance=self.platform_instance,
219-
env=self.env,
236+
self._schema_resolver = self._exit_stack.enter_context(
237+
SchemaResolver(
238+
platform=self.platform.platform_name,
239+
platform_instance=self.platform_instance,
240+
env=self.env,
241+
)
220242
)
221243
else:
222244
self._schema_resolver = None # type: ignore
@@ -235,44 +257,54 @@ def __init__(
235257

236258
# By providing a filename explicitly here, we also ensure that the file
237259
# is not automatically deleted on exit.
238-
self._shared_connection = ConnectionWrapper(filename=query_log_path)
260+
self._shared_connection = self._exit_stack.enter_context(
261+
ConnectionWrapper(filename=query_log_path)
262+
)
239263

240264
# Stores the logged queries.
241-
self._logged_queries = FileBackedList[str](
265+
self._logged_queries = FileBackedList[LoggedQuery](
242266
shared_connection=self._shared_connection, tablename="stored_queries"
243267
)
268+
self._exit_stack.push(self._logged_queries)
244269

245270
# Map of query_id -> QueryMetadata
246271
self._query_map = FileBackedDict[QueryMetadata](
247272
shared_connection=self._shared_connection, tablename="query_map"
248273
)
274+
self._exit_stack.push(self._query_map)
249275

250276
# Map of downstream urn -> { query ids }
251277
self._lineage_map = FileBackedDict[OrderedSet[QueryId]](
252278
shared_connection=self._shared_connection, tablename="lineage_map"
253279
)
280+
self._exit_stack.push(self._lineage_map)
254281

255282
# Map of view urn -> view definition
256283
self._view_definitions = FileBackedDict[ViewDefinition](
257284
shared_connection=self._shared_connection, tablename="view_definitions"
258285
)
286+
self._exit_stack.push(self._view_definitions)
259287

260288
# Map of session ID -> {temp table name -> query id}
261289
# Needs to use the query_map to find the info about the query.
262290
# This assumes that a temp table is created at most once per session.
263291
self._temp_lineage_map = FileBackedDict[Dict[UrnStr, QueryId]](
264292
shared_connection=self._shared_connection, tablename="temp_lineage_map"
265293
)
294+
self._exit_stack.push(self._temp_lineage_map)
266295

267296
# Map of query ID -> schema fields, only for query IDs that generate temp tables.
268297
self._inferred_temp_schemas = FileBackedDict[List[models.SchemaFieldClass]](
269-
shared_connection=self._shared_connection, tablename="inferred_temp_schemas"
298+
shared_connection=self._shared_connection,
299+
tablename="inferred_temp_schemas",
270300
)
301+
self._exit_stack.push(self._inferred_temp_schemas)
271302

272303
# Map of table renames, from original UrnStr to new UrnStr.
273304
self._table_renames = FileBackedDict[UrnStr](
274305
shared_connection=self._shared_connection, tablename="table_renames"
275306
)
307+
self._exit_stack.push(self._table_renames)
276308

277309
# Usage aggregator. This will only be initialized if usage statistics are enabled.
278310
# TODO: Replace with FileBackedDict.
@@ -281,6 +313,9 @@ def __init__(
281313
assert self.usage_config is not None
282314
self._usage_aggregator = UsageAggregator(config=self.usage_config)
283315

316+
def close(self) -> None:
317+
self._exit_stack.close()
318+
284319
@property
285320
def _need_schemas(self) -> bool:
286321
return self.generate_lineage or self.generate_usage_statistics
@@ -499,6 +534,9 @@ def add_observed_query(
499534
default_db=default_db,
500535
default_schema=default_schema,
501536
schema_resolver=schema_resolver,
537+
session_id=session_id,
538+
timestamp=query_timestamp,
539+
user=user,
502540
)
503541
if parsed.debug_info.error:
504542
self.report.observed_query_parse_failures.append(
@@ -700,6 +738,9 @@ def _run_sql_parser(
700738
default_db: Optional[str],
701739
default_schema: Optional[str],
702740
schema_resolver: SchemaResolverInterface,
741+
session_id: str = _MISSING_SESSION_ID,
742+
timestamp: Optional[datetime] = None,
743+
user: Optional[CorpUserUrn] = None,
703744
) -> SqlParsingResult:
704745
parsed = sqlglot_lineage(
705746
query,
@@ -712,7 +753,15 @@ def _run_sql_parser(
712753
if self.query_log == QueryLogSetting.STORE_ALL or (
713754
self.query_log == QueryLogSetting.STORE_FAILED and parsed.debug_info.error
714755
):
715-
self._logged_queries.append(query)
756+
query_log_entry = LoggedQuery(
757+
query=query,
758+
session_id=session_id if session_id != _MISSING_SESSION_ID else None,
759+
timestamp=timestamp,
760+
user=user.urn() if user else None,
761+
default_db=default_db,
762+
default_schema=default_schema,
763+
)
764+
self._logged_queries.append(query_log_entry)
716765

717766
# Also add some extra logging.
718767
if parsed.debug_info.error:

metadata-ingestion/src/datahub/testing/compare_metadata_json.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,13 @@ def assert_metadata_files_equal(
6262
# We have to "normalize" the golden file by reading and writing it back out.
6363
# This will clean up nulls, double serialization, and other formatting issues.
6464
with tempfile.NamedTemporaryFile() as temp:
65-
golden_metadata = read_metadata_file(pathlib.Path(golden_path))
66-
write_metadata_file(pathlib.Path(temp.name), golden_metadata)
67-
golden = load_json_file(temp.name)
65+
try:
66+
golden_metadata = read_metadata_file(pathlib.Path(golden_path))
67+
write_metadata_file(pathlib.Path(temp.name), golden_metadata)
68+
golden = load_json_file(temp.name)
69+
except (ValueError, AssertionError) as e:
70+
logger.info(f"Error reformatting golden file as MCP/MCEs: {e}")
71+
golden = load_json_file(golden_path)
6872

6973
diff = diff_metadata_json(output, golden, ignore_paths, ignore_order=ignore_order)
7074
if diff and update_golden:
@@ -107,7 +111,7 @@ def diff_metadata_json(
107111
# if ignore_order is False, always use DeepDiff
108112
except CannotCompareMCPs as e:
109113
logger.info(f"{e}, falling back to MCE diff")
110-
except AssertionError as e:
114+
except (AssertionError, ValueError) as e:
111115
logger.warning(f"Reverting to old diff method: {e}")
112116
logger.debug("Error with new diff method", exc_info=True)
113117

metadata-ingestion/src/datahub/utilities/file_backed_collections.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def executemany(
126126
def close(self) -> None:
127127
for obj in self._dependent_objects:
128128
obj.close()
129+
self._dependent_objects.clear()
129130
with self.conn_lock:
130131
self.conn.close()
131132
if self._temp_directory:
@@ -440,7 +441,7 @@ def __del__(self) -> None:
440441
self.close()
441442

442443

443-
class FileBackedList(Generic[_VT]):
444+
class FileBackedList(Generic[_VT], Closeable):
444445
"""An append-only, list-like object that stores its contents in a SQLite database."""
445446

446447
_len: int = field(default=0)
@@ -456,7 +457,6 @@ def __init__(
456457
cache_max_size: Optional[int] = None,
457458
cache_eviction_batch_size: Optional[int] = None,
458459
) -> None:
459-
self._len = 0
460460
self._dict = FileBackedDict[_VT](
461461
shared_connection=shared_connection,
462462
tablename=tablename,
@@ -468,6 +468,12 @@ def __init__(
468468
or _DEFAULT_MEMORY_CACHE_EVICTION_BATCH_SIZE,
469469
)
470470

471+
if shared_connection:
472+
shared_connection._dependent_objects.append(self)
473+
474+
# In case we're reusing an existing list, we need to run a query to get the length.
475+
self._len = len(self._dict)
476+
471477
@property
472478
def tablename(self) -> str:
473479
return self._dict.tablename

0 commit comments

Comments
 (0)