diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index 54282876..3454c95f 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -137,6 +137,32 @@ def visit_empty_set_expr(self, type_): _type_map_inv[type(type_[0])] ) + def render_literal_value(self, value, type_): + """Render the value of a bind parameter as a quoted literal. + + This is used for statement sections that do not accept bind parameters + on the target driver/database. + + This should be implemented by subclasses using the quoting services + of the DBAPI. + + Cloud spanner supports prefixed backslash to escape non-alphanumeric characters + in string. Override the method to add additional escape before using it to + generate a SQL statement. + """ + raw = ["\\", "'", '"', "\n", "\t", "\r"] + if type(value) == str and any(single in value for single in raw): + value = 'r"""{}"""'.format(value) + return value + else: + processor = type_._cached_literal_processor(self.dialect) + if processor: + return processor(value) + else: + raise NotImplementedError( + "Don't know how to literal-quote value %r" % value + ) + class SpannerDDLCompiler(DDLCompiler): """Spanner DDL statements compiler.""" @@ -183,7 +209,7 @@ def visit_primary_key_constraint(self, constraint): return None def visit_unique_constraint(self, constraint): - """Unique contraints in Spanner are defined with indexes: + """Unique constraints in Spanner are defined with indexes: https://cloud.google.com/spanner/docs/secondary-indexes#unique-indexes The method throws an exception to notify user that in diff --git a/test/test_suite.py b/test/test_suite.py index 3e2d5330..b0d48f11 100644 --- a/test/test_suite.py +++ b/test/test_suite.py @@ -86,6 +86,8 @@ from sqlalchemy.testing.suite.test_select import OrderByLabelTest as _OrderByLabelTest from sqlalchemy.testing.suite.test_types import BooleanTest as _BooleanTest from sqlalchemy.testing.suite.test_types import IntegerTest as _IntegerTest +from sqlalchemy.testing.suite.test_types import StringTest as _StringTest + from sqlalchemy.testing.suite.test_types import _LiteralRoundTripFixture from sqlalchemy.testing.suite.test_types import ( # noqa: F401, F403 @@ -535,7 +537,6 @@ def _literal_round_trip(self, type_, input_, output, filter_=None): ) ) conn.execute(ins) - conn.execute("SELECT 1") if self.supports_whereclause: stmt = t.select().where(t.c.x == literal(value)) @@ -843,3 +844,9 @@ def test_nolength_binary(self): @pytest.mark.skip("Spanner doesn't support quotes in table names.") class QuotedNameArgumentTest(_QuotedNameArgumentTest): pass + + +class StringTest(_StringTest): + @pytest.mark.skip("Spanner doesn't support non-ascii characters") + def test_literal_non_ascii(self): + pass