|
| 1 | +# Copyright (c) 2021 The sqlalchemy-bigquery Authors |
| 2 | +# |
| 3 | +# Permission is hereby granted, free of charge, to any person obtaining a copy of |
| 4 | +# this software and associated documentation files (the "Software"), to deal in |
| 5 | +# the Software without restriction, including without limitation the rights to |
| 6 | +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of |
| 7 | +# the Software, and to permit persons to whom the Software is furnished to do so, |
| 8 | +# subject to the following conditions: |
| 9 | +# |
| 10 | +# The above copyright notice and this permission notice shall be included in all |
| 11 | +# copies or substantial portions of the Software. |
| 12 | +# |
| 13 | +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 14 | +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS |
| 15 | +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR |
| 16 | +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER |
| 17 | +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN |
| 18 | +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
| 19 | + |
| 20 | +from typing import Mapping, Tuple |
| 21 | + |
| 22 | +import packaging.version |
| 23 | +import sqlalchemy.sql.default_comparator |
| 24 | +import sqlalchemy.sql.sqltypes |
| 25 | +import sqlalchemy.types |
| 26 | + |
| 27 | +from . import base |
| 28 | + |
| 29 | +sqlalchemy_1_4_or_more = packaging.version.parse( |
| 30 | + sqlalchemy.__version__ |
| 31 | +) >= packaging.version.parse("1.4") |
| 32 | + |
| 33 | +if sqlalchemy_1_4_or_more: |
| 34 | + import sqlalchemy.sql.coercions |
| 35 | + import sqlalchemy.sql.roles |
| 36 | + |
| 37 | + |
| 38 | +def _get_subtype_col_spec(type_): |
| 39 | + global _get_subtype_col_spec |
| 40 | + |
| 41 | + type_compiler = base.dialect.type_compiler(base.dialect()) |
| 42 | + _get_subtype_col_spec = type_compiler.process |
| 43 | + return _get_subtype_col_spec(type_) |
| 44 | + |
| 45 | + |
| 46 | +class STRUCT(sqlalchemy.sql.sqltypes.Indexable, sqlalchemy.types.UserDefinedType): |
| 47 | + """ |
| 48 | + A type for BigQuery STRUCT/RECORD data |
| 49 | +
|
| 50 | + See https://googleapis.dev/python/sqlalchemy-bigquery/latest/struct.html |
| 51 | + """ |
| 52 | + |
| 53 | + # See https://docs.sqlalchemy.org/en/14/core/custom_types.html#creating-new-types |
| 54 | + |
| 55 | + def __init__( |
| 56 | + self, |
| 57 | + *fields: Tuple[str, sqlalchemy.types.TypeEngine], |
| 58 | + **kwfields: Mapping[str, sqlalchemy.types.TypeEngine], |
| 59 | + ): |
| 60 | + # Note that because: |
| 61 | + # https://docs.python.org/3/whatsnew/3.6.html#pep-468-preserving-keyword-argument-order |
| 62 | + # We know that `kwfields` preserves order. |
| 63 | + self._STRUCT_fields = tuple( |
| 64 | + ( |
| 65 | + name, |
| 66 | + type_ if isinstance(type_, sqlalchemy.types.TypeEngine) else type_(), |
| 67 | + ) |
| 68 | + for (name, type_) in (fields + tuple(kwfields.items())) |
| 69 | + ) |
| 70 | + |
| 71 | + self._STRUCT_byname = { |
| 72 | + name.lower(): type_ for (name, type_) in self._STRUCT_fields |
| 73 | + } |
| 74 | + |
| 75 | + def __repr__(self): |
| 76 | + fields = ", ".join( |
| 77 | + f"{name}={repr(type_)}" for name, type_ in self._STRUCT_fields |
| 78 | + ) |
| 79 | + return f"STRUCT({fields})" |
| 80 | + |
| 81 | + def get_col_spec(self, **kw): |
| 82 | + fields = ", ".join( |
| 83 | + f"{name} {_get_subtype_col_spec(type_)}" |
| 84 | + for name, type_ in self._STRUCT_fields |
| 85 | + ) |
| 86 | + return f"STRUCT<{fields}>" |
| 87 | + |
| 88 | + def bind_processor(self, dialect): |
| 89 | + return dict |
| 90 | + |
| 91 | + class Comparator(sqlalchemy.sql.sqltypes.Indexable.Comparator): |
| 92 | + def _setup_getitem(self, name): |
| 93 | + if not isinstance(name, str): |
| 94 | + raise TypeError( |
| 95 | + f"STRUCT fields can only be accessed with strings field names," |
| 96 | + f" not {repr(name)}." |
| 97 | + ) |
| 98 | + subtype = self.expr.type._STRUCT_byname.get(name.lower()) |
| 99 | + if subtype is None: |
| 100 | + raise KeyError(name) |
| 101 | + operator = struct_getitem_op |
| 102 | + index = _field_index(self, name, operator) |
| 103 | + return operator, index, subtype |
| 104 | + |
| 105 | + def __getattr__(self, name): |
| 106 | + if name.lower() in self.expr.type._STRUCT_byname: |
| 107 | + return self[name] |
| 108 | + |
| 109 | + comparator_factory = Comparator |
| 110 | + |
| 111 | + |
| 112 | +# In the implementations of _field_index below, we're stealing from |
| 113 | +# the JSON type implementation, but the code to steal changed in |
| 114 | +# 1.4. :/ |
| 115 | + |
| 116 | +if sqlalchemy_1_4_or_more: |
| 117 | + |
| 118 | + def _field_index(self, name, operator): |
| 119 | + return sqlalchemy.sql.coercions.expect( |
| 120 | + sqlalchemy.sql.roles.BinaryElementRole, |
| 121 | + name, |
| 122 | + expr=self.expr, |
| 123 | + operator=operator, |
| 124 | + bindparam_type=sqlalchemy.types.String(), |
| 125 | + ) |
| 126 | + |
| 127 | + |
| 128 | +else: |
| 129 | + |
| 130 | + def _field_index(self, name, operator): |
| 131 | + return sqlalchemy.sql.default_comparator._check_literal( |
| 132 | + self.expr, operator, name, bindparam_type=sqlalchemy.types.String(), |
| 133 | + ) |
| 134 | + |
| 135 | + |
| 136 | +def struct_getitem_op(a, b): |
| 137 | + raise NotImplementedError() |
| 138 | + |
| 139 | + |
| 140 | +sqlalchemy.sql.default_comparator.operator_lookup[ |
| 141 | + struct_getitem_op.__name__ |
| 142 | +] = sqlalchemy.sql.default_comparator.operator_lookup["json_getitem_op"] |
| 143 | + |
| 144 | + |
| 145 | +class SQLCompiler: |
| 146 | + def visit_struct_getitem_op_binary(self, binary, operator_, **kw): |
| 147 | + left = self.process(binary.left, **kw) |
| 148 | + return f"{left}.{binary.right.value}" |
0 commit comments