Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Fix: Only parse relevant columns. Only warn on relevant columns. #92

Merged
merged 3 commits into from
Jun 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 46 additions & 28 deletions data_diff/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from abc import ABC, abstractmethod
from runtype import dataclass
import logging
from typing import Tuple, Optional, List
from typing import Sequence, Tuple, Optional, List
from concurrent.futures import ThreadPoolExecutor
import threading
from typing import Dict
Expand Down Expand Up @@ -131,10 +131,6 @@ def __post_init__(self):
class UnknownColType(ColType):
text: str

def __post_init__(self):
logger.warn(f"Column of type '{self.text}' has no compatibility handling. "
"If encoding/formatting differs between databases, it may result in false positives.")


class AbstractDatabase(ABC):
@abstractmethod
Expand Down Expand Up @@ -163,7 +159,7 @@ def select_table_schema(self, path: DbPath) -> str:
...

@abstractmethod
def query_table_schema(self, path: DbPath) -> Dict[str, ColType]:
def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str]] = None) -> Dict[str, ColType]:
"Query the table for its schema for table in 'path', and return {column: type}"
...

Expand All @@ -177,7 +173,6 @@ def close(self):
"Close connection(s) to the database instance. Querying will stop functioning."
...


@abstractmethod
def normalize_timestamp(self, value: str, coltype: ColType) -> str:
"""Creates an SQL expression, that converts 'value' to a normalized timestamp.
Expand Down Expand Up @@ -241,6 +236,10 @@ class Database(AbstractDatabase):
DATETIME_TYPES = {}
default_schema = None

@property
def name(self):
return type(self).__name__

def query(self, sql_ast: SqlOrStr, res_type: type):
"Query the given SQL code/AST, and attempt to convert the result to type 'res_type'"

Expand Down Expand Up @@ -282,7 +281,12 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
return math.floor(math.log(2**p, 10))

def _parse_type(
self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None, numeric_scale: int = None
self,
col_name: str,
type_repr: str,
datetime_precision: int = None,
numeric_precision: int = None,
numeric_scale: int = None,
) -> ColType:
""" """

Expand All @@ -301,6 +305,8 @@ def _parse_type(
return cls(precision=0)

elif issubclass(cls, Decimal):
if numeric_scale is None:
raise ValueError(f"{self.name}: Unexpected numeric_scale is NULL, for column {col_name} of type {type_repr}.")
return cls(precision=numeric_scale)

assert issubclass(cls, Float)
Expand All @@ -321,13 +327,17 @@ def select_table_schema(self, path: DbPath) -> str:
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
)

def query_table_schema(self, path: DbPath) -> Dict[str, ColType]:
def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str]] = None) -> Dict[str, ColType]:
rows = self.query(self.select_table_schema(path), list)
if not rows:
raise RuntimeError(f"{self.__class__.__name__}: Table '{'.'.join(path)}' does not exist, or has no columns")
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")

if filter_columns is not None:
accept = {i.lower() for i in filter_columns}
rows = [r for r in rows if r[0].lower() in accept]

# Return a dict of form {name: type} after canonizaation
return {row[0]: self._parse_type(*row[1:]) for row in rows}
# Return a dict of form {name: type} after normalization
return {row[0]: self._parse_type(*row) for row in rows}

# @lru_cache()
# def get_table_schema(self, path: DbPath) -> Dict[str, ColType]:
Expand All @@ -338,13 +348,10 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
if self.default_schema:
return self.default_schema, path[0]
elif len(path) != 2:
raise ValueError(
f"{self.__class__.__name__}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table"
)
raise ValueError(f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table")

return path


def parse_table_name(self, name: str) -> DbPath:
return parse_table_name(name)

Expand Down Expand Up @@ -407,6 +414,7 @@ class Postgres(ThreadedDatabase):
"decimal": Decimal,
"integer": Integer,
"numeric": Decimal,
"bigint": Integer,
}
ROUNDS_ON_PREC_LOSS = True

Expand Down Expand Up @@ -439,13 +447,14 @@ def md5_to_int(self, s: str) -> str:
def to_string(self, s: str):
return f"{s}::varchar"


def normalize_timestamp(self, value: str, coltype: ColType) -> str:
if coltype.rounds:
return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')"

timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')"
return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
return (
f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
)

def normalize_number(self, value: str, coltype: ColType) -> str:
return self.to_string(f"{value}::decimal(38, {coltype.precision})")
Expand Down Expand Up @@ -495,9 +504,7 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str:
else:
s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')"

return (
f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')"
)
return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')"

def normalize_number(self, value: str, coltype: ColType) -> str:
return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))")
Expand All @@ -510,7 +517,9 @@ def select_table_schema(self, path: DbPath) -> str:
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
)

def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None) -> ColType:
def _parse_type(
self, col_name: str, type_repr: str, datetime_precision: int = None, numeric_precision: int = None
) -> ColType:
regexps = {
r"timestamp\((\d)\)": Timestamp,
r"timestamp\((\d)\) with time zone": TimestampTZ,
Expand Down Expand Up @@ -600,7 +609,6 @@ def normalize_number(self, value: str, coltype: ColType) -> str:
return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))")



class Oracle(ThreadedDatabase):
ROUNDS_ON_PREC_LOSS = True

Expand Down Expand Up @@ -654,7 +662,12 @@ def normalize_number(self, value: str, coltype: ColType) -> str:
return f"to_char({value}, '{format_str}')"

def _parse_type(
self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None, numeric_scale: int = None
self,
col_name: str,
type_repr: str,
datetime_precision: int = None,
numeric_precision: int = None,
numeric_scale: int = None,
) -> ColType:
""" """
regexps = {
Expand Down Expand Up @@ -713,15 +726,18 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str:
us = f"extract(us from {timestamp})"
# epoch = Total time since epoch in microseconds.
epoch = f"{secs}*1000000 + {ms}*1000 + {us}"
timestamp6 = f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')"
timestamp6 = (
f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')"
)
else:
timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')"
return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
return (
f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
)

def normalize_number(self, value: str, coltype: ColType) -> str:
return self.to_string(f"{value}::decimal(38,{coltype.precision})")


def select_table_schema(self, path: DbPath) -> str:
schema, table = self._normalize_table_path(path)

Expand Down Expand Up @@ -831,7 +847,9 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str:
return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})"

timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})"
return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
return (
f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
)

def normalize_number(self, value: str, coltype: ColType) -> str:
if isinstance(coltype, Integer):
Expand Down
14 changes: 12 additions & 2 deletions data_diff/diff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from runtype import dataclass

from .sql import Select, Checksum, Compare, DbPath, DbKey, DbTime, Count, TableName, Time, Min, Max
from .database import Database, NumericType, PrecisionType, ColType
from .database import Database, NumericType, PrecisionType, ColType, UnknownColType

logger = logging.getLogger("diff_tables")

Expand Down Expand Up @@ -142,7 +142,8 @@ def with_schema(self) -> "TableSegment":
"Queries the table schema from the database, and returns a new instance of TableSegmentWithSchema."
if self._schema:
return self
schema = self.database.query_table_schema(self.table_path)

schema = self.database.query_table_schema(self.table_path, self._relevant_columns)
if self.case_sensitive:
schema = Schema_CaseSensitive(schema)
else:
Expand Down Expand Up @@ -381,6 +382,15 @@ def _validate_and_adjust_columns(self, table1, table2):
table1._schema[c] = col1.replace(precision=lowest.precision)
table2._schema[c] = col2.replace(precision=lowest.precision)

for t in [table1, table2]:
for c in t._relevant_columns:
ctype = t._schema[c]
if isinstance(ctype, UnknownColType):
logger.warn(
f"[{t.database.name}] Column '{c}' of type '{ctype.text}' has no compatibility handling. "
"If encoding/formatting differs between databases, it may result in false positives."
)

def _bisect_and_diff_tables(self, table1, table2, level=0, max_rows=None):
assert table1.is_bounded and table2.is_bounded

Expand Down
5 changes: 2 additions & 3 deletions tests/test_diff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,16 @@ def tearDownClass(cls):
cls.preql.close()
cls.connection.close()


# Fallback for test runners that doesn't support setUpClass/tearDownClass
def setUp(self) -> None:
if not hasattr(self, 'connection'):
if not hasattr(self, "connection"):
self.setUpClass.__func__(self)
self.private_connection = True

return super().setUp()

def tearDown(self) -> None:
if hasattr(self, 'private_connection'):
if hasattr(self, "private_connection"):
self.tearDownClass.__func__(self)

return super().tearDown()
Expand Down