Skip to content

Native support for additional Type Converters #353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jul 15, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 139 additions & 65 deletions graphene_sqlalchemy/converter.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions graphene_sqlalchemy/enums.py
Original file line number Diff line number Diff line change
@@ -144,9 +144,9 @@ def sort_enum_for_object_type(
column = orm_field.columns[0]
if only_indexed and not (column.primary_key or column.index):
continue
asc_name = get_name(column.name, True)
asc_name = get_name(column.key, True)
asc_value = EnumValue(asc_name, column.asc())
desc_name = get_name(column.name, False)
desc_name = get_name(column.key, False)
desc_value = EnumValue(desc_name, column.desc())
if column.primary_key:
default.append(asc_value)
38 changes: 29 additions & 9 deletions graphene_sqlalchemy/registry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from collections import defaultdict
from typing import List, Type

from sqlalchemy.types import Enum as SQLAlchemyEnumType

import graphene
from graphene import Enum


@@ -13,12 +15,13 @@ def __init__(self):
self._registry_composites = {}
self._registry_enums = {}
self._registry_sort_enums = {}
self._registry_unions = {}

def register(self, obj_type):
from .types import SQLAlchemyObjectType

from .types import SQLAlchemyObjectType
if not isinstance(obj_type, type) or not issubclass(
obj_type, SQLAlchemyObjectType
obj_type, SQLAlchemyObjectType
):
raise TypeError(
"Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)
@@ -37,7 +40,7 @@ def register_orm_field(self, obj_type, field_name, orm_field):
from .types import SQLAlchemyObjectType

if not isinstance(obj_type, type) or not issubclass(
obj_type, SQLAlchemyObjectType
obj_type, SQLAlchemyObjectType
):
raise TypeError(
"Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)
@@ -55,7 +58,7 @@ def register_composite_converter(self, composite, converter):
def get_converter_for_composite(self, composite):
return self._registry_composites.get(composite)

def register_enum(self, sa_enum, graphene_enum):
def register_enum(self, sa_enum: SQLAlchemyEnumType, graphene_enum: Enum):
if not isinstance(sa_enum, SQLAlchemyEnumType):
raise TypeError(
"Expected SQLAlchemyEnumType, but got: {!r}".format(sa_enum)
@@ -67,14 +70,14 @@ def register_enum(self, sa_enum, graphene_enum):

self._registry_enums[sa_enum] = graphene_enum

def get_graphene_enum_for_sa_enum(self, sa_enum):
def get_graphene_enum_for_sa_enum(self, sa_enum: SQLAlchemyEnumType):
return self._registry_enums.get(sa_enum)

def register_sort_enum(self, obj_type, sort_enum):
from .types import SQLAlchemyObjectType
def register_sort_enum(self, obj_type, sort_enum: Enum):

from .types import SQLAlchemyObjectType
if not isinstance(obj_type, type) or not issubclass(
obj_type, SQLAlchemyObjectType
obj_type, SQLAlchemyObjectType
):
raise TypeError(
"Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)
@@ -83,9 +86,26 @@ def register_sort_enum(self, obj_type, sort_enum):
raise TypeError("Expected Graphene Enum, but got: {!r}".format(sort_enum))
self._registry_sort_enums[obj_type] = sort_enum

def get_sort_enum_for_object_type(self, obj_type):
def get_sort_enum_for_object_type(self, obj_type: graphene.ObjectType):
return self._registry_sort_enums.get(obj_type)

def register_union_type(self, union: graphene.Union, obj_types: List[Type[graphene.ObjectType]]):
if not isinstance(union, graphene.Union):
raise TypeError(
"Expected graphene.Union, but got: {!r}".format(union)
)

for obj_type in obj_types:
if not isinstance(obj_type, type(graphene.ObjectType)):
raise TypeError(
"Expected Graphene ObjectType, but got: {!r}".format(obj_type)
)

self._registry_unions[frozenset(obj_types)] = union

def get_union_for_object_types(self, obj_types : List[Type[graphene.ObjectType]]):
return self._registry_unions.get(frozenset(obj_types))


registry = None

10 changes: 8 additions & 2 deletions graphene_sqlalchemy/tests/models.py
Original file line number Diff line number Diff line change
@@ -5,8 +5,8 @@
from decimal import Decimal
from typing import List, Optional, Tuple

from sqlalchemy import (Column, Date, Enum, ForeignKey, Integer, String, Table,
func, select)
from sqlalchemy import (Column, Date, Enum, ForeignKey, Integer, Numeric,
String, Table, func, select)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import column_property, composite, mapper, relationship
@@ -228,3 +228,9 @@ def hybrid_prop_self_referential_list(self) -> List['ShoppingCart']:
@hybrid_property
def hybrid_prop_optional_self_referential(self) -> Optional['ShoppingCart']:
return None


class KeyedModel(Base):
__tablename__ = "test330"
id = Column(Integer(), primary_key=True)
reporter_number = Column("% reporter_number", Numeric, key="reporter_number")
243 changes: 197 additions & 46 deletions graphene_sqlalchemy/tests/test_converter.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
import enum
import sys
from typing import Dict, Union

import pytest
import sqlalchemy_utils as sqa_utils
from sqlalchemy import Column, func, select, types
from sqlalchemy.dialects import postgresql
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import column_property, composite
from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType

import graphene
from graphene import Boolean, Float, Int, Scalar, String
from graphene.relay import Node
from graphene.types.datetime import Date, DateTime, Time
from graphene.types.json import JSONString
from graphene.types.structures import List, Structure
from graphene.types.structures import Structure

from ..converter import (convert_sqlalchemy_column,
convert_sqlalchemy_composite,
convert_sqlalchemy_hybrid_method,
convert_sqlalchemy_relationship)
from ..fields import (UnsortedSQLAlchemyConnectionField,
default_connection_field_factory)
from ..registry import Registry, get_global_registry
from ..types import SQLAlchemyObjectType
from ..types import ORMField, SQLAlchemyObjectType
from .models import (Article, CompositeFullName, Pet, Reporter, ShoppingCart,
ShoppingCartItem)

@@ -51,23 +51,117 @@ class Model(declarative_base()):
return convert_sqlalchemy_column(column_prop, get_global_registry(), mock_resolver)


def test_should_unknown_sqlalchemy_field_raise_exception():
re_err = "Don't know how to convert the SQLAlchemy field"
with pytest.raises(Exception, match=re_err):
# support legacy Binary type and subsequent LargeBinary
get_field(getattr(types, 'LargeBinary', types.BINARY)())
def get_hybrid_property_type(prop_method):
class Model(declarative_base()):
__tablename__ = 'model'
id_ = Column(types.Integer, primary_key=True)
prop = prop_method

column_prop = inspect(Model).all_orm_descriptors['prop']
return convert_sqlalchemy_hybrid_method(column_prop, mock_resolver(), **ORMField().kwargs)


def test_hybrid_prop_int():
@hybrid_property
def prop_method() -> int:
return 42

assert get_hybrid_property_type(prop_method).type == graphene.Int


@pytest.mark.skipif(sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10")
def test_hybrid_prop_scalar_union_310():
@hybrid_property
def prop_method() -> int | str:
return "not allowed in gql schema"

with pytest.raises(ValueError,
match=r"Cannot convert hybrid_property Union to "
r"graphene.Union: the Union contains scalars. \.*"):
get_hybrid_property_type(prop_method)


@pytest.mark.skipif(sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10")
def test_hybrid_prop_scalar_union_and_optional_310():
"""Checks if the use of Optionals does not interfere with non-conform scalar return types"""

@hybrid_property
def prop_method() -> int | None:
return 42

assert get_hybrid_property_type(prop_method).type == graphene.Int


@pytest.mark.skipif(sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10")
def test_should_union_work_310():
reg = Registry()

class PetType(SQLAlchemyObjectType):
class Meta:
model = Pet
registry = reg

class ShoppingCartType(SQLAlchemyObjectType):
class Meta:
model = ShoppingCartItem
registry = reg

@hybrid_property
def prop_method() -> Union[PetType, ShoppingCartType]:
return None

@hybrid_property
def prop_method_2() -> Union[ShoppingCartType, PetType]:
return None

field_type_1 = get_hybrid_property_type(prop_method).type
field_type_2 = get_hybrid_property_type(prop_method_2).type

assert isinstance(field_type_1, graphene.Union)
assert field_type_1 is field_type_2

# TODO verify types of the union


@pytest.mark.skipif(sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10")
def test_should_union_work_310():
reg = Registry()

class PetType(SQLAlchemyObjectType):
class Meta:
model = Pet
registry = reg

class ShoppingCartType(SQLAlchemyObjectType):
class Meta:
model = ShoppingCartItem
registry = reg

@hybrid_property
def prop_method() -> PetType | ShoppingCartType:
return None

@hybrid_property
def prop_method_2() -> ShoppingCartType | PetType:
return None

def test_should_date_convert_string():
assert get_field(types.Date()).type == graphene.String
field_type_1 = get_hybrid_property_type(prop_method).type
field_type_2 = get_hybrid_property_type(prop_method_2).type

assert isinstance(field_type_1, graphene.Union)
assert field_type_1 is field_type_2


def test_should_datetime_convert_datetime():
assert get_field(types.DateTime()).type == DateTime
assert get_field(types.DateTime()).type == graphene.DateTime


def test_should_time_convert_time():
assert get_field(types.Time()).type == graphene.Time

def test_should_time_convert_string():
assert get_field(types.Time()).type == graphene.String

def test_should_date_convert_date():
assert get_field(types.Date()).type == graphene.Date


def test_should_string_convert_string():
@@ -86,6 +180,30 @@ def test_should_unicodetext_convert_string():
assert get_field(types.UnicodeText()).type == graphene.String


def test_should_tsvector_convert_string():
assert get_field(sqa_utils.TSVectorType()).type == graphene.String


def test_should_email_convert_string():
assert get_field(sqa_utils.EmailType()).type == graphene.String


def test_should_URL_convert_string():
assert get_field(sqa_utils.URLType()).type == graphene.String


def test_should_IPaddress_convert_string():
assert get_field(sqa_utils.IPAddressType()).type == graphene.String


def test_should_inet_convert_string():
assert get_field(postgresql.INET()).type == graphene.String


def test_should_cidr_convert_string():
assert get_field(postgresql.CIDR()).type == graphene.String


def test_should_enum_convert_enum():
field = get_field(types.Enum(enum.Enum("TwoNumbers", ("one", "two"))))
field_type = field.type()
@@ -142,7 +260,7 @@ def test_should_numeric_convert_float():


def test_should_choice_convert_enum():
field = get_field(ChoiceType([(u"es", u"Spanish"), (u"en", u"English")]))
field = get_field(sqa_utils.ChoiceType([(u"es", u"Spanish"), (u"en", u"English")]))
graphene_type = field.type
assert issubclass(graphene_type, graphene.Enum)
assert graphene_type._meta.name == "MODEL_COLUMN"
@@ -155,20 +273,40 @@ class TestEnum(enum.Enum):
es = u"Spanish"
en = u"English"

field = get_field(ChoiceType(TestEnum, impl=types.String()))
field = get_field(sqa_utils.ChoiceType(TestEnum, impl=types.String()))
graphene_type = field.type
assert issubclass(graphene_type, graphene.Enum)
assert graphene_type._meta.name == "MODEL_COLUMN"
assert graphene_type._meta.enum.__members__["es"].value == "Spanish"
assert graphene_type._meta.enum.__members__["en"].value == "English"


def test_choice_enum_column_key_name_issue_301():
"""
Verifies that the sort enum name is generated from the column key instead of the name,
in case the column has an invalid enum name. See #330
"""

class TestEnum(enum.Enum):
es = u"Spanish"
en = u"English"

testChoice = Column("% descuento1", sqa_utils.ChoiceType(TestEnum, impl=types.String()), key="descuento1")
field = get_field_from_column(testChoice)

graphene_type = field.type
assert issubclass(graphene_type, graphene.Enum)
assert graphene_type._meta.name == "MODEL_DESCUENTO1"
assert graphene_type._meta.enum.__members__["es"].value == "Spanish"
assert graphene_type._meta.enum.__members__["en"].value == "English"


def test_should_intenum_choice_convert_enum():
class TestEnum(enum.IntEnum):
one = 1
two = 2

field = get_field(ChoiceType(TestEnum, impl=types.String()))
field = get_field(sqa_utils.ChoiceType(TestEnum, impl=types.String()))
graphene_type = field.type
assert issubclass(graphene_type, graphene.Enum)
assert graphene_type._meta.name == "MODEL_COLUMN"
@@ -185,13 +323,22 @@ def test_should_columproperty_convert():


def test_should_scalar_list_convert_list():
field = get_field(ScalarListType())
field = get_field(sqa_utils.ScalarListType())
assert isinstance(field.type, graphene.List)
assert field.type.of_type == graphene.String


def test_should_jsontype_convert_jsonstring():
assert get_field(JSONType()).type == JSONString
assert get_field(sqa_utils.JSONType()).type == graphene.JSONString
assert get_field(types.JSON).type == graphene.JSONString


def test_should_variant_int_convert_int():
assert get_field(types.Variant(types.Integer(), {})).type == graphene.Int


def test_should_variant_string_convert_string():
assert get_field(types.Variant(types.String(), {})).type == graphene.String


def test_should_manytomany_convert_connectionorlist():
@@ -291,7 +438,11 @@ class Meta:


def test_should_postgresql_uuid_convert():
assert get_field(postgresql.UUID()).type == graphene.String
assert get_field(postgresql.UUID()).type == graphene.UUID


def test_should_sqlalchemy_utils_uuid_convert():
assert get_field(sqa_utils.UUIDType()).type == graphene.UUID


def test_should_postgresql_enum_convert():
@@ -405,8 +556,8 @@ class Meta:
# Check ShoppingCartItem's Properties and Return Types
#######################################################

shopping_cart_item_expected_types: Dict[str, Union[Scalar, Structure]] = {
'hybrid_prop_shopping_cart': List(ShoppingCartType)
shopping_cart_item_expected_types: Dict[str, Union[graphene.Scalar, Structure]] = {
'hybrid_prop_shopping_cart': graphene.List(ShoppingCartType)
}

assert sorted(list(ShoppingCartItemType._meta.fields.keys())) == sorted([
@@ -421,37 +572,37 @@ class Meta:

# this is a simple way of showing the failed property name
# instead of having to unroll the loop.
assert (
(hybrid_prop_name, str(hybrid_prop_field.type)) ==
(hybrid_prop_name, str(hybrid_prop_expected_return_type))
assert (hybrid_prop_name, str(hybrid_prop_field.type)) == (
hybrid_prop_name,
str(hybrid_prop_expected_return_type),
)
assert hybrid_prop_field.description is None # "doc" is ignored by hybrid property

###################################################
# Check ShoppingCart's Properties and Return Types
###################################################

shopping_cart_expected_types: Dict[str, Union[Scalar, Structure]] = {
shopping_cart_expected_types: Dict[str, Union[graphene.Scalar, Structure]] = {
# Basic types
"hybrid_prop_str": String,
"hybrid_prop_int": Int,
"hybrid_prop_float": Float,
"hybrid_prop_bool": Boolean,
"hybrid_prop_decimal": String, # Decimals should be serialized Strings
"hybrid_prop_date": Date,
"hybrid_prop_time": Time,
"hybrid_prop_datetime": DateTime,
"hybrid_prop_str": graphene.String,
"hybrid_prop_int": graphene.Int,
"hybrid_prop_float": graphene.Float,
"hybrid_prop_bool": graphene.Boolean,
"hybrid_prop_decimal": graphene.String, # Decimals should be serialized Strings
"hybrid_prop_date": graphene.Date,
"hybrid_prop_time": graphene.Time,
"hybrid_prop_datetime": graphene.DateTime,
# Lists and Nested Lists
"hybrid_prop_list_int": List(Int),
"hybrid_prop_list_date": List(Date),
"hybrid_prop_nested_list_int": List(List(Int)),
"hybrid_prop_deeply_nested_list_int": List(List(List(Int))),
"hybrid_prop_list_int": graphene.List(graphene.Int),
"hybrid_prop_list_date": graphene.List(graphene.Date),
"hybrid_prop_nested_list_int": graphene.List(graphene.List(graphene.Int)),
"hybrid_prop_deeply_nested_list_int": graphene.List(graphene.List(graphene.List(graphene.Int))),
"hybrid_prop_first_shopping_cart_item": ShoppingCartItemType,
"hybrid_prop_shopping_cart_item_list": List(ShoppingCartItemType),
"hybrid_prop_unsupported_type_tuple": String,
"hybrid_prop_shopping_cart_item_list": graphene.List(ShoppingCartItemType),
"hybrid_prop_unsupported_type_tuple": graphene.String,
# Self Referential List
"hybrid_prop_self_referential": ShoppingCartType,
"hybrid_prop_self_referential_list": List(ShoppingCartType),
"hybrid_prop_self_referential_list": graphene.List(ShoppingCartType),
# Optionals
"hybrid_prop_optional_self_referential": ShoppingCartType,
}
@@ -468,8 +619,8 @@ class Meta:

# this is a simple way of showing the failed property name
# instead of having to unroll the loop.
assert (
(hybrid_prop_name, str(hybrid_prop_field.type)) ==
(hybrid_prop_name, str(hybrid_prop_expected_return_type))
assert (hybrid_prop_name, str(hybrid_prop_field.type)) == (
hybrid_prop_name,
str(hybrid_prop_expected_return_type),
)
assert hybrid_prop_field.description is None # "doc" is ignored by hybrid property
56 changes: 55 additions & 1 deletion graphene_sqlalchemy/tests/test_registry.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import pytest
from sqlalchemy.types import Enum as SQLAlchemyEnum

import graphene
from graphene import Enum as GrapheneEnum

from ..registry import Registry
from ..types import SQLAlchemyObjectType
from ..utils import EnumValue
from .models import Pet
from .models import Pet, Reporter


def test_register_object_type():
@@ -126,3 +127,56 @@ class Meta:
re_err = r"Expected Graphene Enum, but got: .*PetType.*"
with pytest.raises(TypeError, match=re_err):
reg.register_sort_enum(PetType, PetType)


def test_register_union():
reg = Registry()

class PetType(SQLAlchemyObjectType):
class Meta:
model = Pet
registry = reg

class ReporterType(SQLAlchemyObjectType):
class Meta:
model = Reporter

union_types = [PetType, ReporterType]
union = graphene.Union('ReporterPet', tuple(union_types))

reg.register_union_type(union, union_types)

assert reg.get_union_for_object_types(union_types) == union
# Order should not matter
assert reg.get_union_for_object_types([ReporterType, PetType]) == union


def test_register_union_scalar():
reg = Registry()

union_types = [graphene.String, graphene.Int]
union = graphene.Union('StringInt', tuple(union_types))

re_err = r"Expected Graphene ObjectType, but got: .*String.*"
with pytest.raises(TypeError, match=re_err):
reg.register_union_type(union, union_types)


def test_register_union_incorrect_types():
reg = Registry()

class PetType(SQLAlchemyObjectType):
class Meta:
model = Pet
registry = reg

class ReporterType(SQLAlchemyObjectType):
class Meta:
model = Reporter

union_types = [PetType, ReporterType]
union = PetType

re_err = r"Expected graphene.Union, but got: .*PetType.*"
with pytest.raises(TypeError, match=re_err):
reg.register_union_type(union, union_types)
25 changes: 24 additions & 1 deletion graphene_sqlalchemy/tests/test_sort_enums.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
from ..fields import SQLAlchemyConnectionField
from ..types import SQLAlchemyObjectType
from ..utils import to_type_name
from .models import Base, HairKind, Pet
from .models import Base, HairKind, KeyedModel, Pet
from .test_query import to_std_dicts


@@ -383,3 +383,26 @@ def makeNodes(nodeList):
assert [node["node"]["name"] for node in result.data["noSort"]["edges"]] == [
node["node"]["name"] for node in result.data["noDefaultSort"]["edges"]
]


def test_sort_enum_from_key_issue_330():
"""
Verifies that the sort enum name is generated from the column key instead of the name,
in case the column has an invalid enum name. See #330
"""

class KeyedType(SQLAlchemyObjectType):
class Meta:
model = KeyedModel

sort_enum = KeyedType.sort_enum()
assert isinstance(sort_enum, type(Enum))
assert sort_enum._meta.name == "KeyedTypeSortEnum"
assert list(sort_enum._meta.enum.__members__) == [
"ID_ASC",
"ID_DESC",
"REPORTER_NUMBER_ASC",
"REPORTER_NUMBER_DESC",
]
assert str(sort_enum.REPORTER_NUMBER_ASC.value.value) == 'test330."% reporter_number" ASC'
assert str(sort_enum.REPORTER_NUMBER_DESC.value.value) == 'test330."% reporter_number" DESC'
2 changes: 1 addition & 1 deletion graphene_sqlalchemy/tests/test_types.py
Original file line number Diff line number Diff line change
@@ -76,7 +76,7 @@ class Meta:

assert sorted(list(ReporterType._meta.fields.keys())) == sorted([
# Columns
"column_prop", # SQLAlchemy retuns column properties first
"column_prop",
"id",
"first_name",
"last_name",
8 changes: 6 additions & 2 deletions graphene_sqlalchemy/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -3,8 +3,8 @@

from graphene import Enum, List, ObjectType, Schema, String

from ..utils import (get_session, sort_argument_for_model, sort_enum_for_model,
to_enum_value_name, to_type_name)
from ..utils import (DummyImport, get_session, sort_argument_for_model,
sort_enum_for_model, to_enum_value_name, to_type_name)
from .models import Base, Editor, Pet


@@ -99,3 +99,7 @@ class MultiplePK(Base):
assert set(arg.default_value) == set(
(MultiplePK.foo.name + "_asc", MultiplePK.bar.name + "_asc")
)

def test_dummy_import():
dummy_module = DummyImport()
assert dummy_module.foo == object
9 changes: 7 additions & 2 deletions graphene_sqlalchemy/utils.py
Original file line number Diff line number Diff line change
@@ -8,8 +8,6 @@
from sqlalchemy.orm import class_mapper, object_mapper
from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError

from graphene_sqlalchemy.registry import get_global_registry


def get_session(context):
return context.get("session")
@@ -203,7 +201,14 @@ def safe_isinstance_checker(arg):


def registry_sqlalchemy_model_from_str(model_name: str) -> Optional[Any]:
from graphene_sqlalchemy.registry import get_global_registry
try:
return next(filter(lambda x: x.__name__ == model_name, list(get_global_registry()._registry.keys())))
except StopIteration:
pass


class DummyImport:
"""The dummy module returns 'object' for a query for any member"""
def __getattr__(self, name):
return object