5
5
from abc import ABC , abstractmethod
6
6
from runtype import dataclass
7
7
import logging
8
- from typing import Tuple , Optional , List
8
+ from typing import Sequence , Tuple , Optional , List
9
9
from concurrent .futures import ThreadPoolExecutor
10
10
import threading
11
11
from typing import Dict
@@ -159,10 +159,6 @@ def __post_init__(self):
159
159
class UnknownColType (ColType ):
160
160
text : str
161
161
162
- def __post_init__ (self ):
163
- logger .warn (f"Column of type '{ self .text } ' has no compatibility handling. "
164
- "If encoding/formatting differs between databases, it may result in false positives." )
165
-
166
162
167
163
class AbstractDatabase (ABC ):
168
164
@abstractmethod
@@ -191,7 +187,7 @@ def select_table_schema(self, path: DbPath) -> str:
191
187
...
192
188
193
189
@abstractmethod
194
- def query_table_schema (self , path : DbPath ) -> Dict [str , ColType ]:
190
+ def query_table_schema (self , path : DbPath , filter_columns : Optional [ Sequence [ str ]] = None ) -> Dict [str , ColType ]:
195
191
"Query the table for its schema for table in 'path', and return {column: type}"
196
192
...
197
193
@@ -205,7 +201,6 @@ def close(self):
205
201
"Close connection(s) to the database instance. Querying will stop functioning."
206
202
...
207
203
208
-
209
204
@abstractmethod
210
205
def normalize_timestamp (self , value : str , coltype : ColType ) -> str :
211
206
"""Creates an SQL expression, that converts 'value' to a normalized timestamp.
@@ -269,6 +264,10 @@ class Database(AbstractDatabase):
269
264
DATETIME_TYPES = {}
270
265
default_schema = None
271
266
267
+ @property
268
+ def name (self ):
269
+ return type (self ).__name__
270
+
272
271
def query (self , sql_ast : SqlOrStr , res_type : type ):
273
272
"Query the given SQL code/AST, and attempt to convert the result to type 'res_type'"
274
273
@@ -310,7 +309,12 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
310
309
return math .floor (math .log (2 ** p , 10 ))
311
310
312
311
def _parse_type (
313
- self , type_repr : str , datetime_precision : int = None , numeric_precision : int = None , numeric_scale : int = None
312
+ self ,
313
+ col_name : str ,
314
+ type_repr : str ,
315
+ datetime_precision : int = None ,
316
+ numeric_precision : int = None ,
317
+ numeric_scale : int = None ,
314
318
) -> ColType :
315
319
""" """
316
320
@@ -329,6 +333,8 @@ def _parse_type(
329
333
return cls (precision = 0 )
330
334
331
335
elif issubclass (cls , Decimal ):
336
+ if numeric_scale is None :
337
+ raise ValueError (f"{ self .name } : Unexpected numeric_scale is NULL, for column { col_name } of type { type_repr } ." )
332
338
return cls (precision = numeric_scale )
333
339
334
340
assert issubclass (cls , Float )
@@ -349,13 +355,17 @@ def select_table_schema(self, path: DbPath) -> str:
349
355
f"WHERE table_name = '{ table } ' AND table_schema = '{ schema } '"
350
356
)
351
357
352
- def query_table_schema (self , path : DbPath ) -> Dict [str , ColType ]:
358
+ def query_table_schema (self , path : DbPath , filter_columns : Optional [ Sequence [ str ]] = None ) -> Dict [str , ColType ]:
353
359
rows = self .query (self .select_table_schema (path ), list )
354
360
if not rows :
355
- raise RuntimeError (f"{ self .__class__ .__name__ } : Table '{ '.' .join (path )} ' does not exist, or has no columns" )
361
+ raise RuntimeError (f"{ self .name } : Table '{ '.' .join (path )} ' does not exist, or has no columns" )
362
+
363
+ if filter_columns is not None :
364
+ accept = {i .lower () for i in filter_columns }
365
+ rows = [r for r in rows if r [0 ].lower () in accept ]
356
366
357
- # Return a dict of form {name: type} after canonizaation
358
- return {row [0 ]: self ._parse_type (* row [ 1 :] ) for row in rows }
367
+ # Return a dict of form {name: type} after normalization
368
+ return {row [0 ]: self ._parse_type (* row ) for row in rows }
359
369
360
370
# @lru_cache()
361
371
# def get_table_schema(self, path: DbPath) -> Dict[str, ColType]:
@@ -366,9 +376,7 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
366
376
if self .default_schema :
367
377
return self .default_schema , path [0 ]
368
378
elif len (path ) != 2 :
369
- raise ValueError (
370
- f"{ self .__class__ .__name__ } : Bad table path for { self } : '{ '.' .join (path )} '. Expected form: schema.table"
371
- )
379
+ raise ValueError (f"{ self .name } : Bad table path for { self } : '{ '.' .join (path )} '. Expected form: schema.table" )
372
380
373
381
return path
374
382
@@ -440,6 +448,7 @@ class PostgreSQL(ThreadedDatabase):
440
448
"decimal" : Decimal ,
441
449
"integer" : Integer ,
442
450
"numeric" : Decimal ,
451
+ "bigint" : Integer ,
443
452
}
444
453
ROUNDS_ON_PREC_LOSS = True
445
454
@@ -472,13 +481,14 @@ def md5_to_int(self, s: str) -> str:
472
481
def to_string (self , s : str ):
473
482
return f"{ s } ::varchar"
474
483
475
-
476
484
def normalize_timestamp (self , value : str , coltype : ColType ) -> str :
477
485
if coltype .rounds :
478
486
return f"to_char({ value } ::timestamp({ coltype .precision } ), 'YYYY-mm-dd HH24:MI:SS.US')"
479
487
480
488
timestamp6 = f"to_char({ value } ::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')"
481
- return f"RPAD(LEFT({ timestamp6 } , { TIMESTAMP_PRECISION_POS + coltype .precision } ), { TIMESTAMP_PRECISION_POS + 6 } , '0')"
489
+ return (
490
+ f"RPAD(LEFT({ timestamp6 } , { TIMESTAMP_PRECISION_POS + coltype .precision } ), { TIMESTAMP_PRECISION_POS + 6 } , '0')"
491
+ )
482
492
483
493
def normalize_number (self , value : str , coltype : ColType ) -> str :
484
494
return self .to_string (f"{ value } ::decimal(38, { coltype .precision } )" )
@@ -528,9 +538,7 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str:
528
538
else :
529
539
s = f"date_format(cast({ value } as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')"
530
540
531
- return (
532
- f"RPAD(RPAD({ s } , { TIMESTAMP_PRECISION_POS + coltype .precision } , '.'), { TIMESTAMP_PRECISION_POS + 6 } , '0')"
533
- )
541
+ return f"RPAD(RPAD({ s } , { TIMESTAMP_PRECISION_POS + coltype .precision } , '.'), { TIMESTAMP_PRECISION_POS + 6 } , '0')"
534
542
535
543
def normalize_number (self , value : str , coltype : ColType ) -> str :
536
544
return self .to_string (f"cast({ value } as decimal(38,{ coltype .precision } ))" )
@@ -543,7 +551,9 @@ def select_table_schema(self, path: DbPath) -> str:
543
551
f"WHERE table_name = '{ table } ' AND table_schema = '{ schema } '"
544
552
)
545
553
546
- def _parse_type (self , type_repr : str , datetime_precision : int = None , numeric_precision : int = None ) -> ColType :
554
+ def _parse_type (
555
+ self , col_name : str , type_repr : str , datetime_precision : int = None , numeric_precision : int = None
556
+ ) -> ColType :
547
557
regexps = {
548
558
r"timestamp\((\d)\)" : Timestamp ,
549
559
r"timestamp\((\d)\) with time zone" : TimestampTZ ,
@@ -633,7 +643,6 @@ def normalize_number(self, value: str, coltype: ColType) -> str:
633
643
return self .to_string (f"cast({ value } as decimal(38, { coltype .precision } ))" )
634
644
635
645
636
-
637
646
class Oracle (ThreadedDatabase ):
638
647
ROUNDS_ON_PREC_LOSS = True
639
648
@@ -687,7 +696,12 @@ def normalize_number(self, value: str, coltype: ColType) -> str:
687
696
return f"to_char({ value } , '{ format_str } ')"
688
697
689
698
def _parse_type (
690
- self , type_repr : str , datetime_precision : int = None , numeric_precision : int = None , numeric_scale : int = None
699
+ self ,
700
+ col_name : str ,
701
+ type_repr : str ,
702
+ datetime_precision : int = None ,
703
+ numeric_precision : int = None ,
704
+ numeric_scale : int = None ,
691
705
) -> ColType :
692
706
""" """
693
707
regexps = {
@@ -746,15 +760,18 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str:
746
760
us = f"extract(us from { timestamp } )"
747
761
# epoch = Total time since epoch in microseconds.
748
762
epoch = f"{ secs } *1000000 + { ms } *1000 + { us } "
749
- timestamp6 = f"to_char({ epoch } , -6+{ coltype .precision } ) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')"
763
+ timestamp6 = (
764
+ f"to_char({ epoch } , -6+{ coltype .precision } ) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')"
765
+ )
750
766
else :
751
767
timestamp6 = f"to_char({ value } ::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')"
752
- return f"RPAD(LEFT({ timestamp6 } , { TIMESTAMP_PRECISION_POS + coltype .precision } ), { TIMESTAMP_PRECISION_POS + 6 } , '0')"
768
+ return (
769
+ f"RPAD(LEFT({ timestamp6 } , { TIMESTAMP_PRECISION_POS + coltype .precision } ), { TIMESTAMP_PRECISION_POS + 6 } , '0')"
770
+ )
753
771
754
772
def normalize_number (self , value : str , coltype : ColType ) -> str :
755
773
return self .to_string (f"{ value } ::decimal(38,{ coltype .precision } )" )
756
774
757
-
758
775
def select_table_schema (self , path : DbPath ) -> str :
759
776
schema , table = self ._normalize_table_path (path )
760
777
@@ -864,7 +881,9 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str:
864
881
return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', { value } )"
865
882
866
883
timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', { value } )"
867
- return f"RPAD(LEFT({ timestamp6 } , { TIMESTAMP_PRECISION_POS + coltype .precision } ), { TIMESTAMP_PRECISION_POS + 6 } , '0')"
884
+ return (
885
+ f"RPAD(LEFT({ timestamp6 } , { TIMESTAMP_PRECISION_POS + coltype .precision } ), { TIMESTAMP_PRECISION_POS + 6 } , '0')"
886
+ )
868
887
869
888
def normalize_number (self , value : str , coltype : ColType ) -> str :
870
889
if isinstance (coltype , Integer ):
0 commit comments