|
20 | 20 | import warnings
|
21 | 21 | import pytest
|
22 | 22 |
|
| 23 | +from google.cloud.spanner_admin_database_v1 import DatabaseDialect |
23 | 24 | from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode
|
24 | 25 | from google.cloud.spanner_dbapi.exceptions import (
|
25 | 26 | InterfaceError,
|
@@ -58,14 +59,16 @@ def _get_client_info(self):
|
58 | 59 |
|
59 | 60 | return ClientInfo(user_agent=USER_AGENT)
|
60 | 61 |
|
61 |
| - def _make_connection(self, **kwargs): |
| 62 | + def _make_connection( |
| 63 | + self, database_dialect=DatabaseDialect.DATABASE_DIALECT_UNSPECIFIED, **kwargs |
| 64 | + ): |
62 | 65 | from google.cloud.spanner_v1.instance import Instance
|
63 | 66 | from google.cloud.spanner_v1.client import Client
|
64 | 67 |
|
65 | 68 | # We don't need a real Client object to test the constructor
|
66 | 69 | client = Client()
|
67 | 70 | instance = Instance(INSTANCE, client=client)
|
68 |
| - database = instance.database(DATABASE) |
| 71 | + database = instance.database(DATABASE, database_dialect=database_dialect) |
69 | 72 | return Connection(instance, database, **kwargs)
|
70 | 73 |
|
71 | 74 | @mock.patch("google.cloud.spanner_dbapi.connection.Connection.commit")
|
@@ -105,6 +108,22 @@ def test_property_instance(self):
|
105 | 108 | self.assertIsInstance(connection.instance, Instance)
|
106 | 109 | self.assertEqual(connection.instance, connection._instance)
|
107 | 110 |
|
| 111 | + def test_property_current_schema_google_sql_dialect(self): |
| 112 | + from google.cloud.spanner_v1.database import Database |
| 113 | + |
| 114 | + connection = self._make_connection( |
| 115 | + database_dialect=DatabaseDialect.GOOGLE_STANDARD_SQL |
| 116 | + ) |
| 117 | + self.assertIsInstance(connection.database, Database) |
| 118 | + self.assertEqual(connection.current_schema, "") |
| 119 | + |
| 120 | + def test_property_current_schema_postgres_sql_dialect(self): |
| 121 | + from google.cloud.spanner_v1.database import Database |
| 122 | + |
| 123 | + connection = self._make_connection(database_dialect=DatabaseDialect.POSTGRESQL) |
| 124 | + self.assertIsInstance(connection.database, Database) |
| 125 | + self.assertEqual(connection.current_schema, "public") |
| 126 | + |
108 | 127 | def test_read_only_connection(self):
|
109 | 128 | connection = self._make_connection(read_only=True)
|
110 | 129 | self.assertTrue(connection.read_only)
|
@@ -745,11 +764,22 @@ def __init__(self, name="instance_id", client=None):
|
745 | 764 | self.name = name
|
746 | 765 | self._client = client
|
747 | 766 |
|
748 |
| - def database(self, database_id="database_id", pool=None): |
749 |
| - return _Database(database_id, pool) |
| 767 | + def database( |
| 768 | + self, |
| 769 | + database_id="database_id", |
| 770 | + pool=None, |
| 771 | + database_dialect=DatabaseDialect.GOOGLE_STANDARD_SQL, |
| 772 | + ): |
| 773 | + return _Database(database_id, pool, database_dialect) |
750 | 774 |
|
751 | 775 |
|
752 | 776 | class _Database(object):
|
753 |
| - def __init__(self, database_id="database_id", pool=None): |
| 777 | + def __init__( |
| 778 | + self, |
| 779 | + database_id="database_id", |
| 780 | + pool=None, |
| 781 | + database_dialect=DatabaseDialect.GOOGLE_STANDARD_SQL, |
| 782 | + ): |
754 | 783 | self.name = database_id
|
755 | 784 | self.pool = pool
|
| 785 | + self.database_dialect = database_dialect |
0 commit comments