Skip to content

Commit b291aa4

Browse files
authored
16078 make GraphQL NumberFilter optional (#16115)
* 16078 make GraphQL NumberFilter optional * 16078 add tests for graphql filtering * 16078 add tests for graphql filtering * 16078 add tests for graphql filtering
1 parent e6ccea0 commit b291aa4

File tree

3 files changed

+58
-7
lines changed

3 files changed

+58
-7
lines changed

netbox/ipam/tests/test_api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,9 @@ class IPAddressTest(APIViewTestCases.APIViewTestCase):
648648
bulk_update_data = {
649649
'description': 'New description',
650650
}
651+
graphql_filter = {
652+
'address': '192.168.0.1/24',
653+
}
651654

652655
@classmethod
653656
def setUpTestData(cls):

netbox/netbox/graphql/filter_mixins.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def map_strawberry_type(field):
8787
pass
8888
elif issubclass(type(field), django_filters.NumberFilter):
8989
should_create_function = True
90-
attr_type = int
90+
attr_type = int | None
9191
elif issubclass(type(field), django_filters.ModelMultipleChoiceFilter):
9292
should_create_function = True
9393
attr_type = List[str] | None

netbox/utilities/testing/api.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -440,13 +440,12 @@ def _get_graphql_base_name(self):
440440
base_name = self.model._meta.verbose_name.lower().replace(' ', '_')
441441
return getattr(self, 'graphql_base_name', base_name)
442442

443-
def _build_query(self, name, **filters):
443+
def _build_query_with_filter(self, name, filter_string):
444+
"""
445+
Called by either _build_query or _build_filtered_query - construct the actual
446+
query given a name and filter string
447+
"""
444448
type_class = get_graphql_type_for_model(self.model)
445-
if filters:
446-
filter_string = ', '.join(f'{k}:{v}' for k, v in filters.items())
447-
filter_string = f'({filter_string})'
448-
else:
449-
filter_string = ''
450449

451450
# Compile list of fields to include
452451
fields_string = ''
@@ -492,6 +491,30 @@ def _build_query(self, name, **filters):
492491

493492
return query
494493

494+
def _build_filtered_query(self, name, **filters):
495+
"""
496+
Create a filtered query: i.e. ip_address_list(filters: {address: "1.1.1.1/24"}){.
497+
"""
498+
if filters:
499+
filter_string = ', '.join(f'{k}: "{v}"' for k, v in filters.items())
500+
filter_string = f'(filters: {{{filter_string}}})'
501+
else:
502+
filter_string = ''
503+
504+
return self._build_query_with_filter(name, filter_string)
505+
506+
def _build_query(self, name, **filters):
507+
"""
508+
Create a normal query - unfiltered or with a string query: i.e. site(name: "aaa"){.
509+
"""
510+
if filters:
511+
filter_string = ', '.join(f'{k}:{v}' for k, v in filters.items())
512+
filter_string = f'({filter_string})'
513+
else:
514+
filter_string = ''
515+
516+
return self._build_query_with_filter(name, filter_string)
517+
495518
@override_settings(LOGIN_REQUIRED=True)
496519
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*', 'auth.user'])
497520
def test_graphql_get_object(self):
@@ -550,6 +573,31 @@ def test_graphql_list_objects(self):
550573
self.assertNotIn('errors', data)
551574
self.assertGreater(len(data['data'][field_name]), 0)
552575

576+
@override_settings(LOGIN_REQUIRED=True)
577+
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*', 'auth.user'])
578+
def test_graphql_filter_objects(self):
579+
if not hasattr(self, 'graphql_filter'):
580+
return
581+
582+
url = reverse('graphql')
583+
field_name = f'{self._get_graphql_base_name()}_list'
584+
query = self._build_filtered_query(field_name, **self.graphql_filter)
585+
586+
# Add object-level permission
587+
obj_perm = ObjectPermission(
588+
name='Test permission',
589+
actions=['view']
590+
)
591+
obj_perm.save()
592+
obj_perm.users.add(self.user)
593+
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
594+
595+
response = self.client.post(url, data={'query': query}, format="json", **self.header)
596+
self.assertHttpStatus(response, status.HTTP_200_OK)
597+
data = json.loads(response.content)
598+
self.assertNotIn('errors', data)
599+
self.assertGreater(len(data['data'][field_name]), 0)
600+
553601
class APIViewTestCase(
554602
GetObjectViewTestCase,
555603
ListObjectsViewTestCase,

0 commit comments

Comments
 (0)