diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 8adea0b56..56e64076d 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -1046,6 +1046,7 @@ def query_table_schema(self, path: DbPath) -> Dict[str, RawColumnInfo]: accessing the schema using a SQL query. """ rows = self.query(self.select_table_schema(path), list, log_message=path) + if not rows: raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") @@ -1060,6 +1061,7 @@ def query_table_schema(self, path: DbPath) -> Dict[str, RawColumnInfo]: ) for r in rows } + assert len(d) == len(rows) return d diff --git a/data_diff/databases/duckdb.py b/data_diff/databases/duckdb.py index 4cdacde87..9537ce50e 100644 --- a/data_diff/databases/duckdb.py +++ b/data_diff/databases/duckdb.py @@ -167,12 +167,16 @@ def select_table_schema(self, path: DbPath) -> str: database, schema, table = self._normalize_table_path(path) info_schema_path = ["information_schema", "columns"] + if database: info_schema_path.insert(0, database) + dynamic_database_clause = f"'{database}'" + else: + dynamic_database_clause = "current_catalog()" return ( f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} " - f"WHERE table_name = '{table}' AND table_schema = '{schema}'" + f"WHERE table_name = '{table}' AND table_schema = '{schema}' and table_catalog = {dynamic_database_clause}" ) def _normalize_table_path(self, path: DbPath) -> DbPath: diff --git a/tests/test_duckdb.py b/tests/test_duckdb.py new file mode 100644 index 000000000..10c3644e3 --- /dev/null +++ b/tests/test_duckdb.py @@ -0,0 +1,42 @@ +import unittest +from data_diff.databases import duckdb as duckdb_differ +import os +import uuid + +test_duckdb_filepath = str(uuid.uuid4()) + ".duckdb" + + +class TestDuckDBTableSchemaMethods(unittest.TestCase): + def setUp(self): + # Create a new duckdb file + self.duckdb_conn = duckdb_differ.DuckDB(filepath=test_duckdb_filepath) + + def tearDown(self): + # Optional: delete file after tests + os.remove(test_duckdb_filepath) + + def test_normalize_table_path(self): + self.assertEqual(self.duckdb_conn._normalize_table_path(("test_table",)), (None, "main", "test_table")) + self.assertEqual( + self.duckdb_conn._normalize_table_path(("test_schema", "test_table")), (None, "test_schema", "test_table") + ) + self.assertEqual( + self.duckdb_conn._normalize_table_path(("test_database", "test_schema", "test_table")), + ("test_database", "test_schema", "test_table"), + ) + + with self.assertRaises(ValueError): + self.duckdb_conn._normalize_table_path(("test_database", "test_schema", "test_table", "extra")) + + def test_select_table_schema(self): + db_path = ("test_table",) + expected_sql = "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM information_schema.columns WHERE table_name = 'test_table' AND table_schema = 'main' and table_catalog = current_catalog()" + self.assertEqual(self.duckdb_conn.select_table_schema(db_path), expected_sql) + + db_path = ("custom_schema", "test_table") + expected_sql = "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM information_schema.columns WHERE table_name = 'test_table' AND table_schema = 'custom_schema' and table_catalog = current_catalog()" + self.assertEqual(self.duckdb_conn.select_table_schema(db_path), expected_sql) + + db_path = ("custom_db", "custom_schema", "test_table") + expected_sql = "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM custom_db.information_schema.columns WHERE table_name = 'test_table' AND table_schema = 'custom_schema' and table_catalog = 'custom_db'" + self.assertEqual(self.duckdb_conn.select_table_schema(db_path), expected_sql)