Skip to content

Commit 87d8548

Browse files
committed
define snowflake catalog
1 parent 0155405 commit 87d8548

File tree

6 files changed

+1201
-680
lines changed

6 files changed

+1201
-680
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ install-poetry:
1919
pip install poetry==1.8.2
2020

2121
install-dependencies:
22-
poetry install -E pyarrow -E hive -E s3fs -E glue -E adlfs -E duckdb -E ray -E sql-postgres -E gcsfs -E sql-sqlite -E daft
22+
poetry install -E pyarrow -E hive -E s3fs -E glue -E adlfs -E duckdb -E ray -E sql-postgres -E gcsfs -E sql-sqlite -E daft -E snowflake
2323

2424
install: | install-poetry install-dependencies
2525

poetry.lock

+729-619
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyiceberg/catalog/__init__.py

+11
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ class CatalogType(Enum):
106106
GLUE = "glue"
107107
DYNAMODB = "dynamodb"
108108
SQL = "sql"
109+
SNOWFLAKE = "snowflake"
109110

110111

111112
def load_rest(name: str, conf: Properties) -> Catalog:
@@ -152,12 +153,22 @@ def load_sql(name: str, conf: Properties) -> Catalog:
152153
) from exc
153154

154155

156+
def load_snowflake(name: str, conf: Properties) -> Catalog:
157+
try:
158+
from pyiceberg.catalog.snowflake_catalog import SnowflakeCatalog
159+
160+
return SnowflakeCatalog(name, **conf)
161+
except ImportError as exc:
162+
raise NotInstalledError("Snowflake support not installed: pip install 'pyiceberg[snowflake]'") from exc
163+
164+
155165
AVAILABLE_CATALOGS: dict[CatalogType, Callable[[str, Properties], Catalog]] = {
156166
CatalogType.REST: load_rest,
157167
CatalogType.HIVE: load_hive,
158168
CatalogType.GLUE: load_glue,
159169
CatalogType.DYNAMODB: load_dynamodb,
160170
CatalogType.SQL: load_sql,
171+
CatalogType.SNOWFLAKE: load_snowflake,
161172
}
162173

163174

+221
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
from __future__ import annotations
2+
3+
import json
4+
import os
5+
from dataclasses import dataclass
6+
from typing import Iterator, List, Optional, Set, Union
7+
8+
import pyarrow as pa
9+
from boto3.session import Session
10+
from snowflake.connector import DictCursor, SnowflakeConnection
11+
12+
from pyiceberg.catalog import MetastoreCatalog, PropertiesUpdateSummary
13+
from pyiceberg.exceptions import NoSuchTableError, TableAlreadyExistsError
14+
from pyiceberg.io import S3_ACCESS_KEY_ID, S3_REGION, S3_SECRET_ACCESS_KEY, S3_SESSION_TOKEN
15+
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec
16+
from pyiceberg.schema import Schema
17+
from pyiceberg.table import CommitTableRequest, CommitTableResponse, StaticTable, Table, sorting
18+
from pyiceberg.typedef import EMPTY_DICT, Identifier, Properties
19+
20+
21+
class SnowflakeCatalog(MetastoreCatalog):
22+
@dataclass(frozen=True, eq=True)
23+
class _SnowflakeIdentifier:
24+
database: str | None
25+
schema: str | None
26+
table: str | None
27+
28+
def __iter__(self) -> Iterator[str]:
29+
"""
30+
Iterate of the non-None parts of the identifier.
31+
32+
Returns:
33+
Iterator[str]: Iterator of the non-None parts of the identifier.
34+
"""
35+
yield from filter(None, [self.database, self.schema, self.table])
36+
37+
@classmethod
38+
def table_from_string(cls, identifier: str) -> SnowflakeCatalog._SnowflakeIdentifier:
39+
parts = identifier.split(".")
40+
if len(parts) == 1:
41+
return cls(None, None, parts[0])
42+
elif len(parts) == 2:
43+
return cls(None, parts[0], parts[1])
44+
elif len(parts) == 3:
45+
return cls(parts[0], parts[1], parts[2])
46+
47+
raise ValueError(f"Invalid identifier: {identifier}")
48+
49+
@classmethod
50+
def schema_from_string(cls, identifier: str) -> SnowflakeCatalog._SnowflakeIdentifier:
51+
parts = identifier.split(".")
52+
if len(parts) == 1:
53+
return cls(None, parts[0], None)
54+
elif len(parts) == 2:
55+
return cls(parts[0], parts[1], None)
56+
57+
raise ValueError(f"Invalid identifier: {identifier}")
58+
59+
@property
60+
def table_name(self) -> str:
61+
return ".".join(self)
62+
63+
@property
64+
def schema_name(self) -> str:
65+
return ".".join(self)
66+
67+
def __init__(self, name: str, **properties: str):
68+
super().__init__(name, **properties)
69+
70+
params = {
71+
"user": properties["user"],
72+
"account": properties["account"],
73+
}
74+
75+
if "authenticator" in properties:
76+
params["authenticator"] = properties["authenticator"]
77+
78+
if "password" in properties:
79+
params["password"] = properties["password"]
80+
81+
if "private_key" in properties:
82+
params["private_key"] = properties["private_key"]
83+
84+
self.connection = SnowflakeConnection(**params)
85+
86+
def load_table(self, identifier: Union[str, Identifier]) -> Table:
87+
sf_identifier = SnowflakeCatalog._SnowflakeIdentifier.table_from_string(
88+
identifier if isinstance(identifier, str) else ".".join(identifier)
89+
)
90+
91+
metadata_query = "SELECT SYSTEM$GET_ICEBERG_TABLE_INFORMATION(%s) AS METADATA"
92+
93+
with self.connection.cursor(DictCursor) as cursor:
94+
try:
95+
cursor.execute(metadata_query, (sf_identifier.table_name,))
96+
metadata = json.loads(cursor.fetchone()["METADATA"])["metadataLocation"]
97+
except Exception as e:
98+
raise NoSuchTableError(f"Table {sf_identifier.table_name} not found") from e
99+
100+
session = Session()
101+
credentials = session.get_credentials()
102+
current_credentials = credentials.get_frozen_credentials()
103+
104+
s3_props = {
105+
S3_ACCESS_KEY_ID: current_credentials.access_key,
106+
S3_SECRET_ACCESS_KEY: current_credentials.secret_key,
107+
S3_SESSION_TOKEN: current_credentials.token,
108+
S3_REGION: os.environ.get("AWS_REGION", "us-east-1"),
109+
}
110+
111+
tbl = StaticTable.from_metadata(metadata, properties=s3_props)
112+
tbl.identifier = tuple(identifier.split(".")) if isinstance(identifier, str) else identifier
113+
tbl.catalog = self
114+
115+
return tbl
116+
117+
def register_table(self, identifier: Union[str, Identifier], metadata_location: str) -> Table:
118+
query = "CREATE ICEBERG TABLE (%s) METADATA_FILE_PATH = (%s)"
119+
sf_identifier = SnowflakeCatalog._SnowflakeIdentifier.table_from_string(
120+
identifier if isinstance(identifier, str) else ".".join(identifier)
121+
)
122+
123+
with self.connection.cursor(DictCursor) as cursor:
124+
try:
125+
cursor.execute(query, (sf_identifier.table_name, metadata_location))
126+
except Exception as e:
127+
raise TableAlreadyExistsError(f"Table {sf_identifier.table_name} already exists") from e
128+
129+
return self.load_table(identifier)
130+
131+
def drop_table(self, identifier: Union[str, Identifier]) -> None:
132+
sf_identifier = SnowflakeCatalog._SnowflakeIdentifier.table_from_string(
133+
identifier if isinstance(identifier, str) else ".".join(identifier)
134+
)
135+
136+
query = "DROP TABLE IF EXISTS (%s)"
137+
138+
with self.connection.cursor(DictCursor) as cursor:
139+
cursor.execute(query, (sf_identifier.table_name,))
140+
141+
def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: Union[str, Identifier]) -> Table:
142+
sf_from_identifier = SnowflakeCatalog._SnowflakeIdentifier.table_from_string(
143+
from_identifier if isinstance(from_identifier, str) else ".".join(from_identifier)
144+
)
145+
sf_to_identifier = SnowflakeCatalog._SnowflakeIdentifier.table_from_string(
146+
to_identifier if isinstance(to_identifier, str) else ".".join(to_identifier)
147+
)
148+
149+
query = "ALTER TABLE (%s) RENAME TO (%s)"
150+
151+
with self.connection.cursor(DictCursor) as cursor:
152+
cursor.execute(query, (sf_from_identifier.table_name, sf_to_identifier.table_name))
153+
154+
return self.load_table(to_identifier)
155+
156+
def _commit_table(self, table_request: CommitTableRequest) -> CommitTableResponse:
157+
raise NotImplementedError
158+
159+
def create_namespace(self, namespace: Union[str, Identifier], properties: Properties = EMPTY_DICT) -> None:
160+
sf_identifier = SnowflakeCatalog._SnowflakeIdentifier.schema_from_string(
161+
namespace if isinstance(namespace, str) else ".".join(namespace)
162+
)
163+
164+
db_query = "CREATE DATABASE IF NOT EXISTS (%s)"
165+
schema_query = "CREATE SCHEMA IF NOT EXISTS (%s)"
166+
167+
with self.connection.cursor(DictCursor) as cursor:
168+
if sf_identifier.database:
169+
cursor.execute(db_query, (sf_identifier.database,))
170+
cursor.execute(schema_query, (sf_identifier.schema_name,))
171+
172+
def drop_namespace(self, namespace: Union[str, Identifier]) -> None:
173+
sf_identifier = SnowflakeCatalog._SnowflakeIdentifier.schema_from_string(
174+
namespace if isinstance(namespace, str) else ".".join(namespace)
175+
)
176+
177+
sf_query = "DROP SCHEMA IF EXISTS (%s)"
178+
db_query = "DROP DATABASE IF EXISTS (%s)"
179+
180+
with self.connection.cursor(DictCursor) as cursor:
181+
if sf_identifier.database:
182+
cursor.execute(db_query, (sf_identifier.database,))
183+
cursor.execute(sf_query, (sf_identifier.schema_name,))
184+
185+
def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]:
186+
sf_identifier = SnowflakeCatalog._SnowflakeIdentifier.schema_from_string(
187+
namespace if isinstance(namespace, str) else ".".join(namespace)
188+
)
189+
190+
schema_query = "SHOW ICEBERG TABLES IN SCHEMA (%s)"
191+
db_query = "SHOW ICEBERG TABLES IN DATABASE (%s)"
192+
193+
with self.connection.cursor(DictCursor) as cursor:
194+
if sf_identifier.database:
195+
cursor.execute(db_query, (sf_identifier.database,))
196+
else:
197+
cursor.execute(schema_query, (sf_identifier.schema,))
198+
199+
return [(row["database_name"], row["schema_name"], row["table_name"]) for row in cursor.fetchall()]
200+
201+
def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]:
202+
raise NotImplementedError
203+
204+
def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Properties:
205+
raise NotImplementedError
206+
207+
def update_namespace_properties(
208+
self, namespace: Union[str, Identifier], removals: Optional[Set[str]] = None, updates: Properties = EMPTY_DICT
209+
) -> PropertiesUpdateSummary:
210+
raise NotImplementedError
211+
212+
def create_table(
213+
self,
214+
identifier: Union[str, Identifier],
215+
schema: Union[Schema, pa.Schema],
216+
location: Optional[str] = None,
217+
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
218+
sort_order: sorting.SortOrder = sorting.UNSORTED_SORT_ORDER,
219+
properties: Properties = EMPTY_DICT,
220+
) -> Table:
221+
raise NotImplementedError

0 commit comments

Comments
 (0)