Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit ee6c593

Browse files
authored
Merge pull request #92 from datafold/fix_june22
Fix: Only parse relevant columns. Only warn on relevant columns.
2 parents 474de40 + 75e3e52 commit ee6c593

File tree

3 files changed

+60
-32
lines changed

3 files changed

+60
-32
lines changed

data_diff/database.py

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from abc import ABC, abstractmethod
66
from runtype import dataclass
77
import logging
8-
from typing import Tuple, Optional, List
8+
from typing import Sequence, Tuple, Optional, List
99
from concurrent.futures import ThreadPoolExecutor
1010
import threading
1111
from typing import Dict
@@ -159,10 +159,6 @@ def __post_init__(self):
159159
class UnknownColType(ColType):
160160
text: str
161161

162-
def __post_init__(self):
163-
logger.warn(f"Column of type '{self.text}' has no compatibility handling. "
164-
"If encoding/formatting differs between databases, it may result in false positives.")
165-
166162

167163
class AbstractDatabase(ABC):
168164
@abstractmethod
@@ -191,7 +187,7 @@ def select_table_schema(self, path: DbPath) -> str:
191187
...
192188

193189
@abstractmethod
194-
def query_table_schema(self, path: DbPath) -> Dict[str, ColType]:
190+
def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str]] = None) -> Dict[str, ColType]:
195191
"Query the table for its schema for table in 'path', and return {column: type}"
196192
...
197193

@@ -205,7 +201,6 @@ def close(self):
205201
"Close connection(s) to the database instance. Querying will stop functioning."
206202
...
207203

208-
209204
@abstractmethod
210205
def normalize_timestamp(self, value: str, coltype: ColType) -> str:
211206
"""Creates an SQL expression, that converts 'value' to a normalized timestamp.
@@ -269,6 +264,10 @@ class Database(AbstractDatabase):
269264
DATETIME_TYPES = {}
270265
default_schema = None
271266

267+
@property
268+
def name(self):
269+
return type(self).__name__
270+
272271
def query(self, sql_ast: SqlOrStr, res_type: type):
273272
"Query the given SQL code/AST, and attempt to convert the result to type 'res_type'"
274273

@@ -310,7 +309,12 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
310309
return math.floor(math.log(2**p, 10))
311310

312311
def _parse_type(
313-
self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None, numeric_scale: int = None
312+
self,
313+
col_name: str,
314+
type_repr: str,
315+
datetime_precision: int = None,
316+
numeric_precision: int = None,
317+
numeric_scale: int = None,
314318
) -> ColType:
315319
""" """
316320

@@ -329,6 +333,8 @@ def _parse_type(
329333
return cls(precision=0)
330334

331335
elif issubclass(cls, Decimal):
336+
if numeric_scale is None:
337+
raise ValueError(f"{self.name}: Unexpected numeric_scale is NULL, for column {col_name} of type {type_repr}.")
332338
return cls(precision=numeric_scale)
333339

334340
assert issubclass(cls, Float)
@@ -349,13 +355,17 @@ def select_table_schema(self, path: DbPath) -> str:
349355
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
350356
)
351357

352-
def query_table_schema(self, path: DbPath) -> Dict[str, ColType]:
358+
def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str]] = None) -> Dict[str, ColType]:
353359
rows = self.query(self.select_table_schema(path), list)
354360
if not rows:
355-
raise RuntimeError(f"{self.__class__.__name__}: Table '{'.'.join(path)}' does not exist, or has no columns")
361+
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")
362+
363+
if filter_columns is not None:
364+
accept = {i.lower() for i in filter_columns}
365+
rows = [r for r in rows if r[0].lower() in accept]
356366

357-
# Return a dict of form {name: type} after canonizaation
358-
return {row[0]: self._parse_type(*row[1:]) for row in rows}
367+
# Return a dict of form {name: type} after normalization
368+
return {row[0]: self._parse_type(*row) for row in rows}
359369

360370
# @lru_cache()
361371
# def get_table_schema(self, path: DbPath) -> Dict[str, ColType]:
@@ -366,9 +376,7 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
366376
if self.default_schema:
367377
return self.default_schema, path[0]
368378
elif len(path) != 2:
369-
raise ValueError(
370-
f"{self.__class__.__name__}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table"
371-
)
379+
raise ValueError(f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table")
372380

373381
return path
374382

@@ -440,6 +448,7 @@ class PostgreSQL(ThreadedDatabase):
440448
"decimal": Decimal,
441449
"integer": Integer,
442450
"numeric": Decimal,
451+
"bigint": Integer,
443452
}
444453
ROUNDS_ON_PREC_LOSS = True
445454

@@ -472,13 +481,14 @@ def md5_to_int(self, s: str) -> str:
472481
def to_string(self, s: str):
473482
return f"{s}::varchar"
474483

475-
476484
def normalize_timestamp(self, value: str, coltype: ColType) -> str:
477485
if coltype.rounds:
478486
return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')"
479487

480488
timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')"
481-
return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
489+
return (
490+
f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
491+
)
482492

483493
def normalize_number(self, value: str, coltype: ColType) -> str:
484494
return self.to_string(f"{value}::decimal(38, {coltype.precision})")
@@ -528,9 +538,7 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str:
528538
else:
529539
s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')"
530540

531-
return (
532-
f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')"
533-
)
541+
return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')"
534542

535543
def normalize_number(self, value: str, coltype: ColType) -> str:
536544
return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))")
@@ -543,7 +551,9 @@ def select_table_schema(self, path: DbPath) -> str:
543551
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
544552
)
545553

546-
def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None) -> ColType:
554+
def _parse_type(
555+
self, col_name: str, type_repr: str, datetime_precision: int = None, numeric_precision: int = None
556+
) -> ColType:
547557
regexps = {
548558
r"timestamp\((\d)\)": Timestamp,
549559
r"timestamp\((\d)\) with time zone": TimestampTZ,
@@ -633,7 +643,6 @@ def normalize_number(self, value: str, coltype: ColType) -> str:
633643
return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))")
634644

635645

636-
637646
class Oracle(ThreadedDatabase):
638647
ROUNDS_ON_PREC_LOSS = True
639648

@@ -687,7 +696,12 @@ def normalize_number(self, value: str, coltype: ColType) -> str:
687696
return f"to_char({value}, '{format_str}')"
688697

689698
def _parse_type(
690-
self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None, numeric_scale: int = None
699+
self,
700+
col_name: str,
701+
type_repr: str,
702+
datetime_precision: int = None,
703+
numeric_precision: int = None,
704+
numeric_scale: int = None,
691705
) -> ColType:
692706
""" """
693707
regexps = {
@@ -746,15 +760,18 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str:
746760
us = f"extract(us from {timestamp})"
747761
# epoch = Total time since epoch in microseconds.
748762
epoch = f"{secs}*1000000 + {ms}*1000 + {us}"
749-
timestamp6 = f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')"
763+
timestamp6 = (
764+
f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')"
765+
)
750766
else:
751767
timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')"
752-
return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
768+
return (
769+
f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
770+
)
753771

754772
def normalize_number(self, value: str, coltype: ColType) -> str:
755773
return self.to_string(f"{value}::decimal(38,{coltype.precision})")
756774

757-
758775
def select_table_schema(self, path: DbPath) -> str:
759776
schema, table = self._normalize_table_path(path)
760777

@@ -864,7 +881,9 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str:
864881
return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})"
865882

866883
timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})"
867-
return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
884+
return (
885+
f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
886+
)
868887

869888
def normalize_number(self, value: str, coltype: ColType) -> str:
870889
if isinstance(coltype, Integer):

data_diff/diff_tables.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from runtype import dataclass
1313

1414
from .sql import Select, Checksum, Compare, DbPath, DbKey, DbTime, Count, TableName, Time, Min, Max
15-
from .database import Database, NumericType, PrecisionType, ColType
15+
from .database import Database, NumericType, PrecisionType, ColType, UnknownColType
1616

1717
logger = logging.getLogger("diff_tables")
1818

@@ -142,7 +142,8 @@ def with_schema(self) -> "TableSegment":
142142
"Queries the table schema from the database, and returns a new instance of TableSegmentWithSchema."
143143
if self._schema:
144144
return self
145-
schema = self.database.query_table_schema(self.table_path)
145+
146+
schema = self.database.query_table_schema(self.table_path, self._relevant_columns)
146147
if self.case_sensitive:
147148
schema = Schema_CaseSensitive(schema)
148149
else:
@@ -381,6 +382,15 @@ def _validate_and_adjust_columns(self, table1, table2):
381382
table1._schema[c] = col1.replace(precision=lowest.precision)
382383
table2._schema[c] = col2.replace(precision=lowest.precision)
383384

385+
for t in [table1, table2]:
386+
for c in t._relevant_columns:
387+
ctype = t._schema[c]
388+
if isinstance(ctype, UnknownColType):
389+
logger.warn(
390+
f"[{t.database.name}] Column '{c}' of type '{ctype.text}' has no compatibility handling. "
391+
"If encoding/formatting differs between databases, it may result in false positives."
392+
)
393+
384394
def _bisect_and_diff_tables(self, table1, table2, level=0, max_rows=None):
385395
assert table1.is_bounded and table2.is_bounded
386396

tests/test_diff_tables.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,16 @@ def tearDownClass(cls):
3232
cls.preql.close()
3333
cls.connection.close()
3434

35-
3635
# Fallback for test runners that doesn't support setUpClass/tearDownClass
3736
def setUp(self) -> None:
38-
if not hasattr(self, 'connection'):
37+
if not hasattr(self, "connection"):
3938
self.setUpClass.__func__(self)
4039
self.private_connection = True
4140

4241
return super().setUp()
4342

4443
def tearDown(self) -> None:
45-
if hasattr(self, 'private_connection'):
44+
if hasattr(self, "private_connection"):
4645
self.tearDownClass.__func__(self)
4746

4847
return super().tearDown()

0 commit comments

Comments
 (0)