diff --git a/rest_framework_docs/api_docs.py b/rest_framework_docs/api_docs.py index d22dd4c..1364041 100644 --- a/rest_framework_docs/api_docs.py +++ b/rest_framework_docs/api_docs.py @@ -7,7 +7,6 @@ class ApiDocumentation(object): - def __init__(self, drf_router=None): self.endpoints = [] self.drf_router = drf_router @@ -21,13 +20,16 @@ def __init__(self, drf_router=None): else: self.get_all_view_names(root_urlconf.urlpatterns) - def get_all_view_names(self, urlpatterns, parent_pattern=None): + def get_all_view_names(self, urlpatterns, previous_parent_patterns=None): for pattern in urlpatterns: + parent_patterns = list(previous_parent_patterns or []) if isinstance(pattern, RegexURLResolver): - parent_pattern = None if pattern._regex == "^" else pattern - self.get_all_view_names(urlpatterns=pattern.url_patterns, parent_pattern=parent_pattern) - elif isinstance(pattern, RegexURLPattern) and self._is_drf_view(pattern) and not self._is_format_endpoint(pattern): - api_endpoint = ApiEndpoint(pattern, parent_pattern, self.drf_router) + if not pattern._regex == "^": + parent_patterns.append(pattern) + self.get_all_view_names(urlpatterns=pattern.url_patterns, previous_parent_patterns=parent_patterns) + elif isinstance(pattern, RegexURLPattern) and self._is_drf_view(pattern) \ + and not self._is_format_endpoint(pattern): + api_endpoint = ApiEndpoint(pattern, parent_patterns, self.drf_router) self.endpoints.append(api_endpoint) def _is_drf_view(self, pattern): diff --git a/rest_framework_docs/api_endpoint.py b/rest_framework_docs/api_endpoint.py index 89a33f8..8ef2080 100644 --- a/rest_framework_docs/api_endpoint.py +++ b/rest_framework_docs/api_endpoint.py @@ -7,14 +7,16 @@ class ApiEndpoint(object): - def __init__(self, pattern, parent_pattern=None, drf_router=None): + def __init__(self, pattern, parent_patterns=None, drf_router=None): self.drf_router = drf_router self.pattern = pattern self.callback = pattern.callback # self.name = pattern.name self.docstring = self.__get_docstring__() - self.name_parent = simplify_regex(parent_pattern.regex.pattern).strip('/') if parent_pattern else None - self.path = self.__get_path__(parent_pattern) + self.name_parent = ''.join([parent_pattern.regex.pattern for parent_pattern in (parent_patterns or [])]) + self.name_parent_suffix = '/' if self.name_parent.endswith('/') else '' + self.name_parent = simplify_regex(self.name_parent).strip('/') + self.path = self.__get_path__(parent_patterns) self.allowed_methods = self.__get_allowed_methods__() # self.view_name = pattern.callback.__name__ self.errors = None @@ -26,9 +28,9 @@ def __init__(self, pattern, parent_pattern=None, drf_router=None): self.permissions = self.__get_permissions_class__() - def __get_path__(self, parent_pattern): - if parent_pattern: - return "/{0}{1}".format(self.name_parent, simplify_regex(self.pattern.regex.pattern)) + def __get_path__(self, parent_patterns): + if parent_patterns: + return simplify_regex("{}{}{}".format(self.name_parent, self.name_parent_suffix, self.pattern.regex.pattern)) return simplify_regex(self.pattern.regex.pattern) def __get_allowed_methods__(self): diff --git a/tests/tests.py b/tests/tests.py index 998faee..4f4903c 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -27,7 +27,7 @@ def test_index_view_with_endpoints(self): response = self.client.get(reverse('drfdocs')) self.assertEqual(response.status_code, 200) - self.assertEqual(len(response.context["endpoints"]), 15) + self.assertEqual(len(response.context["endpoints"]), 17) # Test the login view self.assertEqual(response.context["endpoints"][0].name_parent, "accounts") @@ -49,6 +49,16 @@ def test_index_view_with_endpoints(self): # The view "OrganisationErroredView" (organisations/(?P[\w-]+)/errored/) should contain an error. self.assertEqual(str(response.context["endpoints"][9].errors), "'test_value'") + def test_deep_recurrence(self): + response = self.client.get("%s?search=inherited" % reverse("drfdocs")) + + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.context["endpoints"]), 2) + + endpoints = response.context["endpoints"] + self.assertEqual(endpoints[0].path, '/organisations/inherited/view/') + self.assertEqual(endpoints[1].path, '/organisations/inherited2no-slash/') + def test_index_search_with_endpoints(self): response = self.client.get("%s?search=reset-password" % reverse("drfdocs")) @@ -75,9 +85,9 @@ def test_model_viewset(self): self.assertEqual(response.context["endpoints"][10].path, '/organisations//') self.assertEqual(response.context['endpoints'][6].fields[2]['to_many_relation'], True) - self.assertEqual(response.context["endpoints"][11].path, '/organisation-model-viewsets/') - self.assertEqual(response.context["endpoints"][12].path, '/organisation-model-viewsets//') - self.assertEqual(response.context["endpoints"][11].allowed_methods, ['GET', 'POST', 'OPTIONS']) - self.assertEqual(response.context["endpoints"][12].allowed_methods, ['GET', 'PUT', 'PATCH', 'DELETE', 'OPTIONS']) - self.assertEqual(response.context["endpoints"][13].allowed_methods, ['POST', 'OPTIONS']) - self.assertEqual(response.context["endpoints"][13].docstring, 'This is a test.') + self.assertEqual(response.context["endpoints"][13].path, '/organisation-model-viewsets/') + self.assertEqual(response.context["endpoints"][14].path, '/organisation-model-viewsets//') + self.assertEqual(response.context["endpoints"][13].allowed_methods, ['GET', 'POST', 'OPTIONS']) + self.assertEqual(response.context["endpoints"][14].allowed_methods, ['GET', 'PUT', 'PATCH', 'DELETE', 'OPTIONS']) + self.assertEqual(response.context["endpoints"][15].allowed_methods, ['POST', 'OPTIONS']) + self.assertEqual(response.context["endpoints"][15].docstring, 'This is a test.') diff --git a/tests/urls.py b/tests/urls.py index abdf71b..52109e5 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -6,6 +6,14 @@ from rest_framework_docs.views import DRFDocsView from tests import views +inherited_urls = [ + url(r'^view/$', views.InheritedView.as_view(), name="inherited"), +] + +inherited_no_slash_urls = [ + url(r'^no-slash/$', views.InheritedView.as_view(), name="inherited-no-slash"), +] + accounts_urls = [ url(r'^login/$', views.LoginView.as_view(), name="login"), url(r'^login2/$', views.LoginWithSerilaizerClassView.as_view(), name="login2"), @@ -18,12 +26,16 @@ url(r'^test/$', views.TestView.as_view(), name="test-view"), ] + organisations_urls = [ url(r'^create/$', view=views.CreateOrganisationView.as_view(), name="create"), url(r'^(?P[\w-]+)/members/$', view=views.OrganisationMembersView.as_view(), name="members"), url(r'^(?P[\w-]+)/leave/$', view=views.LeaveOrganisationView.as_view(), name="leave"), url(r'^(?P[\w-]+)/errored/$', view=views.OrganisationErroredView.as_view(), name="errored"), url(r'^(?P[\w-]+)/$', view=views.RetrieveOrganisationView.as_view(), name="organisation"), + + url(r'^inherited/', include(inherited_urls)), + url(r'^inherited2', include(inherited_no_slash_urls)) ] router = SimpleRouter() diff --git a/tests/views.py b/tests/views.py index fcf319d..a7dfc26 100644 --- a/tests/views.py +++ b/tests/views.py @@ -149,3 +149,7 @@ def test_route(self, request): class RetrieveOrganisationView(generics.RetrieveAPIView): serializer_class = serializers.RetrieveOrganisationSerializer + + +class InheritedView(generics.CreateAPIView): + serializer_class = serializers.CreateOrganisationSerializer