diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 30f6dedd..00000000 --- a/.flake8 +++ /dev/null @@ -1,4 +0,0 @@ -[flake8] -ignore = E203,W503 -exclude = .git,.mypy_cache,.pytest_cache,.tox,.venv,__pycache__,build,dist,docs -max-line-length = 120 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 66db3814..470a29eb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ default_language_version: - python: python3.10 + python: python3.7 repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.2.0 @@ -16,6 +16,14 @@ repos: hooks: - id: isort name: isort (python) + - repo: https://github.com/asottile/pyupgrade + rev: v2.37.3 + hooks: + - id: pyupgrade + - repo: https://github.com/psf/black + rev: 22.6.0 + hooks: + - id: black - repo: https://github.com/PyCQA/flake8 rev: 4.0.0 hooks: diff --git a/docs/conf.py b/docs/conf.py index 3fa6391d..9c9fc1d7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,6 +1,6 @@ import os -on_rtd = os.environ.get('READTHEDOCS', None) == 'True' +on_rtd = os.environ.get("READTHEDOCS", None) == "True" # -*- coding: utf-8 -*- # @@ -34,46 +34,46 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.coverage', - 'sphinx.ext.viewcode', + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.coverage", + "sphinx.ext.viewcode", ] if not on_rtd: extensions += [ - 'sphinx.ext.githubpages', + "sphinx.ext.githubpages", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. # # source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'Graphene Django' -copyright = u'Graphene 2016' -author = u'Syrus Akbary' +project = "Graphene Django" +copyright = "Graphene 2016" +author = "Syrus Akbary" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = u'1.0' +version = "1.0" # The full version, including alpha/beta/rc tags. -release = u'1.0.dev' +release = "1.0.dev" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -94,7 +94,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # The reST default role (used for this markup: `text`) to use for all # documents. @@ -116,7 +116,7 @@ # show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. # modindex_common_prefix = [] @@ -175,7 +175,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied @@ -255,34 +255,30 @@ # html_search_scorer = 'scorer.js' # Output file base name for HTML help builder. -htmlhelp_basename = 'Graphenedoc' +htmlhelp_basename = "Graphenedoc" # -- Options for LaTeX output --------------------------------------------- latex_elements = { - # The paper size ('letterpaper' or 'a4paper'). - # - # 'papersize': 'letterpaper', - - # The font size ('10pt', '11pt' or '12pt'). - # - # 'pointsize': '10pt', - - # Additional stuff for the LaTeX preamble. - # - # 'preamble': '', - - # Latex figure (float) alignment - # - # 'figure_align': 'htbp', + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'Graphene.tex', u'Graphene Documentation', - u'Syrus Akbary', 'manual'), + (master_doc, "Graphene.tex", "Graphene Documentation", "Syrus Akbary", "manual"), ] # The name of an image file (relative to this directory) to place at the top of @@ -323,8 +319,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - (master_doc, 'graphene_django', u'Graphene Django Documentation', - [author], 1) + (master_doc, "graphene_django", "Graphene Django Documentation", [author], 1) ] # If true, show URL addresses after external links. @@ -338,9 +333,15 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'Graphene-Django', u'Graphene Django Documentation', - author, 'Graphene Django', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "Graphene-Django", + "Graphene Django Documentation", + author, + "Graphene Django", + "One line description of project.", + "Miscellaneous", + ), ] # Documents to append as an appendix to all manuals. @@ -414,7 +415,7 @@ # epub_post_files = [] # A list of files that should not be packed into the epub file. -epub_exclude_files = ['search.html'] +epub_exclude_files = ["search.html"] # The depth of the table of contents in toc.ncx. # @@ -446,4 +447,4 @@ # Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {'https://docs.python.org/': None} +intersphinx_mapping = {"https://docs.python.org/": None} diff --git a/examples/flask_sqlalchemy/database.py b/examples/flask_sqlalchemy/database.py index ca4d4122..74ec7ca9 100644 --- a/examples/flask_sqlalchemy/database.py +++ b/examples/flask_sqlalchemy/database.py @@ -2,10 +2,10 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import scoped_session, sessionmaker -engine = create_engine('sqlite:///database.sqlite3', convert_unicode=True) -db_session = scoped_session(sessionmaker(autocommit=False, - autoflush=False, - bind=engine)) +engine = create_engine("sqlite:///database.sqlite3", convert_unicode=True) +db_session = scoped_session( + sessionmaker(autocommit=False, autoflush=False, bind=engine) +) Base = declarative_base() Base.query = db_session.query_property() @@ -15,24 +15,25 @@ def init_db(): # they will be registered properly on the metadata. Otherwise # you will have to import them first before calling init_db() from models import Department, Employee, Role + Base.metadata.drop_all(bind=engine) Base.metadata.create_all(bind=engine) # Create the fixtures - engineering = Department(name='Engineering') + engineering = Department(name="Engineering") db_session.add(engineering) - hr = Department(name='Human Resources') + hr = Department(name="Human Resources") db_session.add(hr) - manager = Role(name='manager') + manager = Role(name="manager") db_session.add(manager) - engineer = Role(name='engineer') + engineer = Role(name="engineer") db_session.add(engineer) - peter = Employee(name='Peter', department=engineering, role=engineer) + peter = Employee(name="Peter", department=engineering, role=engineer) db_session.add(peter) - roy = Employee(name='Roy', department=engineering, role=engineer) + roy = Employee(name="Roy", department=engineering, role=engineer) db_session.add(roy) - tracy = Employee(name='Tracy', department=hr, role=manager) + tracy = Employee(name="Tracy", department=hr, role=manager) db_session.add(tracy) db_session.commit() diff --git a/examples/flask_sqlalchemy/models.py b/examples/flask_sqlalchemy/models.py index efbbe690..38f0fd0a 100644 --- a/examples/flask_sqlalchemy/models.py +++ b/examples/flask_sqlalchemy/models.py @@ -4,35 +4,31 @@ class Department(Base): - __tablename__ = 'department' + __tablename__ = "department" id = Column(Integer, primary_key=True) name = Column(String) class Role(Base): - __tablename__ = 'roles' + __tablename__ = "roles" role_id = Column(Integer, primary_key=True) name = Column(String) class Employee(Base): - __tablename__ = 'employee' + __tablename__ = "employee" id = Column(Integer, primary_key=True) name = Column(String) # Use default=func.now() to set the default hiring time # of an Employee to be the current time when an # Employee record was created hired_on = Column(DateTime, default=func.now()) - department_id = Column(Integer, ForeignKey('department.id')) - role_id = Column(Integer, ForeignKey('roles.role_id')) + department_id = Column(Integer, ForeignKey("department.id")) + role_id = Column(Integer, ForeignKey("roles.role_id")) # Use cascade='delete,all' to propagate the deletion of a Department onto its Employees department = relationship( - Department, - backref=backref('employees', - uselist=True, - cascade='delete,all')) + Department, backref=backref("employees", uselist=True, cascade="delete,all") + ) role = relationship( - Role, - backref=backref('roles', - uselist=True, - cascade='delete,all')) + Role, backref=backref("roles", uselist=True, cascade="delete,all") + ) diff --git a/examples/flask_sqlalchemy/schema.py b/examples/flask_sqlalchemy/schema.py index ea525e3b..c4a91e63 100644 --- a/examples/flask_sqlalchemy/schema.py +++ b/examples/flask_sqlalchemy/schema.py @@ -10,26 +10,27 @@ class Department(SQLAlchemyObjectType): class Meta: model = DepartmentModel - interfaces = (relay.Node, ) + interfaces = (relay.Node,) class Employee(SQLAlchemyObjectType): class Meta: model = EmployeeModel - interfaces = (relay.Node, ) + interfaces = (relay.Node,) class Role(SQLAlchemyObjectType): class Meta: model = RoleModel - interfaces = (relay.Node, ) + interfaces = (relay.Node,) class Query(graphene.ObjectType): node = relay.Node.Field() # Allow only single column sorting all_employees = SQLAlchemyConnectionField( - Employee.connection, sort=Employee.sort_argument()) + Employee.connection, sort=Employee.sort_argument() + ) # Allows sorting over multiple columns, by default over the primary key all_roles = SQLAlchemyConnectionField(Role.connection) # Disable sorting over this field diff --git a/examples/nameko_sqlalchemy/app.py b/examples/nameko_sqlalchemy/app.py index 05352529..64d305ea 100755 --- a/examples/nameko_sqlalchemy/app.py +++ b/examples/nameko_sqlalchemy/app.py @@ -1,37 +1,45 @@ from database import db_session, init_db from schema import schema -from graphql_server import (HttpQueryError, default_format_error, - encode_execution_results, json_encode, - load_json_body, run_http_query) - - -class App(): - def __init__(self): - init_db() - - def query(self, request): - data = self.parse_body(request) - execution_results, params = run_http_query( - schema, - 'post', - data) - result, status_code = encode_execution_results( - execution_results, - format_error=default_format_error,is_batch=False, encode=json_encode) - return result - - def parse_body(self,request): - # We use mimetype here since we don't need the other - # information provided by content_type - content_type = request.mimetype - if content_type == 'application/graphql': - return {'query': request.data.decode('utf8')} - - elif content_type == 'application/json': - return load_json_body(request.data.decode('utf8')) - - elif content_type in ('application/x-www-form-urlencoded', 'multipart/form-data'): - return request.form - - return {} +from graphql_server import ( + HttpQueryError, + default_format_error, + encode_execution_results, + json_encode, + load_json_body, + run_http_query, +) + + +class App: + def __init__(self): + init_db() + + def query(self, request): + data = self.parse_body(request) + execution_results, params = run_http_query(schema, "post", data) + result, status_code = encode_execution_results( + execution_results, + format_error=default_format_error, + is_batch=False, + encode=json_encode, + ) + return result + + def parse_body(self, request): + # We use mimetype here since we don't need the other + # information provided by content_type + content_type = request.mimetype + if content_type == "application/graphql": + return {"query": request.data.decode("utf8")} + + elif content_type == "application/json": + return load_json_body(request.data.decode("utf8")) + + elif content_type in ( + "application/x-www-form-urlencoded", + "multipart/form-data", + ): + return request.form + + return {} diff --git a/examples/nameko_sqlalchemy/database.py b/examples/nameko_sqlalchemy/database.py index ca4d4122..74ec7ca9 100644 --- a/examples/nameko_sqlalchemy/database.py +++ b/examples/nameko_sqlalchemy/database.py @@ -2,10 +2,10 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import scoped_session, sessionmaker -engine = create_engine('sqlite:///database.sqlite3', convert_unicode=True) -db_session = scoped_session(sessionmaker(autocommit=False, - autoflush=False, - bind=engine)) +engine = create_engine("sqlite:///database.sqlite3", convert_unicode=True) +db_session = scoped_session( + sessionmaker(autocommit=False, autoflush=False, bind=engine) +) Base = declarative_base() Base.query = db_session.query_property() @@ -15,24 +15,25 @@ def init_db(): # they will be registered properly on the metadata. Otherwise # you will have to import them first before calling init_db() from models import Department, Employee, Role + Base.metadata.drop_all(bind=engine) Base.metadata.create_all(bind=engine) # Create the fixtures - engineering = Department(name='Engineering') + engineering = Department(name="Engineering") db_session.add(engineering) - hr = Department(name='Human Resources') + hr = Department(name="Human Resources") db_session.add(hr) - manager = Role(name='manager') + manager = Role(name="manager") db_session.add(manager) - engineer = Role(name='engineer') + engineer = Role(name="engineer") db_session.add(engineer) - peter = Employee(name='Peter', department=engineering, role=engineer) + peter = Employee(name="Peter", department=engineering, role=engineer) db_session.add(peter) - roy = Employee(name='Roy', department=engineering, role=engineer) + roy = Employee(name="Roy", department=engineering, role=engineer) db_session.add(roy) - tracy = Employee(name='Tracy', department=hr, role=manager) + tracy = Employee(name="Tracy", department=hr, role=manager) db_session.add(tracy) db_session.commit() diff --git a/examples/nameko_sqlalchemy/models.py b/examples/nameko_sqlalchemy/models.py index efbbe690..38f0fd0a 100644 --- a/examples/nameko_sqlalchemy/models.py +++ b/examples/nameko_sqlalchemy/models.py @@ -4,35 +4,31 @@ class Department(Base): - __tablename__ = 'department' + __tablename__ = "department" id = Column(Integer, primary_key=True) name = Column(String) class Role(Base): - __tablename__ = 'roles' + __tablename__ = "roles" role_id = Column(Integer, primary_key=True) name = Column(String) class Employee(Base): - __tablename__ = 'employee' + __tablename__ = "employee" id = Column(Integer, primary_key=True) name = Column(String) # Use default=func.now() to set the default hiring time # of an Employee to be the current time when an # Employee record was created hired_on = Column(DateTime, default=func.now()) - department_id = Column(Integer, ForeignKey('department.id')) - role_id = Column(Integer, ForeignKey('roles.role_id')) + department_id = Column(Integer, ForeignKey("department.id")) + role_id = Column(Integer, ForeignKey("roles.role_id")) # Use cascade='delete,all' to propagate the deletion of a Department onto its Employees department = relationship( - Department, - backref=backref('employees', - uselist=True, - cascade='delete,all')) + Department, backref=backref("employees", uselist=True, cascade="delete,all") + ) role = relationship( - Role, - backref=backref('roles', - uselist=True, - cascade='delete,all')) + Role, backref=backref("roles", uselist=True, cascade="delete,all") + ) diff --git a/examples/nameko_sqlalchemy/service.py b/examples/nameko_sqlalchemy/service.py index d9c519c9..7f4c5078 100644 --- a/examples/nameko_sqlalchemy/service.py +++ b/examples/nameko_sqlalchemy/service.py @@ -4,8 +4,8 @@ class DepartmentService: - name = 'department' + name = "department" - @http('POST', '/graphql') + @http("POST", "/graphql") def query(self, request): return App().query(request) diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py index f6f14a6e..275d5904 100644 --- a/graphene_sqlalchemy/batching.py +++ b/graphene_sqlalchemy/batching.py @@ -7,8 +7,7 @@ from sqlalchemy.orm import Session, strategies from sqlalchemy.orm.query import QueryContext -from .utils import (is_graphene_version_less_than, - is_sqlalchemy_version_less_than) +from .utils import is_graphene_version_less_than, is_sqlalchemy_version_less_than class RelationshipLoader(aiodataloader.DataLoader): @@ -59,13 +58,13 @@ async def batch_load_fn(self, parents): # For our purposes, the query_context will only used to get the session query_context = None - if is_sqlalchemy_version_less_than('1.4'): + if is_sqlalchemy_version_less_than("1.4"): query_context = QueryContext(session.query(parent_mapper.entity)) else: parent_mapper_query = session.query(parent_mapper.entity) query_context = parent_mapper_query._compile_context() - if is_sqlalchemy_version_less_than('1.4'): + if is_sqlalchemy_version_less_than("1.4"): self.selectin_loader._load_for_path( query_context, parent_mapper._path_registry, @@ -82,9 +81,7 @@ async def batch_load_fn(self, parents): child_mapper, None, ) - return [ - getattr(parent, self.relationship_prop.key) for parent in parents - ] + return [getattr(parent, self.relationship_prop.key) for parent in parents] # Cache this across `batch_load_fn` calls @@ -117,7 +114,7 @@ def _get_loader(relationship_prop): loader = RELATIONSHIP_LOADERS_CACHE.get(relationship_prop, None) if loader is None or loader.loop != get_event_loop(): selectin_loader = strategies.SelectInLoader( - relationship_prop, (('lazy', 'selectin'),) + relationship_prop, (("lazy", "selectin"),) ) loader = RelationshipLoader( relationship_prop=relationship_prop, diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 1e7846eb..d1873c2b 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -15,13 +15,16 @@ from .batching import get_batch_resolver from .enums import enum_for_sa_enum -from .fields import (BatchSQLAlchemyConnectionField, - default_connection_field_factory) +from .fields import BatchSQLAlchemyConnectionField, default_connection_field_factory from .registry import get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver -from .utils import (DummyImport, registry_sqlalchemy_model_from_str, - safe_isinstance, singledispatchbymatchfunction, - value_equals) +from .utils import ( + DummyImport, + registry_sqlalchemy_model_from_str, + safe_isinstance, + singledispatchbymatchfunction, + value_equals, +) try: from typing import ForwardRef @@ -39,7 +42,7 @@ except ImportError: sqa_utils = DummyImport() -is_selectin_available = getattr(strategies, 'SelectInLoader', None) +is_selectin_available = getattr(strategies, "SelectInLoader", None) def get_column_doc(column): @@ -50,8 +53,14 @@ def is_column_nullable(column): return bool(getattr(column, "nullable", True)) -def convert_sqlalchemy_relationship(relationship_prop, obj_type, connection_field_factory, batching, - orm_field_name, **field_kwargs): +def convert_sqlalchemy_relationship( + relationship_prop, + obj_type, + connection_field_factory, + batching, + orm_field_name, + **field_kwargs, +): """ :param sqlalchemy.RelationshipProperty relationship_prop: :param SQLAlchemyObjectType obj_type: @@ -65,24 +74,34 @@ def convert_sqlalchemy_relationship(relationship_prop, obj_type, connection_fiel def dynamic_type(): """:rtype: Field|None""" direction = relationship_prop.direction - child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity) + child_type = obj_type._meta.registry.get_type_for_model( + relationship_prop.mapper.entity + ) batching_ = batching if is_selectin_available else False if not child_type: return None if direction == interfaces.MANYTOONE or not relationship_prop.uselist: - return _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching_, orm_field_name, - **field_kwargs) + return _convert_o2o_or_m2o_relationship( + relationship_prop, obj_type, batching_, orm_field_name, **field_kwargs + ) if direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY): - return _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching_, - connection_field_factory, **field_kwargs) + return _convert_o2m_or_m2m_relationship( + relationship_prop, + obj_type, + batching_, + connection_field_factory, + **field_kwargs, + ) return graphene.Dynamic(dynamic_type) -def _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching, orm_field_name, **field_kwargs): +def _convert_o2o_or_m2o_relationship( + relationship_prop, obj_type, batching, orm_field_name, **field_kwargs +): """ Convert one-to-one or many-to-one relationshsip. Return an object field. @@ -93,17 +112,24 @@ def _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching, orm_ :param dict field_kwargs: :rtype: Field """ - child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity) + child_type = obj_type._meta.registry.get_type_for_model( + relationship_prop.mapper.entity + ) resolver = get_custom_resolver(obj_type, orm_field_name) if resolver is None: - resolver = get_batch_resolver(relationship_prop) if batching else \ - get_attr_resolver(obj_type, relationship_prop.key) + resolver = ( + get_batch_resolver(relationship_prop) + if batching + else get_attr_resolver(obj_type, relationship_prop.key) + ) return graphene.Field(child_type, resolver=resolver, **field_kwargs) -def _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching, connection_field_factory, **field_kwargs): +def _convert_o2m_or_m2m_relationship( + relationship_prop, obj_type, batching, connection_field_factory, **field_kwargs +): """ Convert one-to-many or many-to-many relationshsip. Return a list field or a connection field. @@ -114,30 +140,34 @@ def _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching, conn :param dict field_kwargs: :rtype: Field """ - child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity) + child_type = obj_type._meta.registry.get_type_for_model( + relationship_prop.mapper.entity + ) if not child_type._meta.connection: return graphene.Field(graphene.List(child_type), **field_kwargs) # TODO Allow override of connection_field_factory and resolver via ORMField if connection_field_factory is None: - connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship if batching else \ - default_connection_field_factory - - return connection_field_factory(relationship_prop, obj_type._meta.registry, **field_kwargs) + connection_field_factory = ( + BatchSQLAlchemyConnectionField.from_relationship + if batching + else default_connection_field_factory + ) + + return connection_field_factory( + relationship_prop, obj_type._meta.registry, **field_kwargs + ) def convert_sqlalchemy_hybrid_method(hybrid_prop, resolver, **field_kwargs): - if 'type_' not in field_kwargs: - field_kwargs['type_'] = convert_hybrid_property_return_type(hybrid_prop) + if "type_" not in field_kwargs: + field_kwargs["type_"] = convert_hybrid_property_return_type(hybrid_prop) - if 'description' not in field_kwargs: - field_kwargs['description'] = getattr(hybrid_prop, "__doc__", None) + if "description" not in field_kwargs: + field_kwargs["description"] = getattr(hybrid_prop, "__doc__", None) - return graphene.Field( - resolver=resolver, - **field_kwargs - ) + return graphene.Field(resolver=resolver, **field_kwargs) def convert_sqlalchemy_composite(composite_prop, registry, resolver): @@ -177,14 +207,14 @@ def inner(fn): def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs): column = column_prop.columns[0] - field_kwargs.setdefault('type_', convert_sqlalchemy_type(getattr(column, "type", None), column, registry)) - field_kwargs.setdefault('required', not is_column_nullable(column)) - field_kwargs.setdefault('description', get_column_doc(column)) - - return graphene.Field( - resolver=resolver, - **field_kwargs + field_kwargs.setdefault( + "type_", + convert_sqlalchemy_type(getattr(column, "type", None), column, registry), ) + field_kwargs.setdefault("required", not is_column_nullable(column)) + field_kwargs.setdefault("description", get_column_doc(column)) + + return graphene.Field(resolver=resolver, **field_kwargs) @singledispatch @@ -271,14 +301,20 @@ def convert_scalar_list_to_list(type, column, registry=None): def init_array_list_recursive(inner_type, n): - return inner_type if n == 0 else graphene.List(init_array_list_recursive(inner_type, n - 1)) + return ( + inner_type + if n == 0 + else graphene.List(init_array_list_recursive(inner_type, n - 1)) + ) @convert_sqlalchemy_type.register(sqa_types.ARRAY) @convert_sqlalchemy_type.register(postgresql.ARRAY) def convert_array_to_list(_type, column, registry=None): inner_type = convert_sqlalchemy_type(column.type.item_type, column) - return graphene.List(init_array_list_recursive(inner_type, (column.type.dimensions or 1) - 1)) + return graphene.List( + init_array_list_recursive(inner_type, (column.type.dimensions or 1) - 1) + ) @convert_sqlalchemy_type.register(postgresql.HSTORE) @@ -313,8 +349,8 @@ def convert_sqlalchemy_hybrid_property_type(arg: Any): # No valid type found, warn and fall back to graphene.String warnings.warn( - (f"I don't know how to generate a GraphQL type out of a \"{arg}\" type." - "Falling back to \"graphene.String\"") + f'I don\'t know how to generate a GraphQL type out of a "{arg}" type.' + 'Falling back to "graphene.String"' ) return graphene.String @@ -368,15 +404,17 @@ def is_union(arg) -> bool: if isinstance(arg, UnionType): return True - return getattr(arg, '__origin__', None) == typing.Union + return getattr(arg, "__origin__", None) == typing.Union -def graphene_union_for_py_union(obj_types: typing.List[graphene.ObjectType], registry) -> graphene.Union: +def graphene_union_for_py_union( + obj_types: typing.List[graphene.ObjectType], registry +) -> graphene.Union: union_type = registry.get_union_for_object_types(obj_types) if union_type is None: # Union Name is name of the three - union_name = ''.join(sorted([obj_type._meta.name for obj_type in obj_types])) + union_name = "".join(sorted(obj_type._meta.name for obj_type in obj_types)) union_type = graphene.Union(union_name, obj_types) registry.register_union_type(union_type, obj_types) @@ -411,16 +449,25 @@ def convert_sqlalchemy_hybrid_property_union(arg): return graphene_types[0] # Now check if every type is instance of an ObjectType - if not all(isinstance(graphene_type, type(graphene.ObjectType)) for graphene_type in graphene_types): - raise ValueError("Cannot convert hybrid_property Union to graphene.Union: the Union contains scalars. " - "Please add the corresponding hybrid_property to the excluded fields in the ObjectType, " - "or use an ORMField to override this behaviour.") - - return graphene_union_for_py_union(cast(typing.List[graphene.ObjectType], list(graphene_types)), - get_global_registry()) + if not all( + isinstance(graphene_type, type(graphene.ObjectType)) + for graphene_type in graphene_types + ): + raise ValueError( + "Cannot convert hybrid_property Union to graphene.Union: the Union contains scalars. " + "Please add the corresponding hybrid_property to the excluded fields in the ObjectType, " + "or use an ORMField to override this behaviour." + ) + + return graphene_union_for_py_union( + cast(typing.List[graphene.ObjectType], list(graphene_types)), + get_global_registry(), + ) -@convert_sqlalchemy_hybrid_property_type.register(lambda x: getattr(x, '__origin__', None) in [list, typing.List]) +@convert_sqlalchemy_hybrid_property_type.register( + lambda x: getattr(x, "__origin__", None) in [list, typing.List] +) def convert_sqlalchemy_hybrid_property_type_list_t(arg): # type is either list[T] or List[T], generic argument at __args__[0] internal_type = arg.__args__[0] @@ -459,6 +506,6 @@ def convert_sqlalchemy_hybrid_property_bare_str(arg): def convert_hybrid_property_return_type(hybrid_prop): # Grab the original method's return type annotations from inside the hybrid property - return_type_annotation = hybrid_prop.fget.__annotations__.get('return', str) + return_type_annotation = hybrid_prop.fget.__annotations__.get("return", str) return convert_sqlalchemy_hybrid_property_type(return_type_annotation) diff --git a/graphene_sqlalchemy/enums.py b/graphene_sqlalchemy/enums.py index 19f40b7f..97f8997c 100644 --- a/graphene_sqlalchemy/enums.py +++ b/graphene_sqlalchemy/enums.py @@ -18,9 +18,7 @@ def _convert_sa_to_graphene_enum(sa_enum, fallback_name=None): The Enum value names are converted to upper case if necessary. """ if not isinstance(sa_enum, SQLAlchemyEnumType): - raise TypeError( - "Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum) - ) + raise TypeError("Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum)) enum_class = sa_enum.enum_class if enum_class: if all(to_enum_value_name(key) == key for key in enum_class.__members__): @@ -45,9 +43,7 @@ def _convert_sa_to_graphene_enum(sa_enum, fallback_name=None): def enum_for_sa_enum(sa_enum, registry): """Return the Graphene Enum type for the specified SQLAlchemy Enum type.""" if not isinstance(sa_enum, SQLAlchemyEnumType): - raise TypeError( - "Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum) - ) + raise TypeError("Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum)) enum = registry.get_graphene_enum_for_sa_enum(sa_enum) if not enum: enum = _convert_sa_to_graphene_enum(sa_enum) @@ -60,11 +56,9 @@ def enum_for_field(obj_type, field_name): from .types import SQLAlchemyObjectType if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyObjectType): - raise TypeError( - "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)) + raise TypeError("Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)) if not field_name or not isinstance(field_name, str): - raise TypeError( - "Expected a field name, but got: {!r}".format(field_name)) + raise TypeError("Expected a field name, but got: {!r}".format(field_name)) registry = obj_type._meta.registry orm_field = registry.get_orm_field_for_graphene_field(obj_type, field_name) if orm_field is None: @@ -166,7 +160,7 @@ def sort_argument_for_object_type( get_symbol_name=None, has_default=True, ): - """"Returns Graphene Argument for sorting the given SQLAlchemyObjectType. + """ "Returns Graphene Argument for sorting the given SQLAlchemyObjectType. Parameters - obj_type : SQLAlchemyObjectType diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 9b4b8436..2cb53c55 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -26,9 +26,7 @@ def type(self): assert issubclass(nullable_type, SQLAlchemyObjectType), ( "SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}" ).format(nullable_type.__name__) - assert ( - nullable_type.connection - ), "The type {} doesn't have a connection".format( + assert nullable_type.connection, "The type {} doesn't have a connection".format( nullable_type.__name__ ) assert type_ == nullable_type, ( @@ -39,7 +37,11 @@ def type(self): def __init__(self, type_, *args, **kwargs): nullable_type = get_nullable_type(type_) - if "sort" not in kwargs and nullable_type and issubclass(nullable_type, Connection): + if ( + "sort" not in kwargs + and nullable_type + and issubclass(nullable_type, Connection) + ): # Let super class raise if type is not a Connection try: kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument()) @@ -151,7 +153,9 @@ class BatchSQLAlchemyConnectionField(SQLAlchemyConnectionField): def connection_resolver(cls, resolver, connection_type, model, root, info, **args): if root is None: resolved = resolver(root, info, **args) - on_resolve = partial(cls.resolve_connection, connection_type, model, info, args) + on_resolve = partial( + cls.resolve_connection, connection_type, model, info, args + ) else: relationship_prop = None for relationship in root.__class__.__mapper__.relationships: @@ -159,7 +163,9 @@ def connection_resolver(cls, resolver, connection_type, model, root, info, **arg relationship_prop = relationship break resolved = get_batch_resolver(relationship_prop)(root, info, **args) - on_resolve = partial(cls.resolve_connection, connection_type, root, info, args) + on_resolve = partial( + cls.resolve_connection, connection_type, root, info, args + ) if is_thenable(resolved): return Promise.resolve(resolved).then(on_resolve) @@ -170,7 +176,11 @@ def connection_resolver(cls, resolver, connection_type, model, root, info, **arg def from_relationship(cls, relationship, registry, **field_kwargs): model = relationship.mapper.entity model_type = registry.get_type_for_model(model) - return cls(model_type.connection, resolver=get_batch_resolver(relationship), **field_kwargs) + return cls( + model_type.connection, + resolver=get_batch_resolver(relationship), + **field_kwargs + ) def default_connection_field_factory(relationship, registry, **field_kwargs): @@ -185,8 +195,8 @@ def default_connection_field_factory(relationship, registry, **field_kwargs): def createConnectionField(type_, **field_kwargs): warnings.warn( - 'createConnectionField is deprecated and will be removed in the next ' - 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.', + "createConnectionField is deprecated and will be removed in the next " + "major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.", DeprecationWarning, ) return __connectionFactory(type_, **field_kwargs) @@ -194,8 +204,8 @@ def createConnectionField(type_, **field_kwargs): def registerConnectionFieldFactory(factoryMethod): warnings.warn( - 'registerConnectionFieldFactory is deprecated and will be removed in the next ' - 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.', + "registerConnectionFieldFactory is deprecated and will be removed in the next " + "major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.", DeprecationWarning, ) global __connectionFactory @@ -204,8 +214,8 @@ def registerConnectionFieldFactory(factoryMethod): def unregisterConnectionFieldFactory(): warnings.warn( - 'registerConnectionFieldFactory is deprecated and will be removed in the next ' - 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.', + "registerConnectionFieldFactory is deprecated and will be removed in the next " + "major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.", DeprecationWarning, ) global __connectionFactory diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index 80470d9b..8f2bc9e7 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -20,8 +20,9 @@ def __init__(self): def register(self, obj_type): from .types import SQLAlchemyObjectType + if not isinstance(obj_type, type) or not issubclass( - obj_type, SQLAlchemyObjectType + obj_type, SQLAlchemyObjectType ): raise TypeError( "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) @@ -40,7 +41,7 @@ def register_orm_field(self, obj_type, field_name, orm_field): from .types import SQLAlchemyObjectType if not isinstance(obj_type, type) or not issubclass( - obj_type, SQLAlchemyObjectType + obj_type, SQLAlchemyObjectType ): raise TypeError( "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) @@ -76,8 +77,9 @@ def get_graphene_enum_for_sa_enum(self, sa_enum: SQLAlchemyEnumType): def register_sort_enum(self, obj_type, sort_enum: Enum): from .types import SQLAlchemyObjectType + if not isinstance(obj_type, type) or not issubclass( - obj_type, SQLAlchemyObjectType + obj_type, SQLAlchemyObjectType ): raise TypeError( "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) @@ -89,11 +91,11 @@ def register_sort_enum(self, obj_type, sort_enum: Enum): def get_sort_enum_for_object_type(self, obj_type: graphene.ObjectType): return self._registry_sort_enums.get(obj_type) - def register_union_type(self, union: graphene.Union, obj_types: List[Type[graphene.ObjectType]]): + def register_union_type( + self, union: graphene.Union, obj_types: List[Type[graphene.ObjectType]] + ): if not isinstance(union, graphene.Union): - raise TypeError( - "Expected graphene.Union, but got: {!r}".format(union) - ) + raise TypeError("Expected graphene.Union, but got: {!r}".format(union)) for obj_type in obj_types: if not isinstance(obj_type, type(graphene.ObjectType)): @@ -103,7 +105,7 @@ def register_union_type(self, union: graphene.Union, obj_types: List[Type[graphe self._registry_unions[frozenset(obj_types)] = union - def get_union_for_object_types(self, obj_types : List[Type[graphene.ObjectType]]): + def get_union_for_object_types(self, obj_types: List[Type[graphene.ObjectType]]): return self._registry_unions.get(frozenset(obj_types)) diff --git a/graphene_sqlalchemy/resolvers.py b/graphene_sqlalchemy/resolvers.py index 83a6e35d..e8e61911 100644 --- a/graphene_sqlalchemy/resolvers.py +++ b/graphene_sqlalchemy/resolvers.py @@ -7,7 +7,7 @@ def get_custom_resolver(obj_type, orm_field_name): does not have a `resolver`, we need to re-implement that logic here so users are able to override the default resolvers that we provide. """ - resolver = getattr(obj_type, 'resolve_{}'.format(orm_field_name), None) + resolver = getattr(obj_type, "resolve_{}".format(orm_field_name), None) if resolver: return get_unbound_function(resolver) diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 34ba9d8a..357ad96e 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -8,7 +8,7 @@ from ..registry import reset_global_registry from .models import Base, CompositeFullName -test_db_url = 'sqlite://' # use in-memory database for tests +test_db_url = "sqlite://" # use in-memory database for tests @pytest.fixture(autouse=True) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index c7a1d664..fd5d3b21 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -5,8 +5,18 @@ from decimal import Decimal from typing import List, Optional, Tuple -from sqlalchemy import (Column, Date, Enum, ForeignKey, Integer, Numeric, - String, Table, func, select) +from sqlalchemy import ( + Column, + Date, + Enum, + ForeignKey, + Integer, + Numeric, + String, + Table, + func, + select, +) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import column_property, composite, mapper, relationship @@ -15,8 +25,8 @@ class HairKind(enum.Enum): - LONG = 'long' - SHORT = 'short' + LONG = "long" + SHORT = "short" Base = declarative_base() @@ -64,7 +74,9 @@ class Reporter(Base): last_name = Column(String(30), doc="Last name") email = Column(String(), doc="Email") favorite_pet_kind = Column(PetKind) - pets = relationship("Pet", secondary=association_table, backref="reporters", order_by="Pet.id") + pets = relationship( + "Pet", secondary=association_table, backref="reporters", order_by="Pet.id" + ) articles = relationship("Article", backref="reporter") favorite_article = relationship("Article", uselist=False) @@ -101,7 +113,9 @@ def hybrid_prop_list(self) -> List[int]: select([func.cast(func.count(id), Integer)]), doc="Column property" ) - composite_prop = composite(CompositeFullName, first_name, last_name, doc="Composite") + composite_prop = composite( + CompositeFullName, first_name, last_name, doc="Composite" + ) class Article(Base): @@ -155,7 +169,7 @@ class ShoppingCartItem(Base): id = Column(Integer(), primary_key=True) @hybrid_property - def hybrid_prop_shopping_cart(self) -> List['ShoppingCart']: + def hybrid_prop_shopping_cart(self) -> List["ShoppingCart"]: return [ShoppingCart(id=1)] @@ -210,11 +224,17 @@ def hybrid_prop_list_date(self) -> List[datetime.date]: @hybrid_property def hybrid_prop_nested_list_int(self) -> List[List[int]]: - return [self.hybrid_prop_list_int, ] + return [ + self.hybrid_prop_list_int, + ] @hybrid_property def hybrid_prop_deeply_nested_list_int(self) -> List[List[List[int]]]: - return [[self.hybrid_prop_list_int, ], ] + return [ + [ + self.hybrid_prop_list_int, + ], + ] # Other SQLAlchemy Instances @hybrid_property @@ -234,17 +254,17 @@ def hybrid_prop_unsupported_type_tuple(self) -> Tuple[str, str]: # Self-references @hybrid_property - def hybrid_prop_self_referential(self) -> 'ShoppingCart': + def hybrid_prop_self_referential(self) -> "ShoppingCart": return ShoppingCart(id=1) @hybrid_property - def hybrid_prop_self_referential_list(self) -> List['ShoppingCart']: + def hybrid_prop_self_referential_list(self) -> List["ShoppingCart"]: return [ShoppingCart(id=1)] # Optional[T] @hybrid_property - def hybrid_prop_optional_self_referential(self) -> Optional['ShoppingCart']: + def hybrid_prop_optional_self_referential(self) -> Optional["ShoppingCart"]: return None diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index fc4e6649..90df0279 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -7,8 +7,7 @@ import graphene from graphene import Connection, relay -from ..fields import (BatchSQLAlchemyConnectionField, - default_connection_field_factory) +from ..fields import BatchSQLAlchemyConnectionField, default_connection_field_factory from ..types import ORMField, SQLAlchemyObjectType from ..utils import is_sqlalchemy_version_less_than from .models import Article, HairKind, Pet, Reader, Reporter @@ -17,6 +16,7 @@ class MockLoggingHandler(logging.Handler): """Intercept and store log messages in a list.""" + def __init__(self, *args, **kwargs): self.messages = [] logging.Handler.__init__(self, *args, **kwargs) @@ -28,7 +28,7 @@ def emit(self, record): @contextlib.contextmanager def mock_sqlalchemy_logging_handler(): logging.basicConfig() - sql_logger = logging.getLogger('sqlalchemy.engine') + sql_logger = logging.getLogger("sqlalchemy.engine") previous_level = sql_logger.level sql_logger.setLevel(logging.INFO) @@ -65,10 +65,10 @@ class Query(graphene.ObjectType): reporters = graphene.Field(graphene.List(ReporterType)) def resolve_articles(self, info): - return info.context.get('session').query(Article).all() + return info.context.get("session").query(Article).all() def resolve_reporters(self, info): - return info.context.get('session').query(Reporter).all() + return info.context.get("session").query(Reporter).all() return graphene.Schema(query=Query) @@ -107,8 +107,8 @@ class Query(graphene.ObjectType): return graphene.Schema(query=Query) -if is_sqlalchemy_version_less_than('1.2'): - pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True) +if is_sqlalchemy_version_less_than("1.2"): + pytest.skip("SQL batching only works for SQLAlchemy 1.2+", allow_module_level=True) @pytest.mark.asyncio @@ -116,19 +116,19 @@ async def test_many_to_one(session_factory): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name="Reporter_1", ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name="Reporter_2", ) session.add(reporter_2) - article_1 = Article(headline='Article_1') + article_1 = Article(headline="Article_1") article_1.reporter = reporter_1 session.add(article_1) - article_2 = Article(headline='Article_2') + article_2 = Article(headline="Article_2") article_2.reporter = reporter_2 session.add(article_2) @@ -140,7 +140,8 @@ async def test_many_to_one(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - result = await schema.execute_async(""" + result = await schema.execute_async( + """ query { articles { headline @@ -149,20 +150,26 @@ async def test_many_to_one(session_factory): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages assert len(messages) == 5 - if is_sqlalchemy_version_less_than('1.3'): + if is_sqlalchemy_version_less_than("1.3"): # The batched SQL statement generated is different in 1.2.x # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` # See https://git.io/JewQu - sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN reporters' in message] + sql_statements = [ + message + for message in messages + if "SELECT" in message and "JOIN reporters" in message + ] assert len(sql_statements) == 1 return - if not is_sqlalchemy_version_less_than('1.4'): + if not is_sqlalchemy_version_less_than("1.4"): messages[2] = remove_cache_miss_stat(messages[2]) messages[4] = remove_cache_miss_stat(messages[4]) @@ -194,19 +201,19 @@ async def test_one_to_one(session_factory): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name="Reporter_1", ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name="Reporter_2", ) session.add(reporter_2) - article_1 = Article(headline='Article_1') + article_1 = Article(headline="Article_1") article_1.reporter = reporter_1 session.add(article_1) - article_2 = Article(headline='Article_2') + article_2 = Article(headline="Article_2") article_2.reporter = reporter_2 session.add(article_2) @@ -218,7 +225,8 @@ async def test_one_to_one(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - result = await schema.execute_async(""" + result = await schema.execute_async( + """ query { reporters { firstName @@ -227,20 +235,26 @@ async def test_one_to_one(session_factory): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages assert len(messages) == 5 - if is_sqlalchemy_version_less_than('1.3'): + if is_sqlalchemy_version_less_than("1.3"): # The batched SQL statement generated is different in 1.2.x # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` # See https://git.io/JewQu - sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message] + sql_statements = [ + message + for message in messages + if "SELECT" in message and "JOIN articles" in message + ] assert len(sql_statements) == 1 return - if not is_sqlalchemy_version_less_than('1.4'): + if not is_sqlalchemy_version_less_than("1.4"): messages[2] = remove_cache_miss_stat(messages[2]) messages[4] = remove_cache_miss_stat(messages[4]) @@ -272,27 +286,27 @@ async def test_one_to_many(session_factory): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name="Reporter_1", ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name="Reporter_2", ) session.add(reporter_2) - article_1 = Article(headline='Article_1') + article_1 = Article(headline="Article_1") article_1.reporter = reporter_1 session.add(article_1) - article_2 = Article(headline='Article_2') + article_2 = Article(headline="Article_2") article_2.reporter = reporter_1 session.add(article_2) - article_3 = Article(headline='Article_3') + article_3 = Article(headline="Article_3") article_3.reporter = reporter_2 session.add(article_3) - article_4 = Article(headline='Article_4') + article_4 = Article(headline="Article_4") article_4.reporter = reporter_2 session.add(article_4) @@ -304,7 +318,8 @@ async def test_one_to_many(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - result = await schema.execute_async(""" + result = await schema.execute_async( + """ query { reporters { firstName @@ -317,20 +332,26 @@ async def test_one_to_many(session_factory): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages assert len(messages) == 5 - if is_sqlalchemy_version_less_than('1.3'): + if is_sqlalchemy_version_less_than("1.3"): # The batched SQL statement generated is different in 1.2.x # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` # See https://git.io/JewQu - sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message] + sql_statements = [ + message + for message in messages + if "SELECT" in message and "JOIN articles" in message + ] assert len(sql_statements) == 1 return - if not is_sqlalchemy_version_less_than('1.4'): + if not is_sqlalchemy_version_less_than("1.4"): messages[2] = remove_cache_miss_stat(messages[2]) messages[4] = remove_cache_miss_stat(messages[4]) @@ -384,27 +405,27 @@ async def test_many_to_many(session_factory): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name="Reporter_1", ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name="Reporter_2", ) session.add(reporter_2) - pet_1 = Pet(name='Pet_1', pet_kind='cat', hair_kind=HairKind.LONG) + pet_1 = Pet(name="Pet_1", pet_kind="cat", hair_kind=HairKind.LONG) session.add(pet_1) - pet_2 = Pet(name='Pet_2', pet_kind='cat', hair_kind=HairKind.LONG) + pet_2 = Pet(name="Pet_2", pet_kind="cat", hair_kind=HairKind.LONG) session.add(pet_2) reporter_1.pets.append(pet_1) reporter_1.pets.append(pet_2) - pet_3 = Pet(name='Pet_3', pet_kind='cat', hair_kind=HairKind.LONG) + pet_3 = Pet(name="Pet_3", pet_kind="cat", hair_kind=HairKind.LONG) session.add(pet_3) - pet_4 = Pet(name='Pet_4', pet_kind='cat', hair_kind=HairKind.LONG) + pet_4 = Pet(name="Pet_4", pet_kind="cat", hair_kind=HairKind.LONG) session.add(pet_4) reporter_2.pets.append(pet_3) @@ -418,7 +439,8 @@ async def test_many_to_many(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - result = await schema.execute_async(""" + result = await schema.execute_async( + """ query { reporters { firstName @@ -431,20 +453,26 @@ async def test_many_to_many(session_factory): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages assert len(messages) == 5 - if is_sqlalchemy_version_less_than('1.3'): + if is_sqlalchemy_version_less_than("1.3"): # The batched SQL statement generated is different in 1.2.x # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` # See https://git.io/JewQu - sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN pets' in message] + sql_statements = [ + message + for message in messages + if "SELECT" in message and "JOIN pets" in message + ] assert len(sql_statements) == 1 return - if not is_sqlalchemy_version_less_than('1.4'): + if not is_sqlalchemy_version_less_than("1.4"): messages[2] = remove_cache_miss_stat(messages[2]) messages[4] = remove_cache_miss_stat(messages[4]) @@ -495,9 +523,9 @@ async def test_many_to_many(session_factory): def test_disable_batching_via_ormfield(session_factory): session = session_factory() - reporter_1 = Reporter(first_name='Reporter_1') + reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) - reporter_2 = Reporter(first_name='Reporter_2') + reporter_2 = Reporter(first_name="Reporter_2") session.add(reporter_2) session.commit() session.close() @@ -520,7 +548,7 @@ class Query(graphene.ObjectType): reporters = graphene.Field(graphene.List(ReporterType)) def resolve_reporters(self, info): - return info.context.get('session').query(Reporter).all() + return info.context.get("session").query(Reporter).all() schema = graphene.Schema(query=Query) @@ -528,7 +556,8 @@ def resolve_reporters(self, info): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - schema.execute(""" + schema.execute( + """ query { reporters { favoriteArticle { @@ -536,17 +565,24 @@ def resolve_reporters(self, info): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages - select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + select_statements = [ + message + for message in messages + if "SELECT" in message and "FROM articles" in message + ] assert len(select_statements) == 2 # Test one-to-many and many-to-many relationships with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - schema.execute(""" + schema.execute( + """ query { reporters { articles { @@ -558,19 +594,25 @@ def resolve_reporters(self, info): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages - select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + select_statements = [ + message + for message in messages + if "SELECT" in message and "FROM articles" in message + ] assert len(select_statements) == 2 @pytest.mark.asyncio def test_batch_sorting_with_custom_ormfield(session_factory): session = session_factory() - reporter_1 = Reporter(first_name='Reporter_1') + reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) - reporter_2 = Reporter(first_name='Reporter_2') + reporter_2 = Reporter(first_name="Reporter_2") session.add(reporter_2) session.commit() session.close() @@ -601,7 +643,8 @@ class Meta: with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - result = schema.execute(""" + result = schema.execute( + """ query { reporters(sort: [FIRSTNAME_DESC]) { edges { @@ -611,30 +654,42 @@ class Meta: } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages result = to_std_dicts(result.data) assert result == { - "reporters": {"edges": [ - {"node": { - "firstname": "Reporter_2", - }}, - {"node": { - "firstname": "Reporter_1", - }}, - ]} + "reporters": { + "edges": [ + { + "node": { + "firstname": "Reporter_2", + } + }, + { + "node": { + "firstname": "Reporter_1", + } + }, + ] + } } - select_statements = [message for message in messages if 'SELECT' in message and 'FROM reporters' in message] + select_statements = [ + message + for message in messages + if "SELECT" in message and "FROM reporters" in message + ] assert len(select_statements) == 2 @pytest.mark.asyncio async def test_connection_factory_field_overrides_batching_is_false(session_factory): session = session_factory() - reporter_1 = Reporter(first_name='Reporter_1') + reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) - reporter_2 = Reporter(first_name='Reporter_2') + reporter_2 = Reporter(first_name="Reporter_2") session.add(reporter_2) session.commit() session.close() @@ -657,14 +712,15 @@ class Query(graphene.ObjectType): reporters = graphene.Field(graphene.List(ReporterType)) def resolve_reporters(self, info): - return info.context.get('session').query(Reporter).all() + return info.context.get("session").query(Reporter).all() schema = graphene.Schema(query=Query) with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - await schema.execute_async(""" + await schema.execute_async( + """ query { reporters { articles { @@ -676,24 +732,34 @@ def resolve_reporters(self, info): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages - if is_sqlalchemy_version_less_than('1.3'): + if is_sqlalchemy_version_less_than("1.3"): # The batched SQL statement generated is different in 1.2.x # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` # See https://git.io/JewQu - select_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message] + select_statements = [ + message + for message in messages + if "SELECT" in message and "JOIN articles" in message + ] else: - select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + select_statements = [ + message + for message in messages + if "SELECT" in message and "FROM articles" in message + ] assert len(select_statements) == 1 def test_connection_factory_field_overrides_batching_is_true(session_factory): session = session_factory() - reporter_1 = Reporter(first_name='Reporter_1') + reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) - reporter_2 = Reporter(first_name='Reporter_2') + reporter_2 = Reporter(first_name="Reporter_2") session.add(reporter_2) session.commit() session.close() @@ -716,14 +782,15 @@ class Query(graphene.ObjectType): reporters = graphene.Field(graphene.List(ReporterType)) def resolve_reporters(self, info): - return info.context.get('session').query(Reporter).all() + return info.context.get("session").query(Reporter).all() schema = graphene.Schema(query=Query) with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - schema.execute(""" + schema.execute( + """ query { reporters { articles { @@ -735,10 +802,16 @@ def resolve_reporters(self, info): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages - select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + select_statements = [ + message + for message in messages + if "SELECT" in message and "FROM articles" in message + ] assert len(select_statements) == 2 @@ -751,10 +824,10 @@ async def test_batching_across_nested_relay_schema(session_factory): first_name=first_name, ) session.add(reporter) - article = Article(headline='Article') + article = Article(headline="Article") article.reporter = reporter session.add(article) - reader = Reader(name='Reader') + reader = Reader(name="Reader") reader.articles = [article] session.add(reader) @@ -766,7 +839,8 @@ async def test_batching_across_nested_relay_schema(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - result = await schema.execute_async(""" + result = await schema.execute_async( + """ query { reporters { edges { @@ -790,14 +864,16 @@ async def test_batching_across_nested_relay_schema(session_factory): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages result = to_std_dicts(result.data) - select_statements = [message for message in messages if 'SELECT' in message] + select_statements = [message for message in messages if "SELECT" in message] assert len(select_statements) == 4 assert select_statements[-1].startswith("SELECT articles_1.id") - if is_sqlalchemy_version_less_than('1.3'): + if is_sqlalchemy_version_less_than("1.3"): assert select_statements[-2].startswith("SELECT reporters_1.id") assert "WHERE reporters_1.id IN" in select_statements[-2] else: @@ -810,10 +886,7 @@ async def test_sorting_can_be_used_with_batching_when_using_full_relay(session_f session = session_factory() for first_name, email in zip("cadbbb", "aaabac"): - reporter_1 = Reporter( - first_name=first_name, - email=email - ) + reporter_1 = Reporter(first_name=first_name, email=email) session.add(reporter_1) article_1 = Article(headline="headline") article_1.reporter = reporter_1 @@ -825,7 +898,8 @@ async def test_sorting_can_be_used_with_batching_when_using_full_relay(session_f schema = get_full_relay_schema() session = session_factory() - result = await schema.execute_async(""" + result = await schema.execute_async( + """ query { reporters(sort: [FIRST_NAME_ASC, EMAIL_ASC]) { edges { @@ -836,10 +910,12 @@ async def test_sorting_can_be_used_with_batching_when_using_full_relay(session_f } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) result = to_std_dicts(result.data) assert [ r["node"]["firstName"] + r["node"]["email"] for r in result["reporters"]["edges"] - ] == ['aa', 'ba', 'bb', 'bc', 'ca', 'da'] + ] == ["aa", "ba", "bb", "bc", "ca", "da"] diff --git a/graphene_sqlalchemy/tests/test_benchmark.py b/graphene_sqlalchemy/tests/test_benchmark.py index 11e9d0e0..bb105edd 100644 --- a/graphene_sqlalchemy/tests/test_benchmark.py +++ b/graphene_sqlalchemy/tests/test_benchmark.py @@ -7,8 +7,8 @@ from ..utils import is_sqlalchemy_version_less_than from .models import Article, HairKind, Pet, Reporter -if is_sqlalchemy_version_less_than('1.2'): - pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True) +if is_sqlalchemy_version_less_than("1.2"): + pytest.skip("SQL batching only works for SQLAlchemy 1.2+", allow_module_level=True) def get_schema(): @@ -32,10 +32,10 @@ class Query(graphene.ObjectType): reporters = graphene.Field(graphene.List(ReporterType)) def resolve_articles(self, info): - return info.context.get('session').query(Article).all() + return info.context.get("session").query(Article).all() def resolve_reporters(self, info): - return info.context.get('session').query(Reporter).all() + return info.context.get("session").query(Reporter).all() return graphene.Schema(query=Query) @@ -46,8 +46,8 @@ def benchmark_query(session_factory, benchmark, query): @benchmark def execute_query(): result = schema.execute( - query, - context_value={"session": session_factory()}, + query, + context_value={"session": session_factory()}, ) assert not result.errors @@ -56,26 +56,29 @@ def test_one_to_one(session_factory, benchmark): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name="Reporter_1", ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name="Reporter_2", ) session.add(reporter_2) - article_1 = Article(headline='Article_1') + article_1 = Article(headline="Article_1") article_1.reporter = reporter_1 session.add(article_1) - article_2 = Article(headline='Article_2') + article_2 = Article(headline="Article_2") article_2.reporter = reporter_2 session.add(article_2) session.commit() session.close() - benchmark_query(session_factory, benchmark, """ + benchmark_query( + session_factory, + benchmark, + """ query { reporters { firstName @@ -84,33 +87,37 @@ def test_one_to_one(session_factory, benchmark): } } } - """) + """, + ) def test_many_to_one(session_factory, benchmark): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name="Reporter_1", ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name="Reporter_2", ) session.add(reporter_2) - article_1 = Article(headline='Article_1') + article_1 = Article(headline="Article_1") article_1.reporter = reporter_1 session.add(article_1) - article_2 = Article(headline='Article_2') + article_2 = Article(headline="Article_2") article_2.reporter = reporter_2 session.add(article_2) session.commit() session.close() - benchmark_query(session_factory, benchmark, """ + benchmark_query( + session_factory, + benchmark, + """ query { articles { headline @@ -119,41 +126,45 @@ def test_many_to_one(session_factory, benchmark): } } } - """) + """, + ) def test_one_to_many(session_factory, benchmark): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name="Reporter_1", ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name="Reporter_2", ) session.add(reporter_2) - article_1 = Article(headline='Article_1') + article_1 = Article(headline="Article_1") article_1.reporter = reporter_1 session.add(article_1) - article_2 = Article(headline='Article_2') + article_2 = Article(headline="Article_2") article_2.reporter = reporter_1 session.add(article_2) - article_3 = Article(headline='Article_3') + article_3 = Article(headline="Article_3") article_3.reporter = reporter_2 session.add(article_3) - article_4 = Article(headline='Article_4') + article_4 = Article(headline="Article_4") article_4.reporter = reporter_2 session.add(article_4) session.commit() session.close() - benchmark_query(session_factory, benchmark, """ + benchmark_query( + session_factory, + benchmark, + """ query { reporters { firstName @@ -166,34 +177,35 @@ def test_one_to_many(session_factory, benchmark): } } } - """) + """, + ) def test_many_to_many(session_factory, benchmark): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name="Reporter_1", ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name="Reporter_2", ) session.add(reporter_2) - pet_1 = Pet(name='Pet_1', pet_kind='cat', hair_kind=HairKind.LONG) + pet_1 = Pet(name="Pet_1", pet_kind="cat", hair_kind=HairKind.LONG) session.add(pet_1) - pet_2 = Pet(name='Pet_2', pet_kind='cat', hair_kind=HairKind.LONG) + pet_2 = Pet(name="Pet_2", pet_kind="cat", hair_kind=HairKind.LONG) session.add(pet_2) reporter_1.pets.append(pet_1) reporter_1.pets.append(pet_2) - pet_3 = Pet(name='Pet_3', pet_kind='cat', hair_kind=HairKind.LONG) + pet_3 = Pet(name="Pet_3", pet_kind="cat", hair_kind=HairKind.LONG) session.add(pet_3) - pet_4 = Pet(name='Pet_4', pet_kind='cat', hair_kind=HairKind.LONG) + pet_4 = Pet(name="Pet_4", pet_kind="cat", hair_kind=HairKind.LONG) session.add(pet_4) reporter_2.pets.append(pet_3) @@ -202,7 +214,10 @@ def test_many_to_many(session_factory, benchmark): session.commit() session.close() - benchmark_query(session_factory, benchmark, """ + benchmark_query( + session_factory, + benchmark, + """ query { reporters { firstName @@ -215,4 +230,5 @@ def test_many_to_many(session_factory, benchmark): } } } - """) + """, + ) diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index a6c2b1bf..812b4cea 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -15,16 +15,23 @@ from graphene.relay import Node from graphene.types.structures import Structure -from ..converter import (convert_sqlalchemy_column, - convert_sqlalchemy_composite, - convert_sqlalchemy_hybrid_method, - convert_sqlalchemy_relationship) -from ..fields import (UnsortedSQLAlchemyConnectionField, - default_connection_field_factory) +from ..converter import ( + convert_sqlalchemy_column, + convert_sqlalchemy_composite, + convert_sqlalchemy_hybrid_method, + convert_sqlalchemy_relationship, +) +from ..fields import UnsortedSQLAlchemyConnectionField, default_connection_field_factory from ..registry import Registry, get_global_registry from ..types import ORMField, SQLAlchemyObjectType -from .models import (Article, CompositeFullName, Pet, Reporter, ShoppingCart, - ShoppingCartItem) +from .models import ( + Article, + CompositeFullName, + Pet, + Reporter, + ShoppingCart, + ShoppingCartItem, +) def mock_resolver(): @@ -33,32 +40,34 @@ def mock_resolver(): def get_field(sqlalchemy_type, **column_kwargs): class Model(declarative_base()): - __tablename__ = 'model' + __tablename__ = "model" id_ = Column(types.Integer, primary_key=True) column = Column(sqlalchemy_type, doc="Custom Help Text", **column_kwargs) - column_prop = inspect(Model).column_attrs['column'] + column_prop = inspect(Model).column_attrs["column"] return convert_sqlalchemy_column(column_prop, get_global_registry(), mock_resolver) def get_field_from_column(column_): class Model(declarative_base()): - __tablename__ = 'model' + __tablename__ = "model" id_ = Column(types.Integer, primary_key=True) column = column_ - column_prop = inspect(Model).column_attrs['column'] + column_prop = inspect(Model).column_attrs["column"] return convert_sqlalchemy_column(column_prop, get_global_registry(), mock_resolver) def get_hybrid_property_type(prop_method): class Model(declarative_base()): - __tablename__ = 'model' + __tablename__ = "model" id_ = Column(types.Integer, primary_key=True) prop = prop_method - column_prop = inspect(Model).all_orm_descriptors['prop'] - return convert_sqlalchemy_hybrid_method(column_prop, mock_resolver(), **ORMField().kwargs) + column_prop = inspect(Model).all_orm_descriptors["prop"] + return convert_sqlalchemy_hybrid_method( + column_prop, mock_resolver(), **ORMField().kwargs + ) def test_hybrid_prop_int(): @@ -69,19 +78,25 @@ def prop_method() -> int: assert get_hybrid_property_type(prop_method).type == graphene.Int -@pytest.mark.skipif(sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10") +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10" +) def test_hybrid_prop_scalar_union_310(): @hybrid_property def prop_method() -> int | str: return "not allowed in gql schema" - with pytest.raises(ValueError, - match=r"Cannot convert hybrid_property Union to " - r"graphene.Union: the Union contains scalars. \.*"): + with pytest.raises( + ValueError, + match=r"Cannot convert hybrid_property Union to " + r"graphene.Union: the Union contains scalars. \.*", + ): get_hybrid_property_type(prop_method) -@pytest.mark.skipif(sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10") +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10" +) def test_hybrid_prop_scalar_union_and_optional_310(): """Checks if the use of Optionals does not interfere with non-conform scalar return types""" @@ -92,8 +107,7 @@ def prop_method() -> int | None: assert get_hybrid_property_type(prop_method).type == graphene.Int -@pytest.mark.skipif(sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10") -def test_should_union_work_310(): +def test_should_union_work(): reg = Registry() class PetType(SQLAlchemyObjectType): @@ -123,7 +137,9 @@ def prop_method_2() -> Union[ShoppingCartType, PetType]: # TODO verify types of the union -@pytest.mark.skipif(sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10") +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10" +) def test_should_union_work_310(): reg = Registry() @@ -244,7 +260,9 @@ def test_should_integer_convert_int(): def test_should_primary_integer_convert_id(): - assert get_field(types.Integer(), primary_key=True).type == graphene.NonNull(graphene.ID) + assert get_field(types.Integer(), primary_key=True).type == graphene.NonNull( + graphene.ID + ) def test_should_boolean_convert_boolean(): @@ -260,7 +278,7 @@ def test_should_numeric_convert_float(): def test_should_choice_convert_enum(): - field = get_field(sqa_utils.ChoiceType([(u"es", u"Spanish"), (u"en", u"English")])) + field = get_field(sqa_utils.ChoiceType([("es", "Spanish"), ("en", "English")])) graphene_type = field.type assert issubclass(graphene_type, graphene.Enum) assert graphene_type._meta.name == "MODEL_COLUMN" @@ -270,8 +288,8 @@ def test_should_choice_convert_enum(): def test_should_enum_choice_convert_enum(): class TestEnum(enum.Enum): - es = u"Spanish" - en = u"English" + es = "Spanish" + en = "English" field = get_field(sqa_utils.ChoiceType(TestEnum, impl=types.String())) graphene_type = field.type @@ -288,10 +306,14 @@ def test_choice_enum_column_key_name_issue_301(): """ class TestEnum(enum.Enum): - es = u"Spanish" - en = u"English" + es = "Spanish" + en = "English" - testChoice = Column("% descuento1", sqa_utils.ChoiceType(TestEnum, impl=types.String()), key="descuento1") + testChoice = Column( + "% descuento1", + sqa_utils.ChoiceType(TestEnum, impl=types.String()), + key="descuento1", + ) field = get_field_from_column(testChoice) graphene_type = field.type @@ -315,9 +337,9 @@ class TestEnum(enum.IntEnum): def test_should_columproperty_convert(): - field = get_field_from_column(column_property( - select([func.sum(func.cast(id, types.Integer))]).where(id == 1) - )) + field = get_field_from_column( + column_property(select([func.sum(func.cast(id, types.Integer))]).where(id == 1)) + ) assert field.type == graphene.Int @@ -347,7 +369,11 @@ class Meta: model = Article dynamic_field = convert_sqlalchemy_relationship( - Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name', + Reporter.pets.property, + A, + default_connection_field_factory, + True, + "orm_field_name", ) assert isinstance(dynamic_field, graphene.Dynamic) assert not dynamic_field.get_type() @@ -359,7 +385,11 @@ class Meta: model = Pet dynamic_field = convert_sqlalchemy_relationship( - Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name', + Reporter.pets.property, + A, + default_connection_field_factory, + True, + "orm_field_name", ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -375,7 +405,11 @@ class Meta: interfaces = (Node,) dynamic_field = convert_sqlalchemy_relationship( - Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name', + Reporter.pets.property, + A, + default_connection_field_factory, + True, + "orm_field_name", ) assert isinstance(dynamic_field, graphene.Dynamic) assert isinstance(dynamic_field.get_type(), UnsortedSQLAlchemyConnectionField) @@ -387,7 +421,11 @@ class Meta: model = Article dynamic_field = convert_sqlalchemy_relationship( - Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name', + Reporter.pets.property, + A, + default_connection_field_factory, + True, + "orm_field_name", ) assert isinstance(dynamic_field, graphene.Dynamic) assert not dynamic_field.get_type() @@ -399,7 +437,11 @@ class Meta: model = Reporter dynamic_field = convert_sqlalchemy_relationship( - Article.reporter.property, A, default_connection_field_factory, True, 'orm_field_name', + Article.reporter.property, + A, + default_connection_field_factory, + True, + "orm_field_name", ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -414,7 +456,11 @@ class Meta: interfaces = (Node,) dynamic_field = convert_sqlalchemy_relationship( - Article.reporter.property, A, default_connection_field_factory, True, 'orm_field_name', + Article.reporter.property, + A, + default_connection_field_factory, + True, + "orm_field_name", ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -429,7 +475,11 @@ class Meta: interfaces = (Node,) dynamic_field = convert_sqlalchemy_relationship( - Reporter.favorite_article.property, A, default_connection_field_factory, True, 'orm_field_name', + Reporter.favorite_article.property, + A, + default_connection_field_factory, + True, + "orm_field_name", ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -457,7 +507,9 @@ def test_should_postgresql_enum_convert(): def test_should_postgresql_py_enum_convert(): - field = get_field(postgresql.ENUM(enum.Enum("TwoNumbers", "one two"), name="two_numbers")) + field = get_field( + postgresql.ENUM(enum.Enum("TwoNumbers", "one two"), name="two_numbers") + ) field_type = field.type() assert field_type._meta.name == "TwoNumbers" assert isinstance(field_type, graphene.Enum) @@ -519,7 +571,11 @@ def convert_composite_class(composite, registry): return graphene.String(description=composite.doc) field = convert_sqlalchemy_composite( - composite(CompositeClass, (Column(types.Unicode(50)), Column(types.Unicode(50))), doc="Custom Help Text"), + composite( + CompositeClass, + (Column(types.Unicode(50)), Column(types.Unicode(50))), + doc="Custom Help Text", + ), registry, mock_resolver, ) @@ -535,7 +591,10 @@ def __init__(self, col1, col2): re_err = "Don't know how to convert the composite field" with pytest.raises(Exception, match=re_err): convert_sqlalchemy_composite( - composite(CompositeFullName, (Column(types.Unicode(50)), Column(types.Unicode(50)))), + composite( + CompositeFullName, + (Column(types.Unicode(50)), Column(types.Unicode(50))), + ), Registry(), mock_resolver, ) @@ -557,17 +616,22 @@ class Meta: ####################################################### shopping_cart_item_expected_types: Dict[str, Union[graphene.Scalar, Structure]] = { - 'hybrid_prop_shopping_cart': graphene.List(ShoppingCartType) + "hybrid_prop_shopping_cart": graphene.List(ShoppingCartType) } - assert sorted(list(ShoppingCartItemType._meta.fields.keys())) == sorted([ - # Columns - "id", - # Append Hybrid Properties from Above - *shopping_cart_item_expected_types.keys() - ]) + assert sorted(list(ShoppingCartItemType._meta.fields.keys())) == sorted( + [ + # Columns + "id", + # Append Hybrid Properties from Above + *shopping_cart_item_expected_types.keys(), + ] + ) - for hybrid_prop_name, hybrid_prop_expected_return_type in shopping_cart_item_expected_types.items(): + for ( + hybrid_prop_name, + hybrid_prop_expected_return_type, + ) in shopping_cart_item_expected_types.items(): hybrid_prop_field = ShoppingCartItemType._meta.fields[hybrid_prop_name] # this is a simple way of showing the failed property name @@ -576,7 +640,9 @@ class Meta: hybrid_prop_name, str(hybrid_prop_expected_return_type), ) - assert hybrid_prop_field.description is None # "doc" is ignored by hybrid property + assert ( + hybrid_prop_field.description is None + ) # "doc" is ignored by hybrid property ################################################### # Check ShoppingCart's Properties and Return Types @@ -596,7 +662,9 @@ class Meta: "hybrid_prop_list_int": graphene.List(graphene.Int), "hybrid_prop_list_date": graphene.List(graphene.Date), "hybrid_prop_nested_list_int": graphene.List(graphene.List(graphene.Int)), - "hybrid_prop_deeply_nested_list_int": graphene.List(graphene.List(graphene.List(graphene.Int))), + "hybrid_prop_deeply_nested_list_int": graphene.List( + graphene.List(graphene.List(graphene.Int)) + ), "hybrid_prop_first_shopping_cart_item": ShoppingCartItemType, "hybrid_prop_shopping_cart_item_list": graphene.List(ShoppingCartItemType), "hybrid_prop_unsupported_type_tuple": graphene.String, @@ -607,14 +675,19 @@ class Meta: "hybrid_prop_optional_self_referential": ShoppingCartType, } - assert sorted(list(ShoppingCartType._meta.fields.keys())) == sorted([ - # Columns - "id", - # Append Hybrid Properties from Above - *shopping_cart_expected_types.keys() - ]) + assert sorted(list(ShoppingCartType._meta.fields.keys())) == sorted( + [ + # Columns + "id", + # Append Hybrid Properties from Above + *shopping_cart_expected_types.keys(), + ] + ) - for hybrid_prop_name, hybrid_prop_expected_return_type in shopping_cart_expected_types.items(): + for ( + hybrid_prop_name, + hybrid_prop_expected_return_type, + ) in shopping_cart_expected_types.items(): hybrid_prop_field = ShoppingCartType._meta.fields[hybrid_prop_name] # this is a simple way of showing the failed property name @@ -623,4 +696,6 @@ class Meta: hybrid_prop_name, str(hybrid_prop_expected_return_type), ) - assert hybrid_prop_field.description is None # "doc" is ignored by hybrid property + assert ( + hybrid_prop_field.description is None + ) # "doc" is ignored by hybrid property diff --git a/graphene_sqlalchemy/tests/test_enums.py b/graphene_sqlalchemy/tests/test_enums.py index ca376964..cd97a00e 100644 --- a/graphene_sqlalchemy/tests/test_enums.py +++ b/graphene_sqlalchemy/tests/test_enums.py @@ -54,7 +54,7 @@ def test_convert_sa_enum_to_graphene_enum_based_on_list_named(): assert [ (key, value.value) for key, value in graphene_enum._meta.enum.__members__.items() - ] == [("RED", 'red'), ("GREEN", 'green'), ("BLUE", 'blue')] + ] == [("RED", "red"), ("GREEN", "green"), ("BLUE", "blue")] def test_convert_sa_enum_to_graphene_enum_based_on_list_unnamed(): @@ -65,7 +65,7 @@ def test_convert_sa_enum_to_graphene_enum_based_on_list_unnamed(): assert [ (key, value.value) for key, value in graphene_enum._meta.enum.__members__.items() - ] == [("RED", 'red'), ("GREEN", 'green'), ("BLUE", 'blue')] + ] == [("RED", "red"), ("GREEN", "green"), ("BLUE", "blue")] def test_convert_sa_enum_to_graphene_enum_based_on_list_without_name(): @@ -80,36 +80,35 @@ class PetType(SQLAlchemyObjectType): class Meta: model = Pet - enum = enum_for_field(PetType, 'pet_kind') + enum = enum_for_field(PetType, "pet_kind") assert isinstance(enum, type(Enum)) assert enum._meta.name == "PetKind" assert [ - (key, value.value) - for key, value in enum._meta.enum.__members__.items() - ] == [("CAT", 'cat'), ("DOG", 'dog')] - enum2 = enum_for_field(PetType, 'pet_kind') + (key, value.value) for key, value in enum._meta.enum.__members__.items() + ] == [("CAT", "cat"), ("DOG", "dog")] + enum2 = enum_for_field(PetType, "pet_kind") assert enum2 is enum - enum2 = PetType.enum_for_field('pet_kind') + enum2 = PetType.enum_for_field("pet_kind") assert enum2 is enum - enum = enum_for_field(PetType, 'hair_kind') + enum = enum_for_field(PetType, "hair_kind") assert isinstance(enum, type(Enum)) assert enum._meta.name == "HairKind" assert enum._meta.enum is HairKind - enum2 = PetType.enum_for_field('hair_kind') + enum2 = PetType.enum_for_field("hair_kind") assert enum2 is enum re_err = r"Cannot get PetType\.other_kind" with pytest.raises(TypeError, match=re_err): - enum_for_field(PetType, 'other_kind') + enum_for_field(PetType, "other_kind") with pytest.raises(TypeError, match=re_err): - PetType.enum_for_field('other_kind') + PetType.enum_for_field("other_kind") re_err = r"PetType\.name does not map to enum column" with pytest.raises(TypeError, match=re_err): - enum_for_field(PetType, 'name') + enum_for_field(PetType, "name") with pytest.raises(TypeError, match=re_err): - PetType.enum_for_field('name') + PetType.enum_for_field("name") re_err = r"Expected a field name, but got: None" with pytest.raises(TypeError, match=re_err): @@ -119,4 +118,4 @@ class Meta: re_err = "Expected SQLAlchemyObjectType, but got: None" with pytest.raises(TypeError, match=re_err): - enum_for_field(None, 'other_kind') + enum_for_field(None, "other_kind") diff --git a/graphene_sqlalchemy/tests/test_fields.py b/graphene_sqlalchemy/tests/test_fields.py index 2782da89..9fed146d 100644 --- a/graphene_sqlalchemy/tests/test_fields.py +++ b/graphene_sqlalchemy/tests/test_fields.py @@ -4,8 +4,7 @@ from graphene import NonNull, ObjectType from graphene.relay import Connection, Node -from ..fields import (SQLAlchemyConnectionField, - UnsortedSQLAlchemyConnectionField) +from ..fields import SQLAlchemyConnectionField, UnsortedSQLAlchemyConnectionField from ..types import SQLAlchemyObjectType from .models import Editor as EditorModel from .models import Pet as PetModel @@ -21,6 +20,7 @@ class Editor(SQLAlchemyObjectType): class Meta: model = EditorModel + ## # SQLAlchemyConnectionField ## @@ -59,6 +59,7 @@ def test_type_assert_object_has_connection(): with pytest.raises(AssertionError, match="doesn't have a connection"): SQLAlchemyConnectionField(Editor).type + ## # UnsortedSQLAlchemyConnectionField ## @@ -66,8 +67,7 @@ def test_type_assert_object_has_connection(): def test_unsorted_connection_field_removes_sort_arg_if_passed(): editor = UnsortedSQLAlchemyConnectionField( - Editor.connection, - sort=Editor.sort_argument(has_default=True) + Editor.connection, sort=Editor.sort_argument(has_default=True) ) assert "sort" not in editor.args diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index 39140814..c7a173df 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -9,19 +9,17 @@ def add_test_data(session): - reporter = Reporter( - first_name='John', last_name='Doe', favorite_pet_kind='cat') + reporter = Reporter(first_name="John", last_name="Doe", favorite_pet_kind="cat") session.add(reporter) - pet = Pet(name='Garfield', pet_kind='cat', hair_kind=HairKind.SHORT) + pet = Pet(name="Garfield", pet_kind="cat", hair_kind=HairKind.SHORT) session.add(pet) pet.reporters.append(reporter) - article = Article(headline='Hi!') + article = Article(headline="Hi!") article.reporter = reporter session.add(article) - reporter = Reporter( - first_name='Jane', last_name='Roe', favorite_pet_kind='dog') + reporter = Reporter(first_name="Jane", last_name="Roe", favorite_pet_kind="dog") session.add(reporter) - pet = Pet(name='Lassie', pet_kind='dog', hair_kind=HairKind.LONG) + pet = Pet(name="Lassie", pet_kind="dog", hair_kind=HairKind.LONG) pet.reporters.append(reporter) session.add(pet) editor = Editor(name="Jack") @@ -163,12 +161,12 @@ class Meta: model = Reporter interfaces = (Node,) - first_name_v2 = ORMField(model_attr='first_name') - hybrid_prop_v2 = ORMField(model_attr='hybrid_prop') - column_prop_v2 = ORMField(model_attr='column_prop') + first_name_v2 = ORMField(model_attr="first_name") + hybrid_prop_v2 = ORMField(model_attr="hybrid_prop") + column_prop_v2 = ORMField(model_attr="column_prop") composite_prop = ORMField() - favorite_article_v2 = ORMField(model_attr='favorite_article') - articles_v2 = ORMField(model_attr='articles') + favorite_article_v2 = ORMField(model_attr="favorite_article") + articles_v2 = ORMField(model_attr="articles") class ArticleType(SQLAlchemyObjectType): class Meta: diff --git a/graphene_sqlalchemy/tests/test_query_enums.py b/graphene_sqlalchemy/tests/test_query_enums.py index 5166c45f..923bbed1 100644 --- a/graphene_sqlalchemy/tests/test_query_enums.py +++ b/graphene_sqlalchemy/tests/test_query_enums.py @@ -9,7 +9,6 @@ def test_query_pet_kinds(session): add_test_data(session) class PetType(SQLAlchemyObjectType): - class Meta: model = Pet @@ -20,8 +19,9 @@ class Meta: class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) reporters = graphene.List(ReporterType) - pets = graphene.List(PetType, kind=graphene.Argument( - PetType.enum_for_field('pet_kind'))) + pets = graphene.List( + PetType, kind=graphene.Argument(PetType.enum_for_field("pet_kind")) + ) def resolve_reporter(self, _info): return session.query(Reporter).first() @@ -58,27 +58,24 @@ def resolve_pets(self, _info, kind): } """ expected = { - 'reporter': { - 'firstName': 'John', - 'lastName': 'Doe', - 'email': None, - 'favoritePetKind': 'CAT', - 'pets': [{ - 'name': 'Garfield', - 'petKind': 'CAT' - }] + "reporter": { + "firstName": "John", + "lastName": "Doe", + "email": None, + "favoritePetKind": "CAT", + "pets": [{"name": "Garfield", "petKind": "CAT"}], }, - 'reporters': [{ - 'firstName': 'John', - 'favoritePetKind': 'CAT', - }, { - 'firstName': 'Jane', - 'favoritePetKind': 'DOG', - }], - 'pets': [{ - 'name': 'Lassie', - 'petKind': 'DOG' - }] + "reporters": [ + { + "firstName": "John", + "favoritePetKind": "CAT", + }, + { + "firstName": "Jane", + "favoritePetKind": "DOG", + }, + ], + "pets": [{"name": "Lassie", "petKind": "DOG"}], } schema = graphene.Schema(query=Query) result = schema.execute(query) @@ -125,8 +122,8 @@ class Meta: class Query(graphene.ObjectType): pet = graphene.Field( - PetType, - kind=graphene.Argument(PetType.enum_for_field('pet_kind'))) + PetType, kind=graphene.Argument(PetType.enum_for_field("pet_kind")) + ) def resolve_pet(self, info, kind=None): query = session.query(Pet) diff --git a/graphene_sqlalchemy/tests/test_reflected.py b/graphene_sqlalchemy/tests/test_reflected.py index 46e10de9..a3f6c4aa 100644 --- a/graphene_sqlalchemy/tests/test_reflected.py +++ b/graphene_sqlalchemy/tests/test_reflected.py @@ -1,4 +1,3 @@ - from graphene import ObjectType from ..registry import Registry diff --git a/graphene_sqlalchemy/tests/test_registry.py b/graphene_sqlalchemy/tests/test_registry.py index f451f355..cb7e9034 100644 --- a/graphene_sqlalchemy/tests/test_registry.py +++ b/graphene_sqlalchemy/tests/test_registry.py @@ -142,7 +142,7 @@ class Meta: model = Reporter union_types = [PetType, ReporterType] - union = graphene.Union('ReporterPet', tuple(union_types)) + union = graphene.Union("ReporterPet", tuple(union_types)) reg.register_union_type(union, union_types) @@ -155,7 +155,7 @@ def test_register_union_scalar(): reg = Registry() union_types = [graphene.String, graphene.Int] - union = graphene.Union('StringInt', tuple(union_types)) + union = graphene.Union("StringInt", tuple(union_types)) re_err = r"Expected Graphene ObjectType, but got: .*String.*" with pytest.raises(TypeError, match=re_err): diff --git a/graphene_sqlalchemy/tests/test_sort_enums.py b/graphene_sqlalchemy/tests/test_sort_enums.py index e2510abc..11c7c9a7 100644 --- a/graphene_sqlalchemy/tests/test_sort_enums.py +++ b/graphene_sqlalchemy/tests/test_sort_enums.py @@ -354,7 +354,7 @@ def makeNodes(nodeList): """ result = schema.execute(queryError, context_value={"session": session}) assert result.errors is not None - assert 'cannot represent non-enum value' in result.errors[0].message + assert "cannot represent non-enum value" in result.errors[0].message queryNoSort = """ query sortTest { @@ -404,5 +404,11 @@ class Meta: "REPORTER_NUMBER_ASC", "REPORTER_NUMBER_DESC", ] - assert str(sort_enum.REPORTER_NUMBER_ASC.value.value) == 'test330."% reporter_number" ASC' - assert str(sort_enum.REPORTER_NUMBER_DESC.value.value) == 'test330."% reporter_number" DESC' + assert ( + str(sort_enum.REPORTER_NUMBER_ASC.value.value) + == 'test330."% reporter_number" ASC' + ) + assert ( + str(sort_enum.REPORTER_NUMBER_DESC.value.value) + == 'test330."% reporter_number" DESC' + ) diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 00e8b3af..4afb120d 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -4,16 +4,31 @@ import sqlalchemy.exc import sqlalchemy.orm.exc -from graphene import (Boolean, Dynamic, Field, Float, GlobalID, Int, List, - Node, NonNull, ObjectType, Schema, String) +from graphene import ( + Boolean, + Dynamic, + Field, + Float, + GlobalID, + Int, + List, + Node, + NonNull, + ObjectType, + Schema, + String, +) from graphene.relay import Connection from .. import utils from ..converter import convert_sqlalchemy_composite -from ..fields import (SQLAlchemyConnectionField, - UnsortedSQLAlchemyConnectionField, createConnectionField, - registerConnectionFieldFactory, - unregisterConnectionFieldFactory) +from ..fields import ( + SQLAlchemyConnectionField, + UnsortedSQLAlchemyConnectionField, + createConnectionField, + registerConnectionFieldFactory, + unregisterConnectionFieldFactory, +) from ..types import ORMField, SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions from .models import Article, CompositeFullName, Pet, Reporter @@ -21,6 +36,7 @@ def test_should_raise_if_no_model(): re_err = r"valid SQLAlchemy Model" with pytest.raises(Exception, match=re_err): + class Character1(SQLAlchemyObjectType): pass @@ -28,6 +44,7 @@ class Character1(SQLAlchemyObjectType): def test_should_raise_if_model_is_invalid(): re_err = r"valid SQLAlchemy Model" with pytest.raises(Exception, match=re_err): + class Character(SQLAlchemyObjectType): class Meta: model = 1 @@ -45,7 +62,7 @@ class Meta: reporter = Reporter() session.add(reporter) session.commit() - info = mock.Mock(context={'session': session}) + info = mock.Mock(context={"session": session}) reporter_node = ReporterType.get_node(info, reporter.id) assert reporter == reporter_node @@ -74,91 +91,93 @@ class Meta: model = Article interfaces = (Node,) - assert sorted(list(ReporterType._meta.fields.keys())) == sorted([ - # Columns - "column_prop", - "id", - "first_name", - "last_name", - "email", - "favorite_pet_kind", - # Composite - "composite_prop", - # Hybrid - "hybrid_prop_with_doc", - "hybrid_prop", - "hybrid_prop_str", - "hybrid_prop_int", - "hybrid_prop_float", - "hybrid_prop_bool", - "hybrid_prop_list", - # Relationship - "pets", - "articles", - "favorite_article", - ]) + assert sorted(list(ReporterType._meta.fields.keys())) == sorted( + [ + # Columns + "column_prop", + "id", + "first_name", + "last_name", + "email", + "favorite_pet_kind", + # Composite + "composite_prop", + # Hybrid + "hybrid_prop_with_doc", + "hybrid_prop", + "hybrid_prop_str", + "hybrid_prop_int", + "hybrid_prop_float", + "hybrid_prop_bool", + "hybrid_prop_list", + # Relationship + "pets", + "articles", + "favorite_article", + ] + ) # column - first_name_field = ReporterType._meta.fields['first_name'] + first_name_field = ReporterType._meta.fields["first_name"] assert first_name_field.type == String assert first_name_field.description == "First name" # column_property - column_prop_field = ReporterType._meta.fields['column_prop'] + column_prop_field = ReporterType._meta.fields["column_prop"] assert column_prop_field.type == Int # "doc" is ignored by column_property assert column_prop_field.description is None # composite - full_name_field = ReporterType._meta.fields['composite_prop'] + full_name_field = ReporterType._meta.fields["composite_prop"] assert full_name_field.type == String # "doc" is ignored by composite assert full_name_field.description is None # hybrid_property - hybrid_prop = ReporterType._meta.fields['hybrid_prop'] + hybrid_prop = ReporterType._meta.fields["hybrid_prop"] assert hybrid_prop.type == String # "doc" is ignored by hybrid_property assert hybrid_prop.description is None # hybrid_property_str - hybrid_prop_str = ReporterType._meta.fields['hybrid_prop_str'] + hybrid_prop_str = ReporterType._meta.fields["hybrid_prop_str"] assert hybrid_prop_str.type == String # "doc" is ignored by hybrid_property assert hybrid_prop_str.description is None # hybrid_property_int - hybrid_prop_int = ReporterType._meta.fields['hybrid_prop_int'] + hybrid_prop_int = ReporterType._meta.fields["hybrid_prop_int"] assert hybrid_prop_int.type == Int # "doc" is ignored by hybrid_property assert hybrid_prop_int.description is None # hybrid_property_float - hybrid_prop_float = ReporterType._meta.fields['hybrid_prop_float'] + hybrid_prop_float = ReporterType._meta.fields["hybrid_prop_float"] assert hybrid_prop_float.type == Float # "doc" is ignored by hybrid_property assert hybrid_prop_float.description is None # hybrid_property_bool - hybrid_prop_bool = ReporterType._meta.fields['hybrid_prop_bool'] + hybrid_prop_bool = ReporterType._meta.fields["hybrid_prop_bool"] assert hybrid_prop_bool.type == Boolean # "doc" is ignored by hybrid_property assert hybrid_prop_bool.description is None # hybrid_property_list - hybrid_prop_list = ReporterType._meta.fields['hybrid_prop_list'] + hybrid_prop_list = ReporterType._meta.fields["hybrid_prop_list"] assert hybrid_prop_list.type == List(Int) # "doc" is ignored by hybrid_property assert hybrid_prop_list.description is None # hybrid_prop_with_doc - hybrid_prop_with_doc = ReporterType._meta.fields['hybrid_prop_with_doc'] + hybrid_prop_with_doc = ReporterType._meta.fields["hybrid_prop_with_doc"] assert hybrid_prop_with_doc.type == String # docstring is picked up from hybrid_prop_with_doc assert hybrid_prop_with_doc.description == "Docstring test" # relationship - favorite_article_field = ReporterType._meta.fields['favorite_article'] + favorite_article_field = ReporterType._meta.fields["favorite_article"] assert isinstance(favorite_article_field, Dynamic) assert favorite_article_field.type().type == ArticleType assert favorite_article_field.type().description is None @@ -172,7 +191,7 @@ def convert_composite_class(composite, registry): class ReporterMixin(object): # columns first_name = ORMField(required=True) - last_name = ORMField(description='Overridden') + last_name = ORMField(description="Overridden") class ReporterType(SQLAlchemyObjectType, ReporterMixin): class Meta: @@ -180,8 +199,8 @@ class Meta: interfaces = (Node,) # columns - email = ORMField(deprecation_reason='Overridden') - email_v2 = ORMField(model_attr='email', type_=Int) + email = ORMField(deprecation_reason="Overridden") + email_v2 = ORMField(model_attr="email", type_=Int) # column_property column_prop = ORMField(type_=String) @@ -190,13 +209,13 @@ class Meta: composite_prop = ORMField() # hybrid_property - hybrid_prop_with_doc = ORMField(description='Overridden') - hybrid_prop = ORMField(description='Overridden') + hybrid_prop_with_doc = ORMField(description="Overridden") + hybrid_prop = ORMField(description="Overridden") # relationships - favorite_article = ORMField(description='Overridden') - articles = ORMField(deprecation_reason='Overridden') - pets = ORMField(description='Overridden') + favorite_article = ORMField(description="Overridden") + articles = ORMField(deprecation_reason="Overridden") + pets = ORMField(description="Overridden") class ArticleType(SQLAlchemyObjectType): class Meta: @@ -209,99 +228,101 @@ class Meta: interfaces = (Node,) use_connection = False - assert sorted(list(ReporterType._meta.fields.keys())) == sorted([ - # Fields from ReporterMixin - "first_name", - "last_name", - # Fields from ReporterType - "email", - "email_v2", - "column_prop", - "composite_prop", - "hybrid_prop_with_doc", - "hybrid_prop", - "favorite_article", - "articles", - "pets", - # Then the automatic SQLAlchemy fields - "id", - "favorite_pet_kind", - "hybrid_prop_str", - "hybrid_prop_int", - "hybrid_prop_float", - "hybrid_prop_bool", - "hybrid_prop_list", - ]) - - first_name_field = ReporterType._meta.fields['first_name'] + assert sorted(list(ReporterType._meta.fields.keys())) == sorted( + [ + # Fields from ReporterMixin + "first_name", + "last_name", + # Fields from ReporterType + "email", + "email_v2", + "column_prop", + "composite_prop", + "hybrid_prop_with_doc", + "hybrid_prop", + "favorite_article", + "articles", + "pets", + # Then the automatic SQLAlchemy fields + "id", + "favorite_pet_kind", + "hybrid_prop_str", + "hybrid_prop_int", + "hybrid_prop_float", + "hybrid_prop_bool", + "hybrid_prop_list", + ] + ) + + first_name_field = ReporterType._meta.fields["first_name"] assert isinstance(first_name_field.type, NonNull) assert first_name_field.type.of_type == String assert first_name_field.description == "First name" assert first_name_field.deprecation_reason is None - last_name_field = ReporterType._meta.fields['last_name'] + last_name_field = ReporterType._meta.fields["last_name"] assert last_name_field.type == String assert last_name_field.description == "Overridden" assert last_name_field.deprecation_reason is None - email_field = ReporterType._meta.fields['email'] + email_field = ReporterType._meta.fields["email"] assert email_field.type == String assert email_field.description == "Email" assert email_field.deprecation_reason == "Overridden" - email_field_v2 = ReporterType._meta.fields['email_v2'] + email_field_v2 = ReporterType._meta.fields["email_v2"] assert email_field_v2.type == Int assert email_field_v2.description == "Email" assert email_field_v2.deprecation_reason is None - hybrid_prop_field = ReporterType._meta.fields['hybrid_prop'] + hybrid_prop_field = ReporterType._meta.fields["hybrid_prop"] assert hybrid_prop_field.type == String assert hybrid_prop_field.description == "Overridden" assert hybrid_prop_field.deprecation_reason is None - hybrid_prop_with_doc_field = ReporterType._meta.fields['hybrid_prop_with_doc'] + hybrid_prop_with_doc_field = ReporterType._meta.fields["hybrid_prop_with_doc"] assert hybrid_prop_with_doc_field.type == String assert hybrid_prop_with_doc_field.description == "Overridden" assert hybrid_prop_with_doc_field.deprecation_reason is None - column_prop_field_v2 = ReporterType._meta.fields['column_prop'] + column_prop_field_v2 = ReporterType._meta.fields["column_prop"] assert column_prop_field_v2.type == String assert column_prop_field_v2.description is None assert column_prop_field_v2.deprecation_reason is None - composite_prop_field = ReporterType._meta.fields['composite_prop'] + composite_prop_field = ReporterType._meta.fields["composite_prop"] assert composite_prop_field.type == String assert composite_prop_field.description is None assert composite_prop_field.deprecation_reason is None - favorite_article_field = ReporterType._meta.fields['favorite_article'] + favorite_article_field = ReporterType._meta.fields["favorite_article"] assert isinstance(favorite_article_field, Dynamic) assert favorite_article_field.type().type == ArticleType - assert favorite_article_field.type().description == 'Overridden' + assert favorite_article_field.type().description == "Overridden" - articles_field = ReporterType._meta.fields['articles'] + articles_field = ReporterType._meta.fields["articles"] assert isinstance(articles_field, Dynamic) assert isinstance(articles_field.type(), UnsortedSQLAlchemyConnectionField) assert articles_field.type().deprecation_reason == "Overridden" - pets_field = ReporterType._meta.fields['pets'] + pets_field = ReporterType._meta.fields["pets"] assert isinstance(pets_field, Dynamic) assert isinstance(pets_field.type().type, List) assert pets_field.type().type.of_type == PetType - assert pets_field.type().description == 'Overridden' + assert pets_field.type().description == "Overridden" def test_invalid_model_attr(): err_msg = ( - "Cannot map ORMField to a model attribute.\n" - "Field: 'ReporterType.first_name'" + "Cannot map ORMField to a model attribute.\n" "Field: 'ReporterType.first_name'" ) with pytest.raises(ValueError, match=err_msg): + class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter - first_name = ORMField(model_attr='does_not_exist') + first_name = ORMField(model_attr="does_not_exist") def test_only_fields(): @@ -325,29 +346,32 @@ class Meta: first_name = ORMField() # Takes precedence last_name = ORMField() # Noop - assert sorted(list(ReporterType._meta.fields.keys())) == sorted([ - "first_name", - "last_name", - "column_prop", - "email", - "favorite_pet_kind", - "composite_prop", - "hybrid_prop_with_doc", - "hybrid_prop", - "hybrid_prop_str", - "hybrid_prop_int", - "hybrid_prop_float", - "hybrid_prop_bool", - "hybrid_prop_list", - "pets", - "articles", - "favorite_article", - ]) + assert sorted(list(ReporterType._meta.fields.keys())) == sorted( + [ + "first_name", + "last_name", + "column_prop", + "email", + "favorite_pet_kind", + "composite_prop", + "hybrid_prop_with_doc", + "hybrid_prop", + "hybrid_prop_str", + "hybrid_prop_int", + "hybrid_prop_float", + "hybrid_prop_bool", + "hybrid_prop_list", + "pets", + "articles", + "favorite_article", + ] + ) def test_only_and_exclude_fields(): re_err = r"'only_fields' and 'exclude_fields' cannot be both set" with pytest.raises(Exception, match=re_err): + class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter @@ -372,14 +396,14 @@ def test_resolvers(session): class ReporterMixin(object): def resolve_id(root, _info): - return 'ID' + return "ID" class ReporterType(ReporterMixin, SQLAlchemyObjectType): class Meta: model = Reporter email = ORMField() - email_v2 = ORMField(model_attr='email') + email_v2 = ORMField(model_attr="email") favorite_pet_kind = Field(String) favorite_pet_kind_v2 = Field(String) @@ -387,10 +411,10 @@ def resolve_last_name(root, _info): return root.last_name.upper() def resolve_email_v2(root, _info): - return root.email + '_V2' + return root.email + "_V2" def resolve_favorite_pet_kind_v2(root, _info): - return str(root.favorite_pet_kind) + '_V2' + return str(root.favorite_pet_kind) + "_V2" class Query(ObjectType): reporter = Field(ReporterType) @@ -398,12 +422,18 @@ class Query(ObjectType): def resolve_reporter(self, _info): return session.query(Reporter).first() - reporter = Reporter(first_name='first_name', last_name='last_name', email='email', favorite_pet_kind='cat') + reporter = Reporter( + first_name="first_name", + last_name="last_name", + email="email", + favorite_pet_kind="cat", + ) session.add(reporter) session.commit() schema = Schema(query=Query) - result = schema.execute(""" + result = schema.execute( + """ query { reporter { id @@ -415,27 +445,29 @@ def resolve_reporter(self, _info): favoritePetKindV2 } } - """) + """ + ) assert not result.errors # Custom resolver on a base class - assert result.data['reporter']['id'] == 'ID' + assert result.data["reporter"]["id"] == "ID" # Default field + default resolver - assert result.data['reporter']['firstName'] == 'first_name' + assert result.data["reporter"]["firstName"] == "first_name" # Default field + custom resolver - assert result.data['reporter']['lastName'] == 'LAST_NAME' + assert result.data["reporter"]["lastName"] == "LAST_NAME" # ORMField + default resolver - assert result.data['reporter']['email'] == 'email' + assert result.data["reporter"]["email"] == "email" # ORMField + custom resolver - assert result.data['reporter']['emailV2'] == 'email_V2' + assert result.data["reporter"]["emailV2"] == "email_V2" # Field + default resolver - assert result.data['reporter']['favoritePetKind'] == 'cat' + assert result.data["reporter"]["favoritePetKind"] == "cat" # Field + custom resolver - assert result.data['reporter']['favoritePetKindV2'] == 'cat_V2' + assert result.data["reporter"]["favoritePetKindV2"] == "cat_V2" # Test Custom SQLAlchemyObjectType Implementation + def test_custom_objecttype_registered(): class CustomSQLAlchemyObjectType(SQLAlchemyObjectType): class Meta: @@ -463,9 +495,9 @@ class Meta: def __init_subclass_with_meta__(cls, custom_option=None, **options): _meta = CustomOptions(cls) _meta.custom_option = custom_option - super(SQLAlchemyObjectTypeWithCustomOptions, cls).__init_subclass_with_meta__( - _meta=_meta, **options - ) + super( + SQLAlchemyObjectTypeWithCustomOptions, cls + ).__init_subclass_with_meta__(_meta=_meta, **options) class ReporterWithCustomOptions(SQLAlchemyObjectTypeWithCustomOptions): class Meta: @@ -479,6 +511,7 @@ class Meta: # Tests for connection_field_factory + class _TestSQLAlchemyConnectionField(SQLAlchemyConnectionField): pass @@ -494,7 +527,9 @@ class Meta: model = Article interfaces = (Node,) - assert isinstance(ReporterType._meta.fields['articles'].type(), UnsortedSQLAlchemyConnectionField) + assert isinstance( + ReporterType._meta.fields["articles"].type(), UnsortedSQLAlchemyConnectionField + ) def test_custom_connection_field_factory(): @@ -514,7 +549,9 @@ class Meta: model = Article interfaces = (Node,) - assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) + assert isinstance( + ReporterType._meta.fields["articles"].type(), _TestSQLAlchemyConnectionField + ) def test_deprecated_registerConnectionFieldFactory(): @@ -531,7 +568,9 @@ class Meta: model = Article interfaces = (Node,) - assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) + assert isinstance( + ReporterType._meta.fields["articles"].type(), _TestSQLAlchemyConnectionField + ) def test_deprecated_unregisterConnectionFieldFactory(): @@ -549,7 +588,9 @@ class Meta: model = Article interfaces = (Node,) - assert not isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) + assert not isinstance( + ReporterType._meta.fields["articles"].type(), _TestSQLAlchemyConnectionField + ) def test_deprecated_createConnectionField(): @@ -557,7 +598,7 @@ def test_deprecated_createConnectionField(): createConnectionField(None) -@mock.patch(utils.__name__ + '.class_mapper') +@mock.patch(utils.__name__ + ".class_mapper") def test_unique_errors_propagate(class_mapper_mock): # Define unique error to detect class UniqueError(Exception): @@ -569,9 +610,11 @@ class UniqueError(Exception): # Make sure that errors are propagated from class_mapper when instantiating new classes error = None try: + class ArticleOne(SQLAlchemyObjectType): class Meta(object): model = Article + except UniqueError as e: error = e @@ -580,7 +623,7 @@ class Meta(object): assert isinstance(error, UniqueError) -@mock.patch(utils.__name__ + '.class_mapper') +@mock.patch(utils.__name__ + ".class_mapper") def test_argument_errors_propagate(class_mapper_mock): # Mock class_mapper effect class_mapper_mock.side_effect = sqlalchemy.exc.ArgumentError @@ -588,9 +631,11 @@ def test_argument_errors_propagate(class_mapper_mock): # Make sure that errors are propagated from class_mapper when instantiating new classes error = None try: + class ArticleTwo(SQLAlchemyObjectType): class Meta(object): model = Article + except sqlalchemy.exc.ArgumentError as e: error = e @@ -599,7 +644,7 @@ class Meta(object): assert isinstance(error, sqlalchemy.exc.ArgumentError) -@mock.patch(utils.__name__ + '.class_mapper') +@mock.patch(utils.__name__ + ".class_mapper") def test_unmapped_errors_reformat(class_mapper_mock): # Mock class_mapper effect class_mapper_mock.side_effect = sqlalchemy.orm.exc.UnmappedClassError(object) @@ -607,9 +652,11 @@ def test_unmapped_errors_reformat(class_mapper_mock): # Make sure that errors are propagated from class_mapper when instantiating new classes error = None try: + class ArticleThree(SQLAlchemyObjectType): class Meta(object): model = Article + except ValueError as e: error = e diff --git a/graphene_sqlalchemy/tests/test_utils.py b/graphene_sqlalchemy/tests/test_utils.py index de359e05..75328280 100644 --- a/graphene_sqlalchemy/tests/test_utils.py +++ b/graphene_sqlalchemy/tests/test_utils.py @@ -3,8 +3,14 @@ from graphene import Enum, List, ObjectType, Schema, String -from ..utils import (DummyImport, get_session, sort_argument_for_model, - sort_enum_for_model, to_enum_value_name, to_type_name) +from ..utils import ( + DummyImport, + get_session, + sort_argument_for_model, + sort_enum_for_model, + to_enum_value_name, + to_type_name, +) from .models import Base, Editor, Pet @@ -96,9 +102,11 @@ class MultiplePK(Base): with pytest.warns(DeprecationWarning): arg = sort_argument_for_model(MultiplePK) - assert set(arg.default_value) == set( - (MultiplePK.foo.name + "_asc", MultiplePK.bar.name + "_asc") - ) + assert set(arg.default_value) == { + MultiplePK.foo.name + "_asc", + MultiplePK.bar.name + "_asc", + } + def test_dummy_import(): dummy_module = DummyImport() diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index e6c3d14c..fe48e9eb 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -2,8 +2,7 @@ import sqlalchemy from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import (ColumnProperty, CompositeProperty, - RelationshipProperty) +from sqlalchemy.orm import ColumnProperty, CompositeProperty, RelationshipProperty from sqlalchemy.orm.exc import NoResultFound from graphene import Field @@ -12,12 +11,17 @@ from graphene.types.utils import yank_fields_from_attrs from graphene.utils.orderedtype import OrderedType -from .converter import (convert_sqlalchemy_column, - convert_sqlalchemy_composite, - convert_sqlalchemy_hybrid_method, - convert_sqlalchemy_relationship) -from .enums import (enum_for_field, sort_argument_for_object_type, - sort_enum_for_object_type) +from .converter import ( + convert_sqlalchemy_column, + convert_sqlalchemy_composite, + convert_sqlalchemy_hybrid_method, + convert_sqlalchemy_relationship, +) +from .enums import ( + enum_for_field, + sort_argument_for_object_type, + sort_enum_for_object_type, +) from .registry import Registry, get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver from .utils import get_query, is_mapped_class, is_mapped_instance @@ -25,15 +29,15 @@ class ORMField(OrderedType): def __init__( - self, - model_attr=None, - type_=None, - required=None, - description=None, - deprecation_reason=None, - batching=None, - _creation_counter=None, - **field_kwargs + self, + model_attr=None, + type_=None, + required=None, + description=None, + deprecation_reason=None, + batching=None, + _creation_counter=None, + **field_kwargs ): """ Use this to override fields automatically generated by SQLAlchemyObjectType. @@ -76,20 +80,28 @@ class Meta: super(ORMField, self).__init__(_creation_counter=_creation_counter) # The is only useful for documentation and auto-completion common_kwargs = { - 'model_attr': model_attr, - 'type_': type_, - 'required': required, - 'description': description, - 'deprecation_reason': deprecation_reason, - 'batching': batching, + "model_attr": model_attr, + "type_": type_, + "required": required, + "description": description, + "deprecation_reason": deprecation_reason, + "batching": batching, + } + common_kwargs = { + kwarg: value for kwarg, value in common_kwargs.items() if value is not None } - common_kwargs = {kwarg: value for kwarg, value in common_kwargs.items() if value is not None} self.kwargs = field_kwargs self.kwargs.update(common_kwargs) def construct_fields( - obj_type, model, registry, only_fields, exclude_fields, batching, connection_field_factory + obj_type, + model, + registry, + only_fields, + exclude_fields, + batching, + connection_field_factory, ): """ Construct all the fields for a SQLAlchemyObjectType. @@ -112,15 +124,20 @@ def construct_fields( all_model_attrs = OrderedDict( inspected_model.column_attrs.items() + inspected_model.composites.items() - + [(name, item) for name, item in inspected_model.all_orm_descriptors.items() - if isinstance(item, hybrid_property)] + + [ + (name, item) + for name, item in inspected_model.all_orm_descriptors.items() + if isinstance(item, hybrid_property) + ] + inspected_model.relationships.items() ) # Filter out excluded fields auto_orm_field_names = [] for attr_name, attr in all_model_attrs.items(): - if (only_fields and attr_name not in only_fields) or (attr_name in exclude_fields): + if (only_fields and attr_name not in only_fields) or ( + attr_name in exclude_fields + ): continue auto_orm_field_names.append(attr_name) @@ -135,13 +152,15 @@ def construct_fields( # Set the model_attr if not set for orm_field_name, orm_field in custom_orm_fields_items: - attr_name = orm_field.kwargs.get('model_attr', orm_field_name) + attr_name = orm_field.kwargs.get("model_attr", orm_field_name) if attr_name not in all_model_attrs: - raise ValueError(( - "Cannot map ORMField to a model attribute.\n" - "Field: '{}.{}'" - ).format(obj_type.__name__, orm_field_name,)) - orm_field.kwargs['model_attr'] = attr_name + raise ValueError( + ("Cannot map ORMField to a model attribute.\n" "Field: '{}.{}'").format( + obj_type.__name__, + orm_field_name, + ) + ) + orm_field.kwargs["model_attr"] = attr_name # Merge automatic fields with custom ORM fields orm_fields = OrderedDict(custom_orm_fields_items) @@ -153,27 +172,38 @@ def construct_fields( # Build all the field dictionary fields = OrderedDict() for orm_field_name, orm_field in orm_fields.items(): - attr_name = orm_field.kwargs.pop('model_attr') + attr_name = orm_field.kwargs.pop("model_attr") attr = all_model_attrs[attr_name] - resolver = get_custom_resolver(obj_type, orm_field_name) or get_attr_resolver(obj_type, attr_name) + resolver = get_custom_resolver(obj_type, orm_field_name) or get_attr_resolver( + obj_type, attr_name + ) if isinstance(attr, ColumnProperty): - field = convert_sqlalchemy_column(attr, registry, resolver, **orm_field.kwargs) + field = convert_sqlalchemy_column( + attr, registry, resolver, **orm_field.kwargs + ) elif isinstance(attr, RelationshipProperty): - batching_ = orm_field.kwargs.pop('batching', batching) + batching_ = orm_field.kwargs.pop("batching", batching) field = convert_sqlalchemy_relationship( - attr, obj_type, connection_field_factory, batching_, orm_field_name, **orm_field.kwargs) + attr, + obj_type, + connection_field_factory, + batching_, + orm_field_name, + **orm_field.kwargs + ) elif isinstance(attr, CompositeProperty): if attr_name != orm_field_name or orm_field.kwargs: # TODO Add a way to override composite property fields raise ValueError( "ORMField kwargs for composite fields must be empty. " - "Field: {}.{}".format(obj_type.__name__, orm_field_name)) + "Field: {}.{}".format(obj_type.__name__, orm_field_name) + ) field = convert_sqlalchemy_composite(attr, registry, resolver) elif isinstance(attr, hybrid_property): field = convert_sqlalchemy_hybrid_method(attr, resolver, **orm_field.kwargs) else: - raise Exception('Property type is not supported') # Should never happen + raise Exception("Property type is not supported") # Should never happen registry.register_orm_field(obj_type, orm_field_name, attr) fields[orm_field_name] = field @@ -191,26 +221,27 @@ class SQLAlchemyObjectTypeOptions(ObjectTypeOptions): class SQLAlchemyObjectType(ObjectType): @classmethod def __init_subclass_with_meta__( - cls, - model=None, - registry=None, - skip_registry=False, - only_fields=(), - exclude_fields=(), - connection=None, - connection_class=None, - use_connection=None, - interfaces=(), - id=None, - batching=False, - connection_field_factory=None, - _meta=None, - **options + cls, + model=None, + registry=None, + skip_registry=False, + only_fields=(), + exclude_fields=(), + connection=None, + connection_class=None, + use_connection=None, + interfaces=(), + id=None, + batching=False, + connection_field_factory=None, + _meta=None, + **options ): # Make sure model is a valid SQLAlchemy model if not is_mapped_class(model): raise ValueError( - "You need to pass a valid SQLAlchemy Model in " '{}.Meta, received "{}".'.format(cls.__name__, model) + "You need to pass a valid SQLAlchemy Model in " + '{}.Meta, received "{}".'.format(cls.__name__, model) ) if not registry: @@ -222,7 +253,9 @@ def __init_subclass_with_meta__( ).format(cls.__name__, registry) if only_fields and exclude_fields: - raise ValueError("The options 'only_fields' and 'exclude_fields' cannot be both set on the same type.") + raise ValueError( + "The options 'only_fields' and 'exclude_fields' cannot be both set on the same type." + ) sqla_fields = yank_fields_from_attrs( construct_fields( @@ -240,7 +273,7 @@ def __init_subclass_with_meta__( if use_connection is None and interfaces: use_connection = any( - (issubclass(interface, Node) for interface in interfaces) + issubclass(interface, Node) for interface in interfaces ) if use_connection and not connection: diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index 27117c0c..54bb8402 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -153,12 +153,16 @@ def sort_argument_for_model(cls, has_default=True): def is_sqlalchemy_version_less_than(version_string): # pragma: no cover """Check the installed SQLAlchemy version""" - return pkg_resources.get_distribution('SQLAlchemy').parsed_version < pkg_resources.parse_version(version_string) + return pkg_resources.get_distribution( + "SQLAlchemy" + ).parsed_version < pkg_resources.parse_version(version_string) def is_graphene_version_less_than(version_string): # pragma: no cover """Check the installed graphene version""" - return pkg_resources.get_distribution('graphene').parsed_version < pkg_resources.parse_version(version_string) + return pkg_resources.get_distribution( + "graphene" + ).parsed_version < pkg_resources.parse_version(version_string) class singledispatchbymatchfunction: @@ -182,7 +186,6 @@ def __call__(self, *args, **kwargs): return self.default(*args, **kwargs) def register(self, matcher_function: Callable[[Any], bool]): - def grab_function_from_outside(f): self.registry[matcher_function] = f return self @@ -192,7 +195,7 @@ def grab_function_from_outside(f): def value_equals(value): """A simple function that makes the equality based matcher functions for - SingleDispatchByMatchFunction prettier""" + SingleDispatchByMatchFunction prettier""" return lambda x: x == value @@ -208,8 +211,14 @@ def safe_isinstance_checker(arg): def registry_sqlalchemy_model_from_str(model_name: str) -> Optional[Any]: from graphene_sqlalchemy.registry import get_global_registry + try: - return next(filter(lambda x: x.__name__ == model_name, list(get_global_registry()._registry.keys()))) + return next( + filter( + lambda x: x.__name__ == model_name, + list(get_global_registry()._registry.keys()), + ) + ) except StopIteration: pass diff --git a/setup.cfg b/setup.cfg index f36334d8..e479585c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,10 +2,12 @@ test=pytest [flake8] -exclude = setup.py,docs/*,examples/*,tests +ignore = E203,W503 +exclude = .git,.mypy_cache,.pytest_cache,.tox,.venv,__pycache__,build,dist,docs,setup.py,docs/*,examples/*,tests max-line-length = 120 [isort] +profile = black no_lines_before=FIRSTPARTY known_graphene=graphene,graphql_relay,flask_graphql,graphql_server,sphinx_graphene_theme known_first_party=graphene_sqlalchemy