Skip to content

add viewset support #99

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 23, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions rest_framework_docs/api_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

class ApiDocumentation(object):

def __init__(self):
def __init__(self, drf_router=None):
self.endpoints = []
self.drf_router = drf_router
try:
root_urlconf = import_string(settings.ROOT_URLCONF)
except ImportError:
Expand All @@ -26,7 +27,7 @@ def get_all_view_names(self, urlpatterns, parent_pattern=None):
parent_pattern = None if pattern._regex == "^" else pattern
self.get_all_view_names(urlpatterns=pattern.url_patterns, parent_pattern=parent_pattern)
elif isinstance(pattern, RegexURLPattern) and self._is_drf_view(pattern) and not self._is_format_endpoint(pattern):
api_endpoint = ApiEndpoint(pattern, parent_pattern)
api_endpoint = ApiEndpoint(pattern, parent_pattern, self.drf_router)
self.endpoints.append(api_endpoint)

def _is_drf_view(self, pattern):
Expand Down
37 changes: 35 additions & 2 deletions rest_framework_docs/api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

class ApiEndpoint(object):

def __init__(self, pattern, parent_pattern=None):
def __init__(self, pattern, parent_pattern=None, drf_router=None):
self.drf_router = drf_router
self.pattern = pattern
self.callback = pattern.callback
# self.name = pattern.name
Expand All @@ -26,7 +27,39 @@ def __get_path__(self, parent_pattern):
return simplify_regex(self.pattern.regex.pattern)

def __get_allowed_methods__(self):
return [force_str(m).upper() for m in self.callback.cls.http_method_names if hasattr(self.callback.cls, m)]

viewset_methods = []
if self.drf_router:
for prefix, viewset, basename in self.drf_router.registry:
if self.callback.cls != viewset:
continue

lookup = self.drf_router.get_lookup_regex(viewset)
routes = self.drf_router.get_routes(viewset)

for route in routes:

# Only actions which actually exist on the viewset will be bound
mapping = self.drf_router.get_method_map(viewset, route.mapping)
if not mapping:
continue

# Build the url pattern
regex = route.url.format(
prefix=prefix,
lookup=lookup,
trailing_slash=self.drf_router.trailing_slash
)
if self.pattern.regex.pattern == regex:
funcs, viewset_methods = zip(
*[(mapping[m], m.upper()) for m in self.callback.cls.http_method_names if m in mapping]
)
viewset_methods = list(viewset_methods)
if len(set(funcs)) == 1:
self.docstring = inspect.getdoc(getattr(self.callback.cls, funcs[0]))

view_methods = [force_str(m).upper() for m in self.callback.cls.http_method_names if hasattr(self.callback.cls, m)]
return viewset_methods + view_methods

def __get_docstring__(self):
return inspect.getdoc(self.callback)
Expand Down
3 changes: 2 additions & 1 deletion rest_framework_docs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
class DRFDocsView(TemplateView):

template_name = "rest_framework_docs/home.html"
drf_router = None

def get_context_data(self, **kwargs):
settings = DRFSettings().settings
if settings["HIDE_DOCS"]:
raise Http404("Django Rest Framework Docs are hidden. Check your settings.")

context = super(DRFDocsView, self).get_context_data(**kwargs)
docs = ApiDocumentation()
docs = ApiDocumentation(drf_router=self.drf_router)
endpoints = docs.get_endpoints()

query = self.request.GET.get("search", "")
Expand Down
13 changes: 12 additions & 1 deletion tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_index_view_with_endpoints(self):
response = self.client.get(reverse('drfdocs'))

self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.context["endpoints"]), 11)
self.assertEqual(len(response.context["endpoints"]), 14)

# Test the login view
self.assertEqual(response.context["endpoints"][0].name_parent, "accounts")
Expand Down Expand Up @@ -67,3 +67,14 @@ def test_index_view_docs_hidden(self):

self.assertEqual(response.status_code, 404)
self.assertEqual(response.reason_phrase.upper(), "NOT FOUND")

def test_model_viewset(self):
response = self.client.get(reverse('drfdocs'))

self.assertEqual(response.status_code, 200)
self.assertEqual(response.context["endpoints"][10].path, '/organisation-model-viewsets/')
self.assertEqual(response.context["endpoints"][11].path, '/organisation-model-viewsets/<pk>/')
self.assertEqual(response.context["endpoints"][10].allowed_methods, ['GET', 'POST', 'OPTIONS'])
self.assertEqual(response.context["endpoints"][11].allowed_methods, ['GET', 'PUT', 'PATCH', 'DELETE', 'OPTIONS'])
self.assertEqual(response.context["endpoints"][12].allowed_methods, ['POST', 'OPTIONS'])
self.assertEqual(response.context["endpoints"][12].docstring, 'This is a test.')
8 changes: 7 additions & 1 deletion tests/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from django.conf.urls import include, url
from django.contrib import admin
from rest_framework.routers import SimpleRouter
from rest_framework_docs.views import DRFDocsView
from tests import views

accounts_urls = [
Expand All @@ -23,13 +25,17 @@
url(r'^(?P<slug>[\w-]+)/errored/$', view=views.OrganisationErroredView.as_view(), name="errored")
]

router = SimpleRouter()
router.register('organisation-model-viewsets', views.TestModelViewSet, base_name='organisation')

urlpatterns = [
url(r'^admin/', include(admin.site.urls)),
url(r'^docs/', include('rest_framework_docs.urls')),
url(r'^docs/', DRFDocsView.as_view(drf_router=router), name='drfdocs'),

# API
url(r'^accounts/', view=include(accounts_urls, namespace='accounts')),
url(r'^organisations/', view=include(organisations_urls, namespace='organisations')),
url(r'^', include(router.urls)),

# Endpoints without parents/namespaces
url(r'^another-login/$', views.LoginView.as_view(), name="login"),
Expand Down
12 changes: 12 additions & 0 deletions tests/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from rest_framework import parsers, renderers, generics, status
from rest_framework.authtoken.models import Token
from rest_framework.authtoken.serializers import AuthTokenSerializer
from rest_framework.decorators import detail_route
from rest_framework.permissions import AllowAny
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.viewsets import ModelViewSet
from tests.models import User, Organisation, Membership
from tests import serializers

Expand Down Expand Up @@ -132,3 +134,13 @@ def post(self, request):

def get_serializer_class(self):
return AuthTokenSerializer


class TestModelViewSet(ModelViewSet):
queryset = Organisation.objects.all()
serializer_class = serializers.OrganisationMembersSerializer

@detail_route(methods=['post'])
def test_route(self, request):
"""This is a test."""
return Response()