Skip to content

Merge support views #150

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jan 25, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions rest_framework_docs/api_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import json
import inspect

from django.contrib.admindocs.views import simplify_regex
from django.utils.encoding import force_str

from rest_framework.viewsets import ModelViewSet
from rest_framework.serializers import BaseSerializer

VIEWSET_METHODS = {
'List': ['get', 'post'],
'Instance': ['get', 'put', 'patch', 'delete'],
}


class ApiEndpoint(object):

Expand Down Expand Up @@ -31,8 +39,14 @@ def __get_path__(self, parent_regex):
return "/{0}{1}".format(self.name_parent, simplify_regex(self.pattern.regex.pattern))
return simplify_regex(self.pattern.regex.pattern)

def __get_allowed_methods__(self):
def is_method_allowed(self, callback_cls, method_name):
has_attr = hasattr(callback_cls, method_name)
viewset_method = (issubclass(callback_cls, ModelViewSet) and
method_name in VIEWSET_METHODS.get(self.callback.suffix, []))

return has_attr or viewset_method

def __get_allowed_methods__(self):
viewset_methods = []
if self.drf_router:
for prefix, viewset, basename in self.drf_router.registry:
Expand All @@ -57,14 +71,18 @@ def __get_allowed_methods__(self):
)
if self.pattern.regex.pattern == regex:
funcs, viewset_methods = zip(
*[(mapping[m], m.upper()) for m in self.callback.cls.http_method_names if m in mapping]
*[(mapping[m], m.upper())
for m in self.callback.cls.http_method_names
if m in mapping]
)
viewset_methods = list(viewset_methods)
if len(set(funcs)) == 1:
self.docstring = inspect.getdoc(getattr(self.callback.cls, funcs[0]))

view_methods = [force_str(m).upper() for m in self.callback.cls.http_method_names if hasattr(self.callback.cls, m)]
return viewset_methods + view_methods
view_methods = [force_str(m).upper()
for m in self.callback.cls.http_method_names
if self.is_method_allowed(self.callback.cls, m)]
return sorted(viewset_methods + view_methods)

def __get_docstring__(self):
return inspect.getdoc(self.callback)
Expand Down
12 changes: 6 additions & 6 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_settings_module(self):

def test_index_view_with_endpoints(self):
"""
Should load the drf focs view with all the endpoints.
Should load the drf docs view with all the endpoints.
NOTE: Views that do **not** inherit from DRF's "APIView" are not included.
"""
response = self.client.get(reverse('drfdocs'))
Expand All @@ -31,15 +31,15 @@ def test_index_view_with_endpoints(self):

# Test the login view
self.assertEqual(response.context["endpoints"][0].name_parent, "accounts")
self.assertEqual(response.context["endpoints"][0].allowed_methods, ['POST', 'OPTIONS'])
self.assertEqual(set(response.context["endpoints"][0].allowed_methods), set(['OPTIONS', 'POST']))
self.assertEqual(response.context["endpoints"][0].path, "/accounts/login/")
self.assertEqual(response.context["endpoints"][0].docstring, "A view that allows users to login providing their username and password.")
self.assertEqual(len(response.context["endpoints"][0].fields), 2)
self.assertEqual(response.context["endpoints"][0].fields[0]["type"], "CharField")
self.assertTrue(response.context["endpoints"][0].fields[0]["required"])

self.assertEqual(response.context["endpoints"][1].name_parent, "accounts")
self.assertEqual(response.context["endpoints"][1].allowed_methods, ['POST', 'OPTIONS'])
self.assertEqual(set(response.context["endpoints"][1].allowed_methods), set(['POST', 'OPTIONS']))
self.assertEqual(response.context["endpoints"][1].path, "/accounts/login2/")
self.assertEqual(response.context["endpoints"][1].docstring, "A view that allows users to login providing their username and password. Without serializer_class")
self.assertEqual(len(response.context["endpoints"][1].fields), 2)
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_model_viewset(self):
self.assertEqual(response.context['endpoints'][6].fields[2]['to_many_relation'], True)
self.assertEqual(response.context["endpoints"][11].path, '/organisation-model-viewsets/')
self.assertEqual(response.context["endpoints"][12].path, '/organisation-model-viewsets/<pk>/')
self.assertEqual(response.context["endpoints"][11].allowed_methods, ['GET', 'POST', 'OPTIONS'])
self.assertEqual(response.context["endpoints"][12].allowed_methods, ['GET', 'PUT', 'PATCH', 'DELETE', 'OPTIONS'])
self.assertEqual(response.context["endpoints"][13].allowed_methods, ['POST', 'OPTIONS'])
self.assertEqual(set(response.context["endpoints"][11].allowed_methods), set(['GET', 'POST', 'OPTIONS']))
self.assertEqual(set(response.context["endpoints"][12].allowed_methods), set(['GET', 'PUT', 'PATCH', 'DELETE', 'OPTIONS']))
self.assertEqual(set(response.context["endpoints"][13].allowed_methods), set(['POST', 'OPTIONS']))
self.assertEqual(response.context["endpoints"][13].docstring, 'This is a test.')