From 101707bb58c8c2514db6a1a8b115677470411fd4 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 22 Jun 2022 11:07:25 +0200 Subject: [PATCH 1/3] Fix: Only parse relevant columns. Only warn on relevant columns. --- data_diff/database.py | 25 +++++++++++++++---------- data_diff/diff_tables.py | 12 ++++++++++-- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/data_diff/database.py b/data_diff/database.py index c4d5edf4..79117f0e 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from runtype import dataclass import logging -from typing import Tuple, Optional, List +from typing import Sequence, Tuple, Optional, List from concurrent.futures import ThreadPoolExecutor import threading from typing import Dict @@ -131,10 +131,6 @@ def __post_init__(self): class UnknownColType(ColType): text: str - def __post_init__(self): - logger.warn(f"Column of type '{self.text}' has no compatibility handling. " - "If encoding/formatting differs between databases, it may result in false positives.") - class AbstractDatabase(ABC): @abstractmethod @@ -163,7 +159,7 @@ def select_table_schema(self, path: DbPath) -> str: ... @abstractmethod - def query_table_schema(self, path: DbPath) -> Dict[str, ColType]: + def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str]] = None) -> Dict[str, ColType]: "Query the table for its schema for table in 'path', and return {column: type}" ... @@ -241,6 +237,10 @@ class Database(AbstractDatabase): DATETIME_TYPES = {} default_schema = None + @property + def name(self): + return type(self).__name__ + def query(self, sql_ast: SqlOrStr, res_type: type): "Query the given SQL code/AST, and attempt to convert the result to type 'res_type'" @@ -321,12 +321,16 @@ def select_table_schema(self, path: DbPath) -> str: f"WHERE table_name = '{table}' AND table_schema = '{schema}'" ) - def query_table_schema(self, path: DbPath) -> Dict[str, ColType]: + def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str]] = None) -> Dict[str, ColType]: rows = self.query(self.select_table_schema(path), list) if not rows: - raise RuntimeError(f"{self.__class__.__name__}: Table '{'.'.join(path)}' does not exist, or has no columns") + raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") + + if filter_columns is not None: + accept = {i.lower() for i in filter_columns} + rows = [r for r in rows if r[0].lower() in accept] - # Return a dict of form {name: type} after canonizaation + # Return a dict of form {name: type} after normalization return {row[0]: self._parse_type(*row[1:]) for row in rows} # @lru_cache() @@ -339,7 +343,7 @@ def _normalize_table_path(self, path: DbPath) -> DbPath: return self.default_schema, path[0] elif len(path) != 2: raise ValueError( - f"{self.__class__.__name__}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table" + f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table" ) return path @@ -407,6 +411,7 @@ class Postgres(ThreadedDatabase): "decimal": Decimal, "integer": Integer, "numeric": Decimal, + "bigint": Integer, } ROUNDS_ON_PREC_LOSS = True diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 565d51a2..2c9a538c 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -12,7 +12,7 @@ from runtype import dataclass from .sql import Select, Checksum, Compare, DbPath, DbKey, DbTime, Count, TableName, Time, Min, Max -from .database import Database, NumericType, PrecisionType, ColType +from .database import Database, NumericType, PrecisionType, ColType, UnknownColType logger = logging.getLogger("diff_tables") @@ -142,7 +142,8 @@ def with_schema(self) -> "TableSegment": "Queries the table schema from the database, and returns a new instance of TableSegmentWithSchema." if self._schema: return self - schema = self.database.query_table_schema(self.table_path) + + schema = self.database.query_table_schema(self.table_path, self._relevant_columns) if self.case_sensitive: schema = Schema_CaseSensitive(schema) else: @@ -381,6 +382,13 @@ def _validate_and_adjust_columns(self, table1, table2): table1._schema[c] = col1.replace(precision=lowest.precision) table2._schema[c] = col2.replace(precision=lowest.precision) + for t in [table1, table2]: + for c in t._relevant_columns: + ctype = t._schema[c] + if isinstance(ctype, UnknownColType): + logger.warn(f"[{t.database.name}] Column '{c}' of type '{ctype.text}' has no compatibility handling. " + "If encoding/formatting differs between databases, it may result in false positives.") + def _bisect_and_diff_tables(self, table1, table2, level=0, max_rows=None): assert table1.is_bounded and table2.is_bounded From 6df9d37c7e80499fbb6aada1c9c7344f13f7d993 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 22 Jun 2022 11:09:53 +0200 Subject: [PATCH 2/3] A better error message --- data_diff/database.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/data_diff/database.py b/data_diff/database.py index 79117f0e..1c03d25c 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -301,6 +301,8 @@ def _parse_type( return cls(precision=0) elif issubclass(cls, Decimal): + if numeric_scale is None: + raise ValueError(f"{self.name}: Unexpected numeric_scale is NULL, for column of type {type_repr}.") return cls(precision=numeric_scale) assert issubclass(cls, Float) From 75e3e528e4f1423f06d697385a766e17af2b02b2 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 22 Jun 2022 12:04:58 +0200 Subject: [PATCH 3/3] Fixes for PR --- data_diff/database.py | 51 ++++++++++++++++++++++++--------------- data_diff/diff_tables.py | 6 +++-- tests/test_diff_tables.py | 5 ++-- 3 files changed, 37 insertions(+), 25 deletions(-) diff --git a/data_diff/database.py b/data_diff/database.py index 1c03d25c..fffbdbb8 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -173,7 +173,6 @@ def close(self): "Close connection(s) to the database instance. Querying will stop functioning." ... - @abstractmethod def normalize_timestamp(self, value: str, coltype: ColType) -> str: """Creates an SQL expression, that converts 'value' to a normalized timestamp. @@ -282,7 +281,12 @@ def _convert_db_precision_to_digits(self, p: int) -> int: return math.floor(math.log(2**p, 10)) def _parse_type( - self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None, numeric_scale: int = None + self, + col_name: str, + type_repr: str, + datetime_precision: int = None, + numeric_precision: int = None, + numeric_scale: int = None, ) -> ColType: """ """ @@ -302,7 +306,7 @@ def _parse_type( elif issubclass(cls, Decimal): if numeric_scale is None: - raise ValueError(f"{self.name}: Unexpected numeric_scale is NULL, for column of type {type_repr}.") + raise ValueError(f"{self.name}: Unexpected numeric_scale is NULL, for column {col_name} of type {type_repr}.") return cls(precision=numeric_scale) assert issubclass(cls, Float) @@ -333,7 +337,7 @@ def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str rows = [r for r in rows if r[0].lower() in accept] # Return a dict of form {name: type} after normalization - return {row[0]: self._parse_type(*row[1:]) for row in rows} + return {row[0]: self._parse_type(*row) for row in rows} # @lru_cache() # def get_table_schema(self, path: DbPath) -> Dict[str, ColType]: @@ -344,13 +348,10 @@ def _normalize_table_path(self, path: DbPath) -> DbPath: if self.default_schema: return self.default_schema, path[0] elif len(path) != 2: - raise ValueError( - f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table" - ) + raise ValueError(f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table") return path - def parse_table_name(self, name: str) -> DbPath: return parse_table_name(name) @@ -446,13 +447,14 @@ def md5_to_int(self, s: str) -> str: def to_string(self, s: str): return f"{s}::varchar" - def normalize_timestamp(self, value: str, coltype: ColType) -> str: if coltype.rounds: return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')" timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" - return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) def normalize_number(self, value: str, coltype: ColType) -> str: return self.to_string(f"{value}::decimal(38, {coltype.precision})") @@ -502,9 +504,7 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str: else: s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" - return ( - f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" - ) + return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" def normalize_number(self, value: str, coltype: ColType) -> str: return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))") @@ -517,7 +517,9 @@ def select_table_schema(self, path: DbPath) -> str: f"WHERE table_name = '{table}' AND table_schema = '{schema}'" ) - def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None) -> ColType: + def _parse_type( + self, col_name: str, type_repr: str, datetime_precision: int = None, numeric_precision: int = None + ) -> ColType: regexps = { r"timestamp\((\d)\)": Timestamp, r"timestamp\((\d)\) with time zone": TimestampTZ, @@ -607,7 +609,6 @@ def normalize_number(self, value: str, coltype: ColType) -> str: return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") - class Oracle(ThreadedDatabase): ROUNDS_ON_PREC_LOSS = True @@ -661,7 +662,12 @@ def normalize_number(self, value: str, coltype: ColType) -> str: return f"to_char({value}, '{format_str}')" def _parse_type( - self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None, numeric_scale: int = None + self, + col_name: str, + type_repr: str, + datetime_precision: int = None, + numeric_precision: int = None, + numeric_scale: int = None, ) -> ColType: """ """ regexps = { @@ -720,15 +726,18 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str: us = f"extract(us from {timestamp})" # epoch = Total time since epoch in microseconds. epoch = f"{secs}*1000000 + {ms}*1000 + {us}" - timestamp6 = f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')" + timestamp6 = ( + f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')" + ) else: timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" - return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) def normalize_number(self, value: str, coltype: ColType) -> str: return self.to_string(f"{value}::decimal(38,{coltype.precision})") - def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) @@ -838,7 +847,9 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str: return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" - return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) def normalize_number(self, value: str, coltype: ColType) -> str: if isinstance(coltype, Integer): diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 2c9a538c..7df04321 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -386,8 +386,10 @@ def _validate_and_adjust_columns(self, table1, table2): for c in t._relevant_columns: ctype = t._schema[c] if isinstance(ctype, UnknownColType): - logger.warn(f"[{t.database.name}] Column '{c}' of type '{ctype.text}' has no compatibility handling. " - "If encoding/formatting differs between databases, it may result in false positives.") + logger.warn( + f"[{t.database.name}] Column '{c}' of type '{ctype.text}' has no compatibility handling. " + "If encoding/formatting differs between databases, it may result in false positives." + ) def _bisect_and_diff_tables(self, table1, table2, level=0, max_rows=None): assert table1.is_bounded and table2.is_bounded diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index 84bb6b9a..3649b37f 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -32,17 +32,16 @@ def tearDownClass(cls): cls.preql.close() cls.connection.close() - # Fallback for test runners that doesn't support setUpClass/tearDownClass def setUp(self) -> None: - if not hasattr(self, 'connection'): + if not hasattr(self, "connection"): self.setUpClass.__func__(self) self.private_connection = True return super().setUp() def tearDown(self) -> None: - if hasattr(self, 'private_connection'): + if hasattr(self, "private_connection"): self.tearDownClass.__func__(self) return super().tearDown()