Skip to content

Commit c40f02c

Browse files
committed
feat: support AUTO_INCREMENT and IDENTITY columns
Adds support for IDENTITY and AUTO_INCREMENT columns to the Spanner dialect. These are used by default for primary key generation. By default, IDENTITY columns using a backing bit-reversed sequence are used for primary key generation. The sequence kind to use can be configured by setting the attribute default_sequence_kind on the Spanner dialect. The use of AUTO_INCREMENT columns instead of IDENTITY can be configured by setting the use_auto_increment attribute on the Spanner dialect.
1 parent 89c4322 commit c40f02c

File tree

6 files changed

+191
-9
lines changed

6 files changed

+191
-9
lines changed

google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py

+33-4
Original file line numberDiff line numberDiff line change
@@ -409,11 +409,32 @@ def get_column_specification(self, column, **kwargs):
409409
if not column.nullable:
410410
colspec += " NOT NULL"
411411

412+
has_identity = (
413+
column.identity is not None and self.dialect.supports_identity_columns
414+
)
412415
default = self.get_column_default_string(column)
413-
if default is not None:
414-
colspec += " DEFAULT (" + default + ")"
415416

416-
if hasattr(column, "computed") and column.computed is not None:
417+
if (
418+
column.primary_key
419+
and column is column.table._autoincrement_column
420+
and not has_identity
421+
and default is None
422+
):
423+
if (
424+
hasattr(self.dialect, "use_auto_increment")
425+
and self.dialect.use_auto_increment
426+
):
427+
colspec += " AUTO_INCREMENT"
428+
else:
429+
sequence_kind = getattr(
430+
self.dialect, "default_sequence_kind", "BIT_REVERSED_POSITIVE"
431+
)
432+
colspec += " GENERATED BY DEFAULT AS IDENTITY (%s)" % sequence_kind
433+
elif has_identity:
434+
colspec += " " + self.process(column.identity)
435+
elif default is not None:
436+
colspec += " DEFAULT (" + default + ")"
437+
elif hasattr(column, "computed") and column.computed is not None:
417438
colspec += " " + self.process(column.computed)
418439

419440
return colspec
@@ -526,6 +547,12 @@ def visit_create_index(
526547
return text
527548

528549
def get_identity_options(self, identity_options):
550+
text = ["bit_reversed_positive"]
551+
if identity_options.start is not None:
552+
text.append("start counter with %d" % identity_options.start)
553+
return " ".join(text)
554+
555+
def get_sequence_options(self, identity_options):
529556
text = ["sequence_kind = 'bit_reversed_positive'"]
530557
if identity_options.start is not None:
531558
text.append("start_with_counter = %d" % identity_options.start)
@@ -534,7 +561,7 @@ def get_identity_options(self, identity_options):
534561
def visit_create_sequence(self, create, prefix=None, **kw):
535562
"""Builds a ``CREATE SEQUENCE`` statement for the sequence."""
536563
text = "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element)
537-
options = self.get_identity_options(create.element)
564+
options = self.get_sequence_options(create.element)
538565
if options:
539566
text += " OPTIONS (" + options + ")"
540567
return text
@@ -628,11 +655,13 @@ class SpannerDialect(DefaultDialect):
628655
supports_default_values = False
629656
supports_sequences = True
630657
sequences_optional = False
658+
supports_identity_columns = True
631659
supports_native_enum = True
632660
supports_native_boolean = True
633661
supports_native_decimal = True
634662
supports_statement_cache = True
635663

664+
postfetch_lastrowid = False
636665
insert_returning = True
637666
update_returning = True
638667
delete_returning = True
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright 2025 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from sqlalchemy import String
16+
from sqlalchemy.orm import DeclarativeBase
17+
from sqlalchemy.orm import Mapped
18+
from sqlalchemy.orm import mapped_column
19+
20+
21+
class Base(DeclarativeBase):
22+
pass
23+
24+
25+
class Singer(Base):
26+
__tablename__ = "singers"
27+
id: Mapped[int] = mapped_column(primary_key=True)
28+
name: Mapped[str] = mapped_column(String)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright 2024 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from sqlalchemy import create_engine
16+
from sqlalchemy.orm import Session
17+
from sqlalchemy.testing import eq_, is_instance_of
18+
from google.cloud.spanner_v1 import (
19+
FixedSizePool,
20+
ResultSet,
21+
BatchCreateSessionsRequest,
22+
ExecuteSqlRequest,
23+
CommitRequest,
24+
BeginTransactionRequest,
25+
)
26+
from test.mockserver_tests.mock_server_test_base import (
27+
MockServerTestBase,
28+
add_result,
29+
)
30+
from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest
31+
import google.cloud.spanner_v1.types.type as spanner_type
32+
import google.cloud.spanner_v1.types.result_set as result_set
33+
34+
35+
class TestAutoIncrement(MockServerTestBase):
36+
def test_create_table(self):
37+
from test.mockserver_tests.auto_increment_model import Base
38+
39+
add_result(
40+
"""SELECT true
41+
FROM INFORMATION_SCHEMA.TABLES
42+
WHERE TABLE_SCHEMA="" AND TABLE_NAME="singers"
43+
LIMIT 1
44+
""",
45+
ResultSet(),
46+
)
47+
engine = create_engine(
48+
"spanner:///projects/p/instances/i/databases/d",
49+
connect_args={"client": self.client, "pool": FixedSizePool(size=10)},
50+
)
51+
Base.metadata.create_all(engine)
52+
requests = self.database_admin_service.requests
53+
eq_(1, len(requests))
54+
is_instance_of(requests[0], UpdateDatabaseDdlRequest)
55+
eq_(1, len(requests[0].statements))
56+
eq_(
57+
"CREATE TABLE singers (\n"
58+
"\tid INT64 NOT NULL "
59+
"GENERATED BY DEFAULT AS IDENTITY (BIT_REVERSED_POSITIVE), \n"
60+
"\tname STRING(MAX) NOT NULL\n"
61+
") PRIMARY KEY (id)",
62+
requests[0].statements[0],
63+
)
64+
65+
def test_insert_row(self):
66+
from test.mockserver_tests.auto_increment_model import Singer
67+
68+
self.add_insert_result(
69+
"INSERT INTO singers (name) VALUES (@a0) THEN RETURN singers.id"
70+
)
71+
engine = create_engine(
72+
"spanner:///projects/p/instances/i/databases/d",
73+
connect_args={"client": self.client, "pool": FixedSizePool(size=10)},
74+
)
75+
76+
with Session(engine) as session:
77+
singer = Singer(name="Test")
78+
session.add(singer)
79+
# Flush the session to send the insert statement to the database.
80+
session.flush()
81+
eq_(987654321, singer.id)
82+
session.commit()
83+
# Verify the requests that we got.
84+
requests = self.spanner_service.requests
85+
eq_(4, len(requests))
86+
is_instance_of(requests[0], BatchCreateSessionsRequest)
87+
is_instance_of(requests[1], BeginTransactionRequest)
88+
is_instance_of(requests[2], ExecuteSqlRequest)
89+
is_instance_of(requests[3], CommitRequest)
90+
91+
def add_insert_result(self, sql):
92+
result = result_set.ResultSet(
93+
dict(
94+
metadata=result_set.ResultSetMetadata(
95+
dict(
96+
row_type=spanner_type.StructType(
97+
dict(
98+
fields=[
99+
spanner_type.StructType.Field(
100+
dict(
101+
name="id",
102+
type=spanner_type.Type(
103+
dict(code=spanner_type.TypeCode.INT64)
104+
),
105+
)
106+
)
107+
]
108+
)
109+
)
110+
)
111+
),
112+
stats=result_set.ResultSetStats(
113+
dict(
114+
row_count_exact=1,
115+
)
116+
),
117+
)
118+
)
119+
result.rows.extend([("987654321",)])
120+
add_result(sql, result)

test/mockserver_tests/test_basics.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ def test_create_table(self):
127127
eq_(1, len(requests[0].statements))
128128
eq_(
129129
"CREATE TABLE users (\n"
130-
"\tuser_id INT64 NOT NULL, \n"
130+
"\tuser_id INT64 NOT NULL "
131+
"GENERATED BY DEFAULT AS IDENTITY (BIT_REVERSED_POSITIVE), \n"
131132
"\tuser_name STRING(16) NOT NULL\n"
132133
") PRIMARY KEY (user_id)",
133134
requests[0].statements[0],
@@ -163,7 +164,8 @@ def test_create_multiple_tables(self):
163164
for i in range(2):
164165
eq_(
165166
f"CREATE TABLE table{i} (\n"
166-
"\tid INT64 NOT NULL, \n"
167+
"\tid INT64 NOT NULL "
168+
"GENERATED BY DEFAULT AS IDENTITY (BIT_REVERSED_POSITIVE), \n"
167169
"\tvalue STRING(16) NOT NULL"
168170
"\n) PRIMARY KEY (id)",
169171
requests[0].statements[i],

test/mockserver_tests/test_json.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ def test_create_table(self):
5858
eq_(1, len(requests[0].statements))
5959
eq_(
6060
"CREATE TABLE venues (\n"
61-
"\tid INT64 NOT NULL, \n"
61+
"\tid INT64 NOT NULL "
62+
"GENERATED BY DEFAULT AS IDENTITY (BIT_REVERSED_POSITIVE), \n"
6263
"\tname STRING(MAX) NOT NULL, \n"
6364
"\tdescription JSON\n"
6465
") PRIMARY KEY (id)",

test/mockserver_tests/test_quickstart.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,17 @@ def test_create_tables(self):
5353
eq_(2, len(requests[0].statements))
5454
eq_(
5555
"CREATE TABLE user_account (\n"
56-
"\tid INT64 NOT NULL, \n"
56+
"\tid INT64 NOT NULL "
57+
"GENERATED BY DEFAULT AS IDENTITY (BIT_REVERSED_POSITIVE), \n"
5758
"\tname STRING(30) NOT NULL, \n"
5859
"\tfullname STRING(MAX)\n"
5960
") PRIMARY KEY (id)",
6061
requests[0].statements[0],
6162
)
6263
eq_(
6364
"CREATE TABLE address (\n"
64-
"\tid INT64 NOT NULL, \n"
65+
"\tid INT64 NOT NULL "
66+
"GENERATED BY DEFAULT AS IDENTITY (BIT_REVERSED_POSITIVE), \n"
6567
"\temail_address STRING(MAX) NOT NULL, \n"
6668
"\tuser_id INT64 NOT NULL, \n"
6769
"\tFOREIGN KEY(user_id) REFERENCES user_account (id)\n"

0 commit comments

Comments
 (0)