@@ -440,13 +440,12 @@ def _get_graphql_base_name(self):
440
440
base_name = self .model ._meta .verbose_name .lower ().replace (' ' , '_' )
441
441
return getattr (self , 'graphql_base_name' , base_name )
442
442
443
- def _build_query (self , name , ** filters ):
443
+ def _build_query_with_filter (self , name , filter_string ):
444
+ """
445
+ Called by either _build_query or _build_filtered_query - construct the actual
446
+ query given a name and filter string
447
+ """
444
448
type_class = get_graphql_type_for_model (self .model )
445
- if filters :
446
- filter_string = ', ' .join (f'{ k } :{ v } ' for k , v in filters .items ())
447
- filter_string = f'({ filter_string } )'
448
- else :
449
- filter_string = ''
450
449
451
450
# Compile list of fields to include
452
451
fields_string = ''
@@ -492,6 +491,30 @@ def _build_query(self, name, **filters):
492
491
493
492
return query
494
493
494
+ def _build_filtered_query (self , name , ** filters ):
495
+ """
496
+ Create a filtered query: i.e. ip_address_list(filters: {address: "1.1.1.1/24"}){.
497
+ """
498
+ if filters :
499
+ filter_string = ', ' .join (f'{ k } : "{ v } "' for k , v in filters .items ())
500
+ filter_string = f'(filters: {{{ filter_string } }})'
501
+ else :
502
+ filter_string = ''
503
+
504
+ return self ._build_query_with_filter (name , filter_string )
505
+
506
+ def _build_query (self , name , ** filters ):
507
+ """
508
+ Create a normal query - unfiltered or with a string query: i.e. site(name: "aaa"){.
509
+ """
510
+ if filters :
511
+ filter_string = ', ' .join (f'{ k } :{ v } ' for k , v in filters .items ())
512
+ filter_string = f'({ filter_string } )'
513
+ else :
514
+ filter_string = ''
515
+
516
+ return self ._build_query_with_filter (name , filter_string )
517
+
495
518
@override_settings (LOGIN_REQUIRED = True )
496
519
@override_settings (EXEMPT_VIEW_PERMISSIONS = ['*' , 'auth.user' ])
497
520
def test_graphql_get_object (self ):
@@ -550,6 +573,31 @@ def test_graphql_list_objects(self):
550
573
self .assertNotIn ('errors' , data )
551
574
self .assertGreater (len (data ['data' ][field_name ]), 0 )
552
575
576
+ @override_settings (LOGIN_REQUIRED = True )
577
+ @override_settings (EXEMPT_VIEW_PERMISSIONS = ['*' , 'auth.user' ])
578
+ def test_graphql_filter_objects (self ):
579
+ if not hasattr (self , 'graphql_filter' ):
580
+ return
581
+
582
+ url = reverse ('graphql' )
583
+ field_name = f'{ self ._get_graphql_base_name ()} _list'
584
+ query = self ._build_filtered_query (field_name , ** self .graphql_filter )
585
+
586
+ # Add object-level permission
587
+ obj_perm = ObjectPermission (
588
+ name = 'Test permission' ,
589
+ actions = ['view' ]
590
+ )
591
+ obj_perm .save ()
592
+ obj_perm .users .add (self .user )
593
+ obj_perm .object_types .add (ObjectType .objects .get_for_model (self .model ))
594
+
595
+ response = self .client .post (url , data = {'query' : query }, format = "json" , ** self .header )
596
+ self .assertHttpStatus (response , status .HTTP_200_OK )
597
+ data = json .loads (response .content )
598
+ self .assertNotIn ('errors' , data )
599
+ self .assertGreater (len (data ['data' ][field_name ]), 0 )
600
+
553
601
class APIViewTestCase (
554
602
GetObjectViewTestCase ,
555
603
ListObjectsViewTestCase ,
0 commit comments