diff --git a/rest_framework_docs/api_endpoint.py b/rest_framework_docs/api_endpoint.py index f598e34..953f9a0 100644 --- a/rest_framework_docs/api_endpoint.py +++ b/rest_framework_docs/api_endpoint.py @@ -1,9 +1,17 @@ import json import inspect + from django.contrib.admindocs.views import simplify_regex from django.utils.encoding import force_str + +from rest_framework.viewsets import ModelViewSet from rest_framework.serializers import BaseSerializer +VIEWSET_METHODS = { + 'List': ['get', 'post'], + 'Instance': ['get', 'put', 'patch', 'delete'], +} + class ApiEndpoint(object): @@ -31,8 +39,14 @@ def __get_path__(self, parent_regex): return "/{0}{1}".format(self.name_parent, simplify_regex(self.pattern.regex.pattern)) return simplify_regex(self.pattern.regex.pattern) - def __get_allowed_methods__(self): + def is_method_allowed(self, callback_cls, method_name): + has_attr = hasattr(callback_cls, method_name) + viewset_method = (issubclass(callback_cls, ModelViewSet) and + method_name in VIEWSET_METHODS.get(self.callback.suffix, [])) + + return has_attr or viewset_method + def __get_allowed_methods__(self): viewset_methods = [] if self.drf_router: for prefix, viewset, basename in self.drf_router.registry: @@ -57,14 +71,18 @@ def __get_allowed_methods__(self): ) if self.pattern.regex.pattern == regex: funcs, viewset_methods = zip( - *[(mapping[m], m.upper()) for m in self.callback.cls.http_method_names if m in mapping] + *[(mapping[m], m.upper()) + for m in self.callback.cls.http_method_names + if m in mapping] ) viewset_methods = list(viewset_methods) if len(set(funcs)) == 1: self.docstring = inspect.getdoc(getattr(self.callback.cls, funcs[0])) - view_methods = [force_str(m).upper() for m in self.callback.cls.http_method_names if hasattr(self.callback.cls, m)] - return viewset_methods + view_methods + view_methods = [force_str(m).upper() + for m in self.callback.cls.http_method_names + if self.is_method_allowed(self.callback.cls, m)] + return sorted(viewset_methods + view_methods) def __get_docstring__(self): return inspect.getdoc(self.callback) diff --git a/tests/tests.py b/tests/tests.py index 998faee..f94736c 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -21,7 +21,7 @@ def test_settings_module(self): def test_index_view_with_endpoints(self): """ - Should load the drf focs view with all the endpoints. + Should load the drf docs view with all the endpoints. NOTE: Views that do **not** inherit from DRF's "APIView" are not included. """ response = self.client.get(reverse('drfdocs')) @@ -31,7 +31,7 @@ def test_index_view_with_endpoints(self): # Test the login view self.assertEqual(response.context["endpoints"][0].name_parent, "accounts") - self.assertEqual(response.context["endpoints"][0].allowed_methods, ['POST', 'OPTIONS']) + self.assertEqual(set(response.context["endpoints"][0].allowed_methods), set(['OPTIONS', 'POST'])) self.assertEqual(response.context["endpoints"][0].path, "/accounts/login/") self.assertEqual(response.context["endpoints"][0].docstring, "A view that allows users to login providing their username and password.") self.assertEqual(len(response.context["endpoints"][0].fields), 2) @@ -39,7 +39,7 @@ def test_index_view_with_endpoints(self): self.assertTrue(response.context["endpoints"][0].fields[0]["required"]) self.assertEqual(response.context["endpoints"][1].name_parent, "accounts") - self.assertEqual(response.context["endpoints"][1].allowed_methods, ['POST', 'OPTIONS']) + self.assertEqual(set(response.context["endpoints"][1].allowed_methods), set(['POST', 'OPTIONS'])) self.assertEqual(response.context["endpoints"][1].path, "/accounts/login2/") self.assertEqual(response.context["endpoints"][1].docstring, "A view that allows users to login providing their username and password. Without serializer_class") self.assertEqual(len(response.context["endpoints"][1].fields), 2) @@ -77,7 +77,7 @@ def test_model_viewset(self): self.assertEqual(response.context['endpoints'][6].fields[2]['to_many_relation'], True) self.assertEqual(response.context["endpoints"][11].path, '/organisation-model-viewsets/') self.assertEqual(response.context["endpoints"][12].path, '/organisation-model-viewsets//') - self.assertEqual(response.context["endpoints"][11].allowed_methods, ['GET', 'POST', 'OPTIONS']) - self.assertEqual(response.context["endpoints"][12].allowed_methods, ['GET', 'PUT', 'PATCH', 'DELETE', 'OPTIONS']) - self.assertEqual(response.context["endpoints"][13].allowed_methods, ['POST', 'OPTIONS']) + self.assertEqual(set(response.context["endpoints"][11].allowed_methods), set(['GET', 'POST', 'OPTIONS'])) + self.assertEqual(set(response.context["endpoints"][12].allowed_methods), set(['GET', 'PUT', 'PATCH', 'DELETE', 'OPTIONS'])) + self.assertEqual(set(response.context["endpoints"][13].allowed_methods), set(['POST', 'OPTIONS'])) self.assertEqual(response.context["endpoints"][13].docstring, 'This is a test.')