Skip to content

Commit aff56b8

Browse files
tcleonardThomas Leonard
and
Thomas Leonard
authored
Validate in and range filter inputs (#1092)
Co-authored-by: Thomas Leonard <[email protected]>
1 parent 1281c13 commit aff56b8

File tree

6 files changed

+211
-34
lines changed

6 files changed

+211
-34
lines changed

Diff for: graphene_django/filter/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
)
1010
else:
1111
from .fields import DjangoFilterConnectionField
12-
from .filterset import GlobalIDFilter, GlobalIDMultipleChoiceFilter
12+
from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter
1313

1414
__all__ = [
1515
"DjangoFilterConnectionField",

Diff for: graphene_django/filter/filters.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from django.core.exceptions import ValidationError
2+
from django.forms import Field
3+
4+
from django_filters import Filter, MultipleChoiceFilter
5+
6+
from graphql_relay.node.node import from_global_id
7+
8+
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
9+
10+
11+
class GlobalIDFilter(Filter):
12+
"""
13+
Filter for Relay global ID.
14+
"""
15+
16+
field_class = GlobalIDFormField
17+
18+
def filter(self, qs, value):
19+
""" Convert the filter value to a primary key before filtering """
20+
_id = None
21+
if value is not None:
22+
_, _id = from_global_id(value)
23+
return super(GlobalIDFilter, self).filter(qs, _id)
24+
25+
26+
class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
27+
field_class = GlobalIDMultipleChoiceField
28+
29+
def filter(self, qs, value):
30+
gids = [from_global_id(v)[1] for v in value]
31+
return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)
32+
33+
34+
class InFilter(Filter):
35+
"""
36+
Filter for a list of value using the `__in` Django filter.
37+
"""
38+
39+
def filter(self, qs, value):
40+
"""
41+
Override the default filter class to check first weather the list is
42+
empty or not.
43+
This needs to be done as in this case we expect to get an empty output
44+
(if not an exclude filter) but django_filter consider an empty list
45+
to be an empty input value (see `EMPTY_VALUES`) meaning that
46+
the filter does not need to be applied (hence returning the original
47+
queryset).
48+
"""
49+
if value is not None and len(value) == 0:
50+
if self.exclude:
51+
return qs
52+
else:
53+
return qs.none()
54+
else:
55+
return super(InFilter, self).filter(qs, value)
56+
57+
58+
def validate_range(value):
59+
"""
60+
Validator for range filter input: the list of value must be of length 2.
61+
Note that validators are only run if the value is not empty.
62+
"""
63+
if len(value) != 2:
64+
raise ValidationError(
65+
"Invalid range specified: it needs to contain 2 values.", code="invalid"
66+
)
67+
68+
69+
class RangeField(Field):
70+
default_validators = [validate_range]
71+
empty_values = [None]
72+
73+
74+
class RangeFilter(Filter):
75+
field_class = RangeField

Diff for: graphene_django/filter/filterset.py

+2-23
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,11 @@
11
import itertools
22

33
from django.db import models
4-
from django_filters import Filter, MultipleChoiceFilter, VERSION
4+
from django_filters import VERSION
55
from django_filters.filterset import BaseFilterSet, FilterSet
66
from django_filters.filterset import FILTER_FOR_DBFIELD_DEFAULTS
77

8-
from graphql_relay.node.node import from_global_id
9-
10-
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
11-
12-
13-
class GlobalIDFilter(Filter):
14-
field_class = GlobalIDFormField
15-
16-
def filter(self, qs, value):
17-
""" Convert the filter value to a primary key before filtering """
18-
_id = None
19-
if value is not None:
20-
_, _id = from_global_id(value)
21-
return super(GlobalIDFilter, self).filter(qs, _id)
22-
23-
24-
class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
25-
field_class = GlobalIDMultipleChoiceField
26-
27-
def filter(self, qs, value):
28-
gids = [from_global_id(v)[1] for v in value]
29-
return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)
8+
from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter
309

3110

3211
GRAPHENE_FILTER_SET_OVERRIDES = {

Diff for: graphene_django/filter/tests/test_in_filter.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -157,20 +157,19 @@ def test_int_in_filter():
157157
]
158158

159159

160-
def test_int_range_filter():
160+
def test_in_filter_with_empty_list():
161161
"""
162-
Test in filter on an integer field.
162+
Check that using a in filter with an empty list provided as input returns no objects.
163163
"""
164164
Pet.objects.create(name="Brutus", age=12)
165165
Pet.objects.create(name="Mimi", age=8)
166-
Pet.objects.create(name="Jojo, the rabbit", age=3)
167166
Pet.objects.create(name="Picotin", age=5)
168167

169168
schema = Schema(query=Query)
170169

171170
query = """
172171
query {
173-
pets (age_Range: [4, 9]) {
172+
pets (name_In: []) {
174173
edges {
175174
node {
176175
name
@@ -181,7 +180,4 @@ def test_int_range_filter():
181180
"""
182181
result = schema.execute(query)
183182
assert not result.errors
184-
assert result.data["pets"]["edges"] == [
185-
{"node": {"name": "Mimi"}},
186-
{"node": {"name": "Picotin"}},
187-
]
183+
assert len(result.data["pets"]["edges"]) == 0

Diff for: graphene_django/filter/tests/test_range_filter.py

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import ast
2+
import json
3+
import pytest
4+
5+
from django_filters import FilterSet
6+
from django_filters import rest_framework as filters
7+
from graphene import ObjectType, Schema
8+
from graphene.relay import Node
9+
from graphene_django import DjangoObjectType
10+
from graphene_django.tests.models import Pet
11+
from graphene_django.utils import DJANGO_FILTER_INSTALLED
12+
13+
pytestmark = []
14+
15+
if DJANGO_FILTER_INSTALLED:
16+
from graphene_django.filter import DjangoFilterConnectionField
17+
else:
18+
pytestmark.append(
19+
pytest.mark.skipif(
20+
True, reason="django_filters not installed or not compatible"
21+
)
22+
)
23+
24+
25+
class PetNode(DjangoObjectType):
26+
class Meta:
27+
model = Pet
28+
interfaces = (Node,)
29+
filter_fields = {
30+
"name": ["exact", "in"],
31+
"age": ["exact", "in", "range"],
32+
}
33+
34+
35+
class Query(ObjectType):
36+
pets = DjangoFilterConnectionField(PetNode)
37+
38+
39+
def test_int_range_filter():
40+
"""
41+
Test range filter on an integer field.
42+
"""
43+
Pet.objects.create(name="Brutus", age=12)
44+
Pet.objects.create(name="Mimi", age=8)
45+
Pet.objects.create(name="Jojo, the rabbit", age=3)
46+
Pet.objects.create(name="Picotin", age=5)
47+
48+
schema = Schema(query=Query)
49+
50+
query = """
51+
query {
52+
pets (age_Range: [4, 9]) {
53+
edges {
54+
node {
55+
name
56+
}
57+
}
58+
}
59+
}
60+
"""
61+
result = schema.execute(query)
62+
assert not result.errors
63+
assert result.data["pets"]["edges"] == [
64+
{"node": {"name": "Mimi"}},
65+
{"node": {"name": "Picotin"}},
66+
]
67+
68+
69+
def test_range_filter_with_invalid_input():
70+
"""
71+
Test range filter used with invalid inputs raise an error.
72+
"""
73+
Pet.objects.create(name="Brutus", age=12)
74+
Pet.objects.create(name="Mimi", age=8)
75+
Pet.objects.create(name="Jojo, the rabbit", age=3)
76+
Pet.objects.create(name="Picotin", age=5)
77+
78+
schema = Schema(query=Query)
79+
80+
query = """
81+
query ($rangeValue: [Int]) {
82+
pets (age_Range: $rangeValue) {
83+
edges {
84+
node {
85+
name
86+
}
87+
}
88+
}
89+
}
90+
"""
91+
expected_error = json.dumps(
92+
{
93+
"age__range": [
94+
{
95+
"message": "Invalid range specified: it needs to contain 2 values.",
96+
"code": "invalid",
97+
}
98+
]
99+
}
100+
)
101+
102+
# Empty list
103+
result = schema.execute(query, variables={"rangeValue": []})
104+
assert len(result.errors) == 1
105+
assert ast.literal_eval(result.errors[0].message)[0] == expected_error
106+
107+
# Only one item in the list
108+
result = schema.execute(query, variables={"rangeValue": [1]})
109+
assert len(result.errors) == 1
110+
assert ast.literal_eval(result.errors[0].message)[0] == expected_error
111+
112+
# More than 2 items in the list
113+
result = schema.execute(query, variables={"rangeValue": [1, 2, 3]})
114+
assert len(result.errors) == 1
115+
assert ast.literal_eval(result.errors[0].message)[0] == expected_error

Diff for: graphene_django/filter/utils.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from django_filters.filters import Filter, BaseCSVFilter
77

88
from .filterset import custom_filterset_factory, setup_filterset
9+
from .filters import InFilter, RangeFilter
910

1011

1112
def get_filtering_args_from_filterset(filterset_class, type):
@@ -80,9 +81,20 @@ def replace_csv_filters(filterset_class):
8081
"""
8182
for name, filter_field in six.iteritems(filterset_class.base_filters):
8283
filter_type = filter_field.lookup_expr
83-
if filter_type in ["in", "range"]:
84+
if filter_type == "in":
85+
assert isinstance(filter_field, BaseCSVFilter)
86+
filterset_class.base_filters[name] = InFilter(
87+
field_name=filter_field.field_name,
88+
lookup_expr=filter_field.lookup_expr,
89+
label=filter_field.label,
90+
method=filter_field.method,
91+
exclude=filter_field.exclude,
92+
**filter_field.extra
93+
)
94+
95+
if filter_type == "range":
8496
assert isinstance(filter_field, BaseCSVFilter)
85-
filterset_class.base_filters[name] = Filter(
97+
filterset_class.base_filters[name] = RangeFilter(
8698
field_name=filter_field.field_name,
8799
lookup_expr=filter_field.lookup_expr,
88100
label=filter_field.label,

0 commit comments

Comments
 (0)