Skip to content

fix: Introspect to get column types in cursor.execute #200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions google/cloud/spanner_dbapi/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,20 @@
}


def _execute_insert_heterogenous(transaction, sql_params_list):
def _execute_insert_heterogenous(
transaction, sql_params_list, param_types=None
):
for sql, params in sql_params_list:
sql, params = sql_pyformat_args_to_spanner(sql, params)
param_types = get_param_types(params)
transaction.execute_update(sql, params=params, param_types=param_types)

if param_types is None:
new_param_types = get_param_types(params)
else:
new_param_types = dict(zip(params.keys(), param_types))

transaction.execute_update(
sql, params=params, param_types=new_param_types
)


def _execute_insert_homogenous(transaction, parts):
Expand All @@ -70,7 +79,7 @@ def _execute_insert_homogenous(transaction, parts):
return transaction.insert(table, columns, values)


def handle_insert(connection, sql, params):
def handle_insert(connection, sql, params, param_types=None):
parts = parse_insert(sql, params)

# The split between the two styles exists because:
Expand All @@ -89,13 +98,15 @@ def handle_insert(connection, sql, params):
if parts.get("homogenous"):
# The common case of multiple values being passed in
# non-complex pyformat args and need to be uploaded in one RPC.
return connection.database.run_in_transaction(_execute_insert_homogenous, parts)
return connection.database.run_in_transaction(
_execute_insert_homogenous, parts
)
else:
# All the other cases that are esoteric and need
# transaction.execute_sql
sql_params_list = parts.get("sql_params_list")
return connection.database.run_in_transaction(
_execute_insert_heterogenous, sql_params_list
_execute_insert_heterogenous, sql_params_list, param_types
)


Expand Down
36 changes: 30 additions & 6 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,18 +175,40 @@ def execute(self, sql, args=None):
sql, params = sql_pyformat_args_to_spanner(sql, args)

statement = Statement(
sql, params, get_param_types(params), ResultsChecksum(),
)
(self._result_set, self._checksum,) = self.connection.run_statement(
statement
sql,
params,
get_param_types(params),
ResultsChecksum(),
)
(
self._result_set,
self._checksum,
) = self.connection.run_statement(statement)
self._itr = PeekIterator(self._result_set)
return

if classification == parse_utils.STMT_NON_UPDATING:
self._handle_DQL(sql, args or None)
elif classification == parse_utils.STMT_INSERT:
_helpers.handle_insert(self.connection, sql, args or None)

# Read INFORMATION_SCHEMA.COLUMNS to get the Spanner types for
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would strongly recommend that you only do this as a workaround when running against the emulator, as real Spanner seems to accept the parameter type to be missing. Otherwise you will be slowing down statements unnecessarily in production environments.
Furthermore, you probably need to do this for all types of statements on the emulator, not only insert statements.

# the columns included in this insert
table_name, columns = parse_utils.get_table_cols_for_insert(
sql
)
schema = self.get_table_column_schema(table_name)
column_type_names = [
schema[col].spanner_type.split("(")[0] for col in columns
]
param_types = [
parse_utils.COL_TYPE_NAME_TO_TYPE[ctn]
for ctn in column_type_names
]

_helpers.handle_insert(
self.connection, sql, args or None, param_types=param_types
)

else:
self.connection.database.run_in_transaction(
self._do_execute_update, sql, args or None
Expand Down Expand Up @@ -353,7 +375,9 @@ def run_sql_in_snapshot(self, sql, params=None, param_types=None):
self.connection.run_prior_DDL_statements()

with self.connection.database.snapshot() as snapshot:
res = snapshot.execute_sql(sql, params=params, param_types=param_types)
res = snapshot.execute_sql(
sql, params=params, param_types=param_types
)
return list(res)

def get_table_column_schema(self, table_name):
Expand Down
44 changes: 33 additions & 11 deletions google/cloud/spanner_dbapi/parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from .types import DateStr, TimestampStr
from .utils import sanitize_literals_for_upload

# Map native to Spanner column types, for inferring types from SQL statements
# TODO (#566): Remove
TYPES_MAP = {
bool: spanner.param_types.BOOL,
bytes: spanner.param_types.BYTES,
Expand All @@ -39,6 +41,18 @@
TimestampStr: spanner.param_types.TIMESTAMP,
}

# Map Spanner column type names to the actual types
COL_TYPE_NAME_TO_TYPE = {
"BOOL": spanner.param_types.BOOL,
"BYTES": spanner.param_types.BYTES,
"DATE": spanner.param_types.DATE,
"FLOAT64": spanner.param_types.FLOAT64,
"INT64": spanner.param_types.INT64,
"NUMERIC": spanner.param_types.NUMERIC,
"STRING": spanner.param_types.STRING,
"TIMESTAMP": spanner.param_types.TIMESTAMP,
}

SPANNER_RESERVED_KEYWORDS = {
"ALL",
"AND",
Expand Down Expand Up @@ -338,6 +352,21 @@ def parse_insert(insert_sql, params):
return {"sql_params_list": sql_param_tuples}


def get_table_cols_for_insert(insert_sql):
"""Get table and column names from `insert_sql`.o

:type insert_sql: str
:param params: A SQL INSERT statement

:rtype: tuple[str, list[str]]
:returns: The table name and list of column names in the statement.
"""
gd = RE_INSERT.match(insert_sql).groupdict()
table_name = gd.get("table_name", "")
columns = [cn.strip() for cn in gd.get("columns", "").split(",")]
return table_name, columns


def rows_for_insert_or_update(columns, params, pyformat_args=None):
"""
Create a tupled list of params to be used as a single value per
Expand Down Expand Up @@ -507,17 +536,10 @@ def get_param_types(params):
:rtype: :class:`dict`
:returns: The types index for the given parameters.
"""
if params is None:
return

param_types = {}

for key, value in params.items():
type_ = type(value)
if type_ in TYPES_MAP:
param_types[key] = TYPES_MAP[type_]

return param_types
if params is not None:
return {
key: TYPES_MAP.get(type(value)) for key, value in params.items()
}


def ensure_where_clause(sql):
Expand Down