Skip to content

Commit b1b6182

Browse files
committed
Add enum for db schemes
1 parent fcbddcf commit b1b6182

File tree

11 files changed

+88
-51
lines changed

11 files changed

+88
-51
lines changed

mautrix/client/state_store/asyncpg/store.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
RoomID,
1717
UserID,
1818
)
19-
from mautrix.util.async_db import Database
19+
from mautrix.util.async_db import Database, Scheme
2020

2121
from ..abstract import StateStore
2222
from .upgrade import upgrade_table
@@ -80,7 +80,7 @@ async def get_members(
8080
memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE),
8181
) -> list[UserID]:
8282
membership_values = [membership.value for membership in memberships]
83-
if self.db.scheme == "postgres":
83+
if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
8484
q = "SELECT user_id FROM mx_user_profile WHERE room_id=$1 AND membership=ANY($2)"
8585
res = await self.db.fetch(q, room_id, membership_values)
8686
else:
@@ -98,7 +98,7 @@ async def get_member_profiles(
9898
memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE),
9999
) -> dict[UserID, Member]:
100100
membership_values = [membership.value for membership in memberships]
101-
if self.db.scheme == "postgres":
101+
if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
102102
q = (
103103
"SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile "
104104
"WHERE room_id=$1 AND membership=ANY($2)"
@@ -123,7 +123,7 @@ async def get_members_filtered(
123123
) -> list[UserID]:
124124
not_like = f"{not_prefix}%{not_suffix}"
125125
membership_values = [membership.value for membership in memberships]
126-
if self.db.scheme == "postgres":
126+
if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
127127
q = (
128128
"SELECT user_id FROM mx_user_profile "
129129
"WHERE room_id=$1 AND membership=ANY($2)"
@@ -155,15 +155,15 @@ async def set_members(
155155
del_q = "DELETE FROM mx_user_profile WHERE room_id=$1"
156156
if only_membership is None:
157157
await conn.execute(del_q, room_id)
158-
elif self.db.scheme == "postgres":
158+
elif self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
159159
del_q = f"{del_q} AND (membership=$2 OR user_id = ANY($3))"
160160
await conn.execute(del_q, room_id, only_membership.value, list(members.keys()))
161161
else:
162162
member_placeholders = ("?," * len(members)).rstrip(",")
163163
del_q = f"{del_q} AND (membership=? OR user_id IN ({member_placeholders}))"
164164
await conn.execute(del_q, room_id, only_membership.value, *members.keys())
165165

166-
if self.db.scheme == "postgres":
166+
if self.db.scheme == Scheme.POSTGRES:
167167
await conn.copy_records_to_table(
168168
"mx_user_profile", records=records, columns=columns
169169
)

mautrix/client/state_store/asyncpg/upgrade.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
66
import logging
77

8-
from asyncpg import Connection
9-
10-
from mautrix.util.async_db.upgrade import UpgradeTable
8+
from mautrix.util.async_db import Connection, Scheme, UpgradeTable
119

1210
upgrade_table = UpgradeTable(
1311
version_table_name="mx_version",
@@ -17,7 +15,7 @@
1715

1816

1917
@upgrade_table.register(description="Latest revision", upgrades_to=2)
20-
async def upgrade_blank_to_v2(conn: Connection, scheme: str) -> None:
18+
async def upgrade_blank_to_v2(conn: Connection, scheme: Scheme) -> None:
2119
await conn.execute(
2220
"""CREATE TABLE mx_room_state (
2321
room_id TEXT PRIMARY KEY,
@@ -27,7 +25,7 @@ async def upgrade_blank_to_v2(conn: Connection, scheme: str) -> None:
2725
power_levels TEXT
2826
)"""
2927
)
30-
if scheme != "sqlite":
28+
if scheme != Scheme.SQLITE:
3129
await conn.execute(
3230
"CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock')"
3331
)
@@ -44,8 +42,8 @@ async def upgrade_blank_to_v2(conn: Connection, scheme: str) -> None:
4442

4543

4644
@upgrade_table.register(description="Stop using size-limited string fields")
47-
async def upgrade_v2(conn: Connection, scheme: str) -> None:
48-
if scheme == "sqlite":
45+
async def upgrade_v2(conn: Connection, scheme: Scheme) -> None:
46+
if scheme == Scheme.SQLITE:
4947
# SQLite doesn't care about types
5048
return
5149
await conn.execute("ALTER TABLE mx_room_state ALTER COLUMN room_id TYPE TEXT")

mautrix/crypto/store/asyncpg/store.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
66
from __future__ import annotations
77

8-
from typing import cast
98
from collections import defaultdict
109
from datetime import timedelta
1110

1211
from mautrix.client.state_store import SyncStore
1312
from mautrix.client.state_store.asyncpg import PgStateStore
1413
from mautrix.types import DeviceID, EventID, IdentityKey, RoomID, SessionID, SyncToken, UserID
15-
from mautrix.util.async_db import Database
14+
from mautrix.util.async_db import Database, Scheme
1615
from mautrix.util.logging import TraceLogger
1716

1817
from ... import (
@@ -339,7 +338,7 @@ async def remove_outbound_group_session(self, room_id: RoomID) -> None:
339338
)
340339

341340
async def remove_outbound_group_sessions(self, rooms: list[RoomID]) -> None:
342-
if self.db.scheme == "postgres":
341+
if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
343342
await self.db.execute(
344343
"DELETE FROM crypto_megolm_outbound_session "
345344
"WHERE account_id=$1 AND room_id=ANY($2)",
@@ -374,7 +373,7 @@ async def validate_message_index(
374373
index: int,
375374
timestamp: int,
376375
) -> bool:
377-
if self.db.scheme == "postgres":
376+
if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
378377
row = await self.db.fetchrow(
379378
self._validate_message_index_query,
380379
sender_key,
@@ -499,7 +498,7 @@ async def put_devices(self, user_id: UserID, devices: dict[DeviceID, DeviceIdent
499498
user_id,
500499
)
501500
await conn.execute("DELETE FROM crypto_device WHERE user_id=$1", user_id)
502-
if self.db.scheme == "postgres":
501+
if self.db.scheme == Scheme.POSTGRES:
503502
await conn.copy_records_to_table("crypto_device", records=data, columns=columns)
504503
else:
505504
await conn.executemany(
@@ -510,7 +509,7 @@ async def put_devices(self, user_id: UserID, devices: dict[DeviceID, DeviceIdent
510509
)
511510

512511
async def filter_tracked_users(self, users: list[UserID]) -> list[UserID]:
513-
if self.db.scheme == "postgres":
512+
if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
514513
rows = await self.db.fetch(
515514
"SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", users
516515
)

mautrix/crypto/store/asyncpg/upgrade.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import logging
99

10-
from mautrix.util.async_db import Connection, UpgradeTable
10+
from mautrix.util.async_db import Connection, Scheme, UpgradeTable
1111

1212
upgrade_table = UpgradeTable(
1313
version_table_name="crypto_version",
@@ -96,8 +96,8 @@ async def upgrade_blank_to_v4(conn: Connection) -> None:
9696

9797

9898
@upgrade_table.register(description="Add account_id primary key column")
99-
async def upgrade_v2(conn: Connection, scheme: str) -> None:
100-
if scheme == "sqlite":
99+
async def upgrade_v2(conn: Connection, scheme: Scheme) -> None:
100+
if scheme == Scheme.SQLITE:
101101
await conn.execute("DROP TABLE crypto_account")
102102
await conn.execute("DROP TABLE crypto_olm_session")
103103
await conn.execute("DROP TABLE crypto_megolm_inbound_session")
@@ -170,8 +170,8 @@ async def add_account_id_column(table: str, pkey_columns: list[str]) -> None:
170170

171171

172172
@upgrade_table.register(description="Stop using size-limited string fields")
173-
async def upgrade_v3(conn: Connection, scheme: str) -> None:
174-
if scheme == "sqlite":
173+
async def upgrade_v3(conn: Connection, scheme: Scheme) -> None:
174+
if scheme == Scheme.SQLITE:
175175
return
176176
await conn.execute("ALTER TABLE crypto_account ALTER COLUMN account_id TYPE TEXT")
177177
await conn.execute("ALTER TABLE crypto_account ALTER COLUMN device_id TYPE TEXT")
@@ -192,11 +192,11 @@ async def upgrade_v3(conn: Connection, scheme: str) -> None:
192192

193193

194194
@upgrade_table.register(description="Split last_used into last_encrypted and last_decrypted")
195-
async def upgrade_v4(conn: Connection, scheme: str) -> None:
195+
async def upgrade_v4(conn: Connection, scheme: Scheme) -> None:
196196
await conn.execute("ALTER TABLE crypto_olm_session RENAME COLUMN last_used TO last_decrypted")
197197
await conn.execute("ALTER TABLE crypto_olm_session ADD COLUMN last_encrypted timestamp")
198198
await conn.execute("UPDATE crypto_olm_session SET last_encrypted=last_decrypted")
199-
if scheme == "postgres":
199+
if scheme == Scheme.POSTGRES:
200200
# This is too hard to do on sqlite, so let's just do it on postgres
201201
await conn.execute(
202202
"ALTER TABLE crypto_olm_session ALTER COLUMN last_encrypted SET NOT NULL"

mautrix/util/async_db/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from .connection import LoggingConnection as Connection
44
from .database import Database
5+
from .scheme import Scheme
56
from .upgrade import UpgradeTable, register_upgrade
67

78
try:

mautrix/util/async_db/aiosqlite.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,17 @@
77

88
from typing import Any
99
from contextlib import asynccontextmanager
10-
from urllib.parse import urlparse
1110
import asyncio
1211
import logging
1312
import re
1413
import sqlite3
1514

15+
from yarl import URL
1616
import aiosqlite
1717

1818
from .connection import LoggingConnection
1919
from .database import Database
20+
from .scheme import Scheme
2021
from .upgrade import UpgradeTable
2122

2223
POSITIONAL_PARAM_PATTERN = re.compile(r"\$(\d+)")
@@ -73,20 +74,20 @@ async def fetchval(
7374

7475

7576
class SQLiteDatabase(Database):
76-
scheme = "sqlite"
77+
scheme = Scheme.SQLITE
7778
_pool: asyncio.Queue[TxnConnection]
7879
_stopped: bool
7980
_conns: int
8081

8182
def __init__(
8283
self,
83-
url: str,
84+
url: URL,
8485
upgrade_table: UpgradeTable,
8586
db_args: dict[str, Any] | None = None,
8687
log: logging.Logger | None = None,
8788
) -> None:
8889
super().__init__(url, db_args=db_args, upgrade_table=upgrade_table, log=log)
89-
self._path = urlparse(url).path
90+
self._path = url.path
9091
if self._path.startswith("/"):
9192
self._path = self._path[1:]
9293
self._pool = asyncio.Queue(self._db_args.pop("min_size", 1))

mautrix/util/async_db/asyncpg.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,31 @@
1010
import asyncio
1111
import logging
1212

13+
from yarl import URL
1314
import asyncpg
1415

1516
from .connection import LoggingConnection
1617
from .database import Database
18+
from .scheme import Scheme
1719
from .upgrade import UpgradeTable
1820

1921

2022
class PostgresDatabase(Database):
21-
scheme = "postgres"
23+
scheme = Scheme.POSTGRES
2224
_pool: asyncpg.pool.Pool | None
2325
_pool_override: bool
2426

2527
def __init__(
2628
self,
27-
url: str,
29+
url: URL,
2830
upgrade_table: UpgradeTable,
2931
db_args: dict[str, Any] = None,
3032
log: logging.Logger | None = None,
3133
) -> None:
34+
if url.scheme in ("cockroach", "cockroachdb"):
35+
self.scheme = Scheme.COCKROACH
36+
# Send postgres scheme to asyncpg
37+
url = url.with_scheme("postgres")
3238
super().__init__(url, db_args=db_args, upgrade_table=upgrade_table, log=log)
3339
self._pool = None
3440
self._pool_override = False
@@ -41,7 +47,7 @@ async def start(self) -> None:
4147
if not self._pool_override:
4248
self._db_args["loop"] = asyncio.get_running_loop()
4349
self.log.debug(f"Connecting to {self.url}")
44-
self._pool = await asyncpg.create_pool(self.url, **self._db_args)
50+
self._pool = await asyncpg.create_pool(str(self.url), **self._db_args)
4551
await super().start()
4652

4753
@property
@@ -62,3 +68,5 @@ async def acquire(self) -> LoggingConnection:
6268

6369
Database.schemes["postgres"] = PostgresDatabase
6470
Database.schemes["postgresql"] = PostgresDatabase
71+
Database.schemes["cockroach"] = PostgresDatabase
72+
Database.schemes["cockroachdb"] = PostgresDatabase

mautrix/util/async_db/connection.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from mautrix import __optional_imports__
1515
from mautrix.util.logging import SILLY, TraceLogger
1616

17+
from .scheme import Scheme
18+
1719
if __optional_imports__:
1820
from sqlite3 import Row
1921

@@ -41,12 +43,15 @@ async def wrapper(self: LoggingConnection, arg: str, *args: Any, **kwargs: str)
4143

4244

4345
class LoggingConnection:
44-
scheme: str
46+
scheme: Scheme
4547
wrapped: aiosqlite.TxnConnection | asyncpg.Connection
4648
log: TraceLogger
4749

4850
def __init__(
49-
self, scheme: str, wrapped: aiosqlite.TxnConnection | asyncpg.Connection, log: TraceLogger
51+
self,
52+
scheme: Scheme,
53+
wrapped: aiosqlite.TxnConnection | asyncpg.Connection,
54+
log: TraceLogger,
5055
) -> None:
5156
self.scheme = scheme
5257
self.wrapped = wrapped
@@ -97,8 +102,8 @@ async def copy_records_to_table(
97102
schema_name: str | None = None,
98103
timeout: float | None = None,
99104
) -> None:
100-
if self.scheme != "postgres":
101-
raise RuntimeError("copy_records_to_table is not supported on SQLite")
105+
if self.scheme != Scheme.POSTGRES:
106+
raise RuntimeError("copy_records_to_table is only supported on Postgres")
102107
return await self.wrapped.copy_records_to_table(
103108
table_name, records=records, columns=columns, schema_name=schema_name, timeout=timeout
104109
)

0 commit comments

Comments
 (0)