Skip to content

chore: Add support for named schema #858

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Mar 25, 2024
84 changes: 39 additions & 45 deletions django_spanner/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,14 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
TypeCode.NUMERIC: "DecimalField",
TypeCode.JSON: "JSONField",
}
if USE_EMULATOR:
# Emulator does not support table_type yet.
# https://github.com/GoogleCloudPlatform/cloud-spanner-emulator/issues/43
LIST_TABLE_SQL = """
SELECT
t.table_name, t.table_name
FROM
information_schema.tables AS t
WHERE
t.table_catalog = '' and t.table_schema = ''
"""
else:
LIST_TABLE_SQL = """
SELECT
t.table_name, t.table_type
FROM
information_schema.tables AS t
WHERE
t.table_catalog = '' and t.table_schema = ''
"""
LIST_TABLE_SQL = """
SELECT
t.table_name, t.table_type
FROM
information_schema.tables AS t
WHERE
t.table_catalog = '' and t.table_schema = @schema_name
"""

def get_field_type(self, data_type, description):
"""A hook for a Spanner database to use the cursor description to
Expand Down Expand Up @@ -76,7 +64,10 @@ def get_table_list(self, cursor):
:rtype: list
:returns: A list of table and view names in the current database.
"""
results = cursor.run_sql_in_snapshot(self.LIST_TABLE_SQL)
schema_name = self._get_schema_name(cursor)
results = cursor.run_sql_in_snapshot(
self.LIST_TABLE_SQL, params={"schema_name": schema_name}
)
tables = []
# The second TableInfo field is 't' for table or 'v' for view.
for row in results:
Expand Down Expand Up @@ -159,8 +150,9 @@ def get_relations(self, cursor, table_name):
:rtype: dict
:returns: A dictionary representing column relationships to other tables.
"""
schema_name = self._get_schema_name(cursor)
results = cursor.run_sql_in_snapshot(
'''
"""
SELECT
tc.COLUMN_NAME as col, ccu.COLUMN_NAME as ref_col, ccu.TABLE_NAME as ref_table
FROM
Expand All @@ -174,8 +166,9 @@ def get_relations(self, cursor, table_name):
ON
rc.UNIQUE_CONSTRAINT_NAME = ccu.CONSTRAINT_NAME
WHERE
tc.TABLE_NAME="%s"'''
% self.connection.ops.quote_name(table_name)
tc.TABLE_SCHEMA=@schema_name AND tc.TABLE_NAME=@view_name
""",
params={"schema_name": schema_name, "view_name": table_name},
)
return {
column: (referred_column, referred_table)
Expand All @@ -194,6 +187,7 @@ def get_primary_key_column(self, cursor, table_name):
:rtype: str
:returns: The name of the PK column.
"""
schema_name = self._get_schema_name(cursor)
results = cursor.run_sql_in_snapshot(
"""
SELECT
Expand All @@ -205,9 +199,9 @@ def get_primary_key_column(self, cursor, table_name):
AS
ccu ON tc.CONSTRAINT_NAME = ccu.CONSTRAINT_NAME
WHERE
tc.TABLE_NAME="%s" AND tc.CONSTRAINT_TYPE='PRIMARY KEY' AND tc.TABLE_SCHEMA=''
"""
% self.connection.ops.quote_name(table_name)
tc.TABLE_NAME=@table_name AND tc.CONSTRAINT_TYPE='PRIMARY KEY' AND tc.TABLE_SCHEMA=@schema_name
""",
params={"schema_name": schema_name, "table_name": table_name},
)
return results[0][0] if results else None

Expand All @@ -224,18 +218,17 @@ def get_constraints(self, cursor, table_name):
:returns: A dictionary with constraints.
"""
constraints = {}
quoted_table_name = self.connection.ops.quote_name(table_name)
schema_name = self._get_schema_name(cursor)

# Firstly populate all available constraints and their columns.
constraint_columns = cursor.run_sql_in_snapshot(
'''
"""
SELECT
CONSTRAINT_NAME, COLUMN_NAME
FROM
INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE
WHERE TABLE_NAME="{table}"'''.format(
table=quoted_table_name
)
WHERE TABLE_NAME=@table AND TABLE_SCHEMA=@schema_name""",
params={"table": table_name, "schema_name": schema_name},
)
for constraint, column_name in constraint_columns:
if constraint not in constraints:
Expand All @@ -254,15 +247,14 @@ def get_constraints(self, cursor, table_name):

# Add the various constraints by type.
constraint_types = cursor.run_sql_in_snapshot(
'''
"""
SELECT
CONSTRAINT_NAME, CONSTRAINT_TYPE
FROM
INFORMATION_SCHEMA.TABLE_CONSTRAINTS
WHERE
TABLE_NAME="{table}"'''.format(
table=quoted_table_name
)
TABLE_NAME=@table AND TABLE_SCHEMA=@schema_name""",
params={"table": table_name, "schema_name": schema_name},
)
for constraint, constraint_type in constraint_types:
already_added = constraint in constraints
Expand Down Expand Up @@ -303,14 +295,13 @@ def get_constraints(self, cursor, table_name):
RIGHT JOIN
INFORMATION_SCHEMA.INDEX_COLUMNS AS idx_col
ON
idx_col.INDEX_NAME = idx.INDEX_NAME AND idx_col.TABLE_NAME="{table}"
idx_col.INDEX_NAME = idx.INDEX_NAME AND idx_col.TABLE_NAME=@table
WHERE
idx.TABLE_NAME="{table}"
idx.TABLE_NAME=@table AND idx.TABLE_SCHEMA=@schema_name
ORDER BY
idx_col.ORDINAL_POSITION
""".format(
table=quoted_table_name
)
""",
params={"table": table_name, "schema_name": schema_name},
)
for (
index_name,
Expand Down Expand Up @@ -350,6 +341,7 @@ def get_key_columns(self, cursor, table_name):
for all key columns in the given table.
"""
key_columns = []
schema_name = self._get_schema_name(cursor)
cursor.execute(
"""SELECT
tc.COLUMN_NAME as column_name,
Expand All @@ -366,10 +358,12 @@ def get_key_columns(self, cursor, table_name):
ON
rc.CONSTRAINT_NAME = ccu.CONSTRAINT_NAME
WHERE
tc.TABLE_NAME="{table}"
""".format(
table=self.connection.ops.quote_name(table_name)
)
tc.TABLE_NAME=@table AND tc.TABLE_SCHEMA=@schema_name
""",
params={"table": table_name, "schema_name": schema_name},
)
key_columns.extend(cursor.fetchall())
return key_columns

def _get_schema_name(self, cursor):
return cursor.connection.current_schema