diff --git a/google/cloud/spanner_dbapi/_helpers.py b/google/cloud/spanner_dbapi/_helpers.py index 2fcdd59137..785b051996 100644 --- a/google/cloud/spanner_dbapi/_helpers.py +++ b/google/cloud/spanner_dbapi/_helpers.py @@ -55,11 +55,20 @@ } -def _execute_insert_heterogenous(transaction, sql_params_list): +def _execute_insert_heterogenous( + transaction, sql_params_list, param_types=None +): for sql, params in sql_params_list: sql, params = sql_pyformat_args_to_spanner(sql, params) - param_types = get_param_types(params) - transaction.execute_update(sql, params=params, param_types=param_types) + + if param_types is None: + new_param_types = get_param_types(params) + else: + new_param_types = dict(zip(params.keys(), param_types)) + + transaction.execute_update( + sql, params=params, param_types=new_param_types + ) def _execute_insert_homogenous(transaction, parts): @@ -70,7 +79,7 @@ def _execute_insert_homogenous(transaction, parts): return transaction.insert(table, columns, values) -def handle_insert(connection, sql, params): +def handle_insert(connection, sql, params, param_types=None): parts = parse_insert(sql, params) # The split between the two styles exists because: @@ -89,13 +98,15 @@ def handle_insert(connection, sql, params): if parts.get("homogenous"): # The common case of multiple values being passed in # non-complex pyformat args and need to be uploaded in one RPC. - return connection.database.run_in_transaction(_execute_insert_homogenous, parts) + return connection.database.run_in_transaction( + _execute_insert_homogenous, parts + ) else: # All the other cases that are esoteric and need # transaction.execute_sql sql_params_list = parts.get("sql_params_list") return connection.database.run_in_transaction( - _execute_insert_heterogenous, sql_params_list + _execute_insert_heterogenous, sql_params_list, param_types ) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 363c2c653c..aba5eaa916 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -175,18 +175,40 @@ def execute(self, sql, args=None): sql, params = sql_pyformat_args_to_spanner(sql, args) statement = Statement( - sql, params, get_param_types(params), ResultsChecksum(), - ) - (self._result_set, self._checksum,) = self.connection.run_statement( - statement + sql, + params, + get_param_types(params), + ResultsChecksum(), ) + ( + self._result_set, + self._checksum, + ) = self.connection.run_statement(statement) self._itr = PeekIterator(self._result_set) return if classification == parse_utils.STMT_NON_UPDATING: self._handle_DQL(sql, args or None) elif classification == parse_utils.STMT_INSERT: - _helpers.handle_insert(self.connection, sql, args or None) + + # Read INFORMATION_SCHEMA.COLUMNS to get the Spanner types for + # the columns included in this insert + table_name, columns = parse_utils.get_table_cols_for_insert( + sql + ) + schema = self.get_table_column_schema(table_name) + column_type_names = [ + schema[col].spanner_type.split("(")[0] for col in columns + ] + param_types = [ + parse_utils.COL_TYPE_NAME_TO_TYPE[ctn] + for ctn in column_type_names + ] + + _helpers.handle_insert( + self.connection, sql, args or None, param_types=param_types + ) + else: self.connection.database.run_in_transaction( self._do_execute_update, sql, args or None @@ -353,7 +375,9 @@ def run_sql_in_snapshot(self, sql, params=None, param_types=None): self.connection.run_prior_DDL_statements() with self.connection.database.snapshot() as snapshot: - res = snapshot.execute_sql(sql, params=params, param_types=param_types) + res = snapshot.execute_sql( + sql, params=params, param_types=param_types + ) return list(res) def get_table_column_schema(self, table_name): diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index 8848233d45..0aeb373849 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -27,6 +27,8 @@ from .types import DateStr, TimestampStr from .utils import sanitize_literals_for_upload +# Map native to Spanner column types, for inferring types from SQL statements +# TODO (#566): Remove TYPES_MAP = { bool: spanner.param_types.BOOL, bytes: spanner.param_types.BYTES, @@ -39,6 +41,18 @@ TimestampStr: spanner.param_types.TIMESTAMP, } +# Map Spanner column type names to the actual types +COL_TYPE_NAME_TO_TYPE = { + "BOOL": spanner.param_types.BOOL, + "BYTES": spanner.param_types.BYTES, + "DATE": spanner.param_types.DATE, + "FLOAT64": spanner.param_types.FLOAT64, + "INT64": spanner.param_types.INT64, + "NUMERIC": spanner.param_types.NUMERIC, + "STRING": spanner.param_types.STRING, + "TIMESTAMP": spanner.param_types.TIMESTAMP, +} + SPANNER_RESERVED_KEYWORDS = { "ALL", "AND", @@ -338,6 +352,21 @@ def parse_insert(insert_sql, params): return {"sql_params_list": sql_param_tuples} +def get_table_cols_for_insert(insert_sql): + """Get table and column names from `insert_sql`.o + + :type insert_sql: str + :param params: A SQL INSERT statement + + :rtype: tuple[str, list[str]] + :returns: The table name and list of column names in the statement. + """ + gd = RE_INSERT.match(insert_sql).groupdict() + table_name = gd.get("table_name", "") + columns = [cn.strip() for cn in gd.get("columns", "").split(",")] + return table_name, columns + + def rows_for_insert_or_update(columns, params, pyformat_args=None): """ Create a tupled list of params to be used as a single value per @@ -507,17 +536,10 @@ def get_param_types(params): :rtype: :class:`dict` :returns: The types index for the given parameters. """ - if params is None: - return - - param_types = {} - - for key, value in params.items(): - type_ = type(value) - if type_ in TYPES_MAP: - param_types[key] = TYPES_MAP[type_] - - return param_types + if params is not None: + return { + key: TYPES_MAP.get(type(value)) for key, value in params.items() + } def ensure_where_clause(sql):