diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index f7a6558a..520b0691 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -176,6 +176,9 @@ def visit_BOOLEAN(self, type_, **kw): def visit_DATETIME(self, type_, **kw): return "TIMESTAMP" + def visit_BIGINT(self, type_, **kw): + return "INT64" + class SpannerDialect(DefaultDialect): """Cloud Spanner dialect. diff --git a/test/test_suite.py b/test/test_suite.py index d0c5313c..bebfd01b 100644 --- a/test/test_suite.py +++ b/test/test_suite.py @@ -16,18 +16,20 @@ import pytest -from sqlalchemy.testing import config +from sqlalchemy.testing import config, db from sqlalchemy.testing import eq_ from sqlalchemy.testing import provide_metadata from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy import literal_column -from sqlalchemy import select, case, bindparam + +from sqlalchemy import bindparam, case, literal, select, util from sqlalchemy import exists from sqlalchemy import Boolean from sqlalchemy import String -from sqlalchemy.testing import requires from sqlalchemy.types import Integer +from sqlalchemy.testing import requires + from google.api_core.datetime_helpers import DatetimeWithNanoseconds from sqlalchemy.testing.suite.test_ddl import * # noqa: F401, F403 @@ -44,9 +46,7 @@ from sqlalchemy.testing.suite.test_dialect import EscapingTest as _EscapingTest from sqlalchemy.testing.suite.test_select import ExistsTest as _ExistsTest from sqlalchemy.testing.suite.test_types import BooleanTest as _BooleanTest - -config.test_schema = "" - +from sqlalchemy.testing.suite.test_types import IntegerTest as _IntegerTest from sqlalchemy.testing.suite.test_types import ( # noqa: F401, F403 DateTest as _DateTest, @@ -59,6 +59,8 @@ TimestampMicrosecondsTest, ) +config.test_schema = "" + class EscapingTest(_EscapingTest): @provide_metadata @@ -422,3 +424,88 @@ class TimeTests(_TimeMicrosecondsTest, _TimeTest): @pytest.mark.skip("Spanner doesn't coerce dates from datetime.") class DateTimeCoercedToDateTimeTest(_DateTimeCoercedToDateTimeTest): pass + + +class IntegerTest(_IntegerTest): + @provide_metadata + def _round_trip(self, datatype, data): + """ + SPANNER OVERRIDE: + + This is the helper method for integer class tests which creates a table and + performs an insert operation. + Cloud Spanner supports tables with an empty primary key, but only one + row can be inserted into such a table - following insertions will fail with + `400 id must not be NULL in table date_table`. + Overriding the tests and adding a manual primary key value to avoid the same + failures and deleting the table at the end. + """ + metadata = self.metadata + int_table = Table( + "integer_table", + metadata, + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), + Column("integer_data", datatype), + ) + + metadata.create_all(config.db) + + config.db.execute(int_table.insert(), {"id": 1, "integer_data": data}) + + row = config.db.execute(select([int_table.c.integer_data])).first() + + eq_(row, (data,)) + + if util.py3k: + assert isinstance(row[0], int) + else: + assert isinstance(row[0], (long, int)) # noqa + + config.db.execute(int_table.delete()) + + @provide_metadata + def _literal_round_trip(self, type_, input_, output, filter_=None): + """ + SPANNER OVERRIDE: + + Spanner DBAPI does not execute DDL statements unless followed by a + non DDL statement, which is preventing correct table clean up. + The table already exists after related tests finish, so it doesn't + create a new table and when running tests for other data types + insertions will fail with `400 Duplicate name in schema: t`. + Overriding the tests to create and drop a new table to prevent + database existence errors. + """ + + # for literal, we test the literal render in an INSERT + # into a typed column. we can then SELECT it back as its + # official type; ideally we'd be able to use CAST here + # but MySQL in particular can't CAST fully + t = Table("int_t", self.metadata, Column("x", type_)) + t.create() + + with db.connect() as conn: + for value in input_: + ins = ( + t.insert() + .values(x=literal(value)) + .compile( + dialect=db.dialect, compile_kwargs=dict(literal_binds=True), + ) + ) + conn.execute(ins) + conn.execute("SELECT 1") + + if self.supports_whereclause: + stmt = t.select().where(t.c.x == literal(value)) + else: + stmt = t.select() + + stmt = stmt.compile( + dialect=db.dialect, compile_kwargs=dict(literal_binds=True), + ) + for row in conn.execute(stmt): + value = row[0] + if filter_ is not None: + value = filter_(value) + assert value in output