diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index 1514f9e3..e57d46a8 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -492,6 +492,26 @@ def post_create_table(self, table): return post_cmds + def visit_create_index( + self, create, include_schema=False, include_table_schema=True, **kw + ): + text = super().visit_create_index( + create, include_schema, include_table_schema, **kw + ) + index = create.element + if "spanner" in index.dialect_options: + options = index.dialect_options["spanner"] + if "storing" in options: + storing = options["storing"] + storing_columns = [ + index.table.c[col] if isinstance(col, str) else col + for col in storing + ] + text += " STORING (%s)" % ", ".join( + [self.preparer.quote(c.name) for c in storing_columns] + ) + return text + def get_identity_options(self, identity_options): text = ["sequence_kind = 'bit_reversed_positive'"] if identity_options.start is not None: @@ -997,15 +1017,35 @@ def get_multi_indexes( i.table_schema, i.table_name, i.index_name, - ARRAY_AGG(ic.column_name), + ( + SELECT ARRAY_AGG(ic.column_name) + FROM information_schema.index_columns ic + WHERE ic.index_name = i.index_name + AND ic.table_catalog = i.table_catalog + AND ic.table_schema = i.table_schema + AND ic.table_name = i.table_name + AND ic.column_ordering is not null + ) as columns, i.is_unique, - ARRAY_AGG(ic.column_ordering) + ( + SELECT ARRAY_AGG(ic.column_ordering) + FROM information_schema.index_columns ic + WHERE ic.index_name = i.index_name + AND ic.table_catalog = i.table_catalog + AND ic.table_schema = i.table_schema + AND ic.table_name = i.table_name + AND ic.column_ordering is not null + ) as column_orderings, + ( + SELECT ARRAY_AGG(storing.column_name) + FROM information_schema.index_columns storing + WHERE storing.index_name = i.index_name + AND storing.table_catalog = i.table_catalog + AND storing.table_schema = i.table_schema + AND storing.table_name = i.table_name + AND storing.column_ordering is null + ) as storing_columns, FROM information_schema.indexes as i - JOIN information_schema.index_columns AS ic - ON ic.index_name = i.index_name - AND ic.table_catalog = i.table_catalog - AND ic.table_schema = i.table_schema - AND ic.table_name = i.table_name JOIN information_schema.tables AS t ON i.table_catalog = t.table_catalog AND i.table_schema = t.table_schema @@ -1016,7 +1056,8 @@ def get_multi_indexes( {schema_filter_query} i.index_type != 'PRIMARY_KEY' AND i.spanner_is_managed = FALSE - GROUP BY i.table_schema, i.table_name, i.index_name, i.is_unique + GROUP BY i.table_catalog, i.table_schema, i.table_name, + i.index_name, i.is_unique ORDER BY i.index_name """.format( table_filter_query=table_filter_query, @@ -1029,13 +1070,19 @@ def get_multi_indexes( result_dict = {} for row in rows: + dialect_options = {} + include_columns = row[6] + if include_columns: + dialect_options["spanner_storing"] = include_columns index_info = { "name": row[2], "column_names": row[3], "unique": row[4], "column_sorting": { - col: order for col, order in zip(row[3], row[5]) + col: order.lower() for col, order in zip(row[3], row[5]) }, + "include_columns": include_columns if include_columns else [], + "dialect_options": dialect_options, } row[0] = row[0] or None table_info = result_dict.get((row[0], row[1]), []) diff --git a/test/system/test_basics.py b/test/system/test_basics.py index 6a53b234..3357104c 100644 --- a/test/system/test_basics.py +++ b/test/system/test_basics.py @@ -11,7 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from sqlalchemy import text, Table, Column, Integer, PrimaryKeyConstraint, String + +from sqlalchemy import ( + text, + Table, + Column, + Integer, + PrimaryKeyConstraint, + String, + Index, + MetaData, + Boolean, +) from sqlalchemy.testing import eq_ from sqlalchemy.testing.plugin.plugin_base import fixtures @@ -19,13 +30,21 @@ class TestBasics(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table( + numbers = Table( "numbers", metadata, Column("number", Integer), Column("name", String(20)), + Column("alternative_name", String(20)), + Column("prime", Boolean), PrimaryKeyConstraint("number"), ) + Index( + "idx_numbers_name", + numbers.c.name, + numbers.c.prime.desc(), + spanner_storing=[numbers.c.alternative_name], + ) def test_hello_world(self, connection): greeting = connection.execute(text("select 'Hello World'")) @@ -33,7 +52,25 @@ def test_hello_world(self, connection): def test_insert_number(self, connection): connection.execute( - text("insert or update into numbers(number, name) values (1, 'One')") + text( + """insert or update into numbers (number, name, prime) + values (1, 'One', false)""" + ) ) name = connection.execute(text("select name from numbers where number=1")) eq_("One", name.fetchone()[0]) + + def test_reflect(self, connection): + engine = connection.engine + meta: MetaData = MetaData() + meta.reflect(bind=engine) + eq_(1, len(meta.tables)) + table = meta.tables["numbers"] + eq_(1, len(table.indexes)) + index = next(iter(table.indexes)) + eq_(2, len(index.columns)) + eq_("name", index.columns[0].name) + eq_("prime", index.columns[1].name) + dialect_options = index.dialect_options["spanner"] + eq_(1, len(dialect_options["storing"])) + eq_("alternative_name", dialect_options["storing"][0]) diff --git a/test/test_suite_20.py b/test/test_suite_20.py index 4902b3ab..22b23e0a 100644 --- a/test/test_suite_20.py +++ b/test/test_suite_20.py @@ -1414,7 +1414,7 @@ def idx( "include_columns": [], } if column_sorting: - res["column_sorting"] = {"q": "DESC"} + res["column_sorting"] = {"q": "desc"} if duplicates: res["duplicates_constraint"] = name return [res] @@ -1458,11 +1458,11 @@ def idx( *idx( "q", name="noncol_idx_nopk", - column_sorting={"q": "DESC"}, + column_sorting={"q": "desc"}, ) ], (schema, "noncol_idx_test_pk"): [ - *idx("q", name="noncol_idx_pk", column_sorting={"q": "DESC"}) + *idx("q", name="noncol_idx_pk", column_sorting={"q": "desc"}) ], (schema, self.temp_table_name()): [ *idx("foo", name="user_tmp_ix"),