Skip to content

Commit 4f3545b

Browse files
committed
fix: Add write support for roundtrip test
1 parent 5980d1e commit 4f3545b

File tree

11 files changed

+165
-31
lines changed

11 files changed

+165
-31
lines changed

cloudquery/sdk/internal/memdb/memdb.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,30 @@
1111
class MemDB(plugin.Plugin):
1212
def __init__(self) -> None:
1313
super().__init__(NAME, VERSION)
14-
self._tables: List[schema.Table] = [
15-
schema.Table("test_table", [schema.Column("test_column", pa.int64())])
16-
]
17-
self._memory_db: Dict[str, pa.record] = {
18-
"test_table": pa.record_batch([pa.array([1, 2, 3])], names=["test_column"])
19-
}
14+
self._db: Dict[str, pa.RecordBatch] = {}
15+
self._tables: Dict[str, schema.Table] = {}
2016

2117
def get_tables(self, options: plugin.TableOptions = None) -> List[plugin.Table]:
22-
return self._tables
18+
tables = list(self._tables.values())
19+
return schema.filter_dfs(tables, options.tables, options.skip_tables)
2320

2421
def sync(
2522
self, options: plugin.SyncOptions
2623
) -> Generator[message.SyncMessage, None, None]:
27-
for table, record in self._memory_db.items():
24+
for table, record in self._db.items():
2825
yield message.SyncInsertMessage(record)
26+
27+
def write(self, msg_iterator: Generator[message.WriteMessage, None, None]) -> None:
28+
for msg in msg_iterator:
29+
if type(msg) == message.WriteMigrateTableMessage:
30+
if msg.table.name not in self._db:
31+
self._db[msg.table.name] = msg.table
32+
self._tables[msg.table.name] = msg.table
33+
elif type(msg) == message.WriteInsertMessage:
34+
table = schema.Table.from_arrow_schema(msg.record.schema)
35+
self._db[table.name] = msg.record
36+
else:
37+
raise NotImplementedError(f"Unknown message type {type(msg)}")
38+
39+
def close(self) -> None:
40+
self._db = {}

cloudquery/sdk/internal/servers/plugin_v3/plugin.py

+38-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
import pyarrow as pa
22
import structlog
33

4+
from typing import Generator
45
from cloudquery.plugin_v3 import plugin_pb2, plugin_pb2_grpc, arrow
5-
from cloudquery.sdk.message import SyncInsertMessage, SyncMigrateTableMessage
6+
from cloudquery.sdk.message import (
7+
SyncInsertMessage,
8+
SyncMigrateTableMessage,
9+
WriteInsertMessage,
10+
WriteMigrateTableMessage,
11+
WriteMessage,
12+
WriteDeleteStale,
13+
)
614
from cloudquery.sdk.plugin.plugin import Plugin, SyncOptions, TableOptions
7-
from cloudquery.sdk.schema import tables_to_arrow_schemas
15+
from cloudquery.sdk.schema import tables_to_arrow_schemas, Table
816

917

1018
class PluginServicer(plugin_pb2_grpc.PluginServicer):
@@ -64,8 +72,34 @@ def Sync(self, request, context):
6472
def Read(self, request, context):
6573
raise NotImplementedError()
6674

67-
def Write(self, request_iterator, context):
68-
raise NotImplementedError()
75+
def Write(
76+
self, request_iterator: Generator[plugin_pb2.Write.Request, None, None], context
77+
):
78+
def msg_iterator() -> Generator[WriteMessage, None, None]:
79+
for msg in request_iterator:
80+
field = msg.WhichOneof("message")
81+
if field == "migrate_table":
82+
sc = arrow.new_schema_from_bytes(msg.migrate_table.table)
83+
table = Table.from_arrow_schema(sc)
84+
yield WriteMigrateTableMessage(table=table)
85+
elif field == "insert":
86+
yield WriteInsertMessage(
87+
record=arrow.new_record_from_bytes(msg.insert.record)
88+
)
89+
elif field == "delete":
90+
yield WriteDeleteStale(
91+
table_name=msg.delete.table_name,
92+
source_name=msg.delete.source_name,
93+
sync_time=msg.delete.sync_time.ToDatetime(),
94+
)
95+
elif field is None:
96+
continue
97+
else:
98+
raise NotImplementedError(f"unknown write message type {field}")
99+
100+
self._plugin.write(msg_iterator())
101+
return plugin_pb2.Write.Response()
69102

70103
def Close(self, request, context):
104+
self._plugin.close()
71105
return plugin_pb2.Close.Response()

cloudquery/sdk/message/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -1 +1,7 @@
11
from .sync import SyncMessage, SyncInsertMessage, SyncMigrateTableMessage
2+
from .write import (
3+
WriteMessage,
4+
WriteInsertMessage,
5+
WriteMigrateTableMessage,
6+
WriteDeleteStale,
7+
)

cloudquery/sdk/message/write.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
11
import pyarrow as pa
2+
from cloudquery.sdk.schema import Table
23

34

45
class WriteMessage:
56
pass
67

78

8-
class InsertMessage(WriteMessage):
9+
class WriteInsertMessage(WriteMessage):
910
def __init__(self, record: pa.RecordBatch):
1011
self.record = record
1112

1213

13-
class MigrateMessage(WriteMessage):
14-
def __init__(self, table: pa.Schema):
14+
class WriteMigrateTableMessage(WriteMessage):
15+
def __init__(self, table: Table):
1516
self.table = table
17+
18+
19+
class WriteDeleteStale(WriteMessage):
20+
def __init__(self, table_name: str, source_name: str, sync_time):
21+
self.table_name = table_name
22+
self.source_name = source_name
23+
self.sync_time = sync_time

cloudquery/sdk/plugin/plugin.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,8 @@ def get_tables(self, options: TableOptions) -> List[Table]:
5050
def sync(self, options: SyncOptions) -> Generator[message.SyncMessage, None, None]:
5151
raise NotImplementedError()
5252

53+
def write(self, writer: Generator[message.WriteMessage, None, None]) -> None:
54+
raise NotImplementedError()
55+
5356
def close(self) -> None:
54-
pass
57+
raise NotImplementedError()

cloudquery/sdk/scheduler/scheduler.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,18 @@ def resolve_table(
103103
)
104104
total_resources = 0
105105
for item in resolver.resolve(client, parent_item):
106-
resource = self.resolve_resource(resolver, client, parent_item, item)
106+
try:
107+
resource = self.resolve_resource(
108+
resolver, client, parent_item, item
109+
)
110+
except Exception as e:
111+
self._logger.error(
112+
"failed to resolve resource",
113+
table=resolver.table.name,
114+
depth=depth,
115+
exception=e,
116+
)
117+
continue
107118
res.put(SyncInsertMessage(resource.to_arrow_record()))
108119
for child_resolvers in resolver.child_resolvers:
109120
self._pools[depth + 1].submit(

cloudquery/sdk/schema/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .column import Column
2-
from .table import Table, tables_to_arrow_schemas
2+
from .table import Table, tables_to_arrow_schemas, filter_dfs
33
from .resource import Resource
44

55
# from .table_resolver import TableReso

cloudquery/sdk/schema/arrow.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
METADATA_UNIQUE = "cq:extension:unique"
2-
METADATA_PRIMARY_KEY = "cq:extension:primary_key"
3-
METADATA_CONSTRAINT_NAME = "cq:extension:constraint_name"
4-
METADATA_INCREMENTAL = "cq:extension:incremental"
1+
METADATA_UNIQUE = b"cq:extension:unique"
2+
METADATA_PRIMARY_KEY = b"cq:extension:primary_key"
3+
METADATA_CONSTRAINT_NAME = b"cq:extension:constraint_name"
4+
METADATA_INCREMENTAL = b"cq:extension:incremental"
55

6-
METADATA_TRUE = "true"
7-
METADATA_FALSE = "false"
8-
METADATA_TABLE_NAME = "cq:table_name"
9-
METADATA_TABLE_DESCRIPTION = "cq:table_description"
6+
METADATA_TRUE = b"true"
7+
METADATA_FALSE = b"false"
8+
METADATA_TABLE_NAME = b"cq:table_name"
9+
METADATA_TABLE_DESCRIPTION = b"cq:table_description"

cloudquery/sdk/schema/table.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from typing import List, Generator, Any
4-
4+
import fnmatch
55
import pyarrow as pa
66

77
from cloudquery.sdk.schema import arrow
@@ -53,6 +53,19 @@ def primary_keys(self):
5353
def incremental_keys(self):
5454
return [column.name for column in self.columns if column.incremental_key]
5555

56+
@classmethod
57+
def from_arrow_schema(cls, schema: pa.Schema) -> Table:
58+
columns = []
59+
for field in schema:
60+
columns.append(Column.from_arrow_field(field))
61+
return cls(
62+
name=schema.metadata[arrow.METADATA_TABLE_NAME].decode("utf-8"),
63+
columns=columns,
64+
description=schema.metadata.get(arrow.METADATA_TABLE_DESCRIPTION).decode(
65+
"utf-8"
66+
),
67+
)
68+
5669
def to_arrow_schema(self):
5770
fields = []
5871
md = {
@@ -74,3 +87,20 @@ def tables_to_arrow_schemas(tables: List[Table]):
7487
for table in tables:
7588
schemas.append(table.to_arrow_schema())
7689
return schemas
90+
91+
92+
def filter_dfs(
93+
tables: List[Table], include_tables: List[str], skip_tables: List[str]
94+
) -> List[Table]:
95+
filtered: List[Table] = []
96+
for table in tables:
97+
matched = False
98+
for include_table in include_tables:
99+
if fnmatch.fnmatch(table.name, include_table):
100+
matched = True
101+
for skip_table in skip_tables:
102+
if fnmatch.fnmatch(table.name, skip_table):
103+
matched = False
104+
if matched:
105+
filtered.append(table)
106+
return filtered

tests/internal/memdb/memdb.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ def test_memdb():
66
p = memdb.MemDB()
77
p.init(None)
88
msgs = []
9-
for msg in p.sync(SyncOptions()):
9+
for msg in p.sync(SyncOptions(tables=["*"])):
1010
msgs.append(msg)
11-
assert len(msgs) == 1
11+
assert len(msgs) == 0

tests/serve/plugin.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
import random
22
import grpc
33
import time
4+
import pyarrow as pa
45
from concurrent import futures
6+
from cloudquery.sdk.schema import Table, Column
57
from cloudquery.sdk import serve
68
from cloudquery.sdk import message
79
from cloudquery.plugin_v3 import plugin_pb2_grpc, plugin_pb2, arrow
810
from cloudquery.sdk.internal.memdb import MemDB
911

12+
test_table = Table(
13+
"test",
14+
[
15+
Column("id", pa.int64()),
16+
Column("name", pa.string()),
17+
],
18+
)
19+
1020

1121
def test_plugin_serve():
1222
p = MemDB()
@@ -27,11 +37,31 @@ def test_plugin_serve():
2737
response = stub.Init(plugin_pb2.Init.Request(spec=b""))
2838
assert response is not None
2939

30-
response = stub.GetTables(plugin_pb2.GetTables.Request())
40+
def writer_iterator():
41+
buf = arrow.schema_to_bytes(test_table.to_arrow_schema())
42+
yield plugin_pb2.Write.Request(
43+
migrate_table=plugin_pb2.Write.MessageMigrateTable(table=buf)
44+
)
45+
record = pa.RecordBatch.from_arrays(
46+
[
47+
pa.array([1, 2, 3]),
48+
pa.array(["a", "b", "c"]),
49+
],
50+
schema=test_table.to_arrow_schema(),
51+
)
52+
yield plugin_pb2.Write.Request(
53+
insert=plugin_pb2.Write.MessageInsert(
54+
record=arrow.record_to_bytes(record)
55+
)
56+
)
57+
58+
stub.Write(writer_iterator())
59+
60+
response = stub.GetTables(plugin_pb2.GetTables.Request(tables=["*"]))
3161
schemas = arrow.new_schemas_from_bytes(response.tables)
3262
assert len(schemas) == 1
3363

34-
response = stub.Sync(plugin_pb2.Sync.Request())
64+
response = stub.Sync(plugin_pb2.Sync.Request(tables=["*"]))
3565
total_records = 0
3666
for msg in response:
3767
if msg.insert is not None:

0 commit comments

Comments
 (0)