Skip to content

Allow http.HTTPMethod enum values in @action() decorator #512

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Nov 19, 2023
Merged
62 changes: 26 additions & 36 deletions rest_framework-stubs/decorators.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from collections.abc import Callable, Mapping, Sequence
from typing import Any, Literal, Protocol, TypeVar

Expand All @@ -16,6 +17,30 @@ from typing_extensions import Concatenate, ParamSpec, TypeAlias
_View = TypeVar("_View", bound=Callable[..., HttpResponseBase])
_P = ParamSpec("_P")
_RESP = TypeVar("_RESP", bound=HttpResponseBase)
_MixedCaseHttpMethod: TypeAlias = Literal[
"GET",
"POST",
"DELETE",
"PUT",
"PATCH",
"TRACE",
"HEAD",
"OPTIONS",
"get",
"post",
"delete",
"put",
"patch",
"trace",
"head",
"options",
]
if sys.version_info >= (3, 11):
from http import HTTPMethod

_HttpMethod: TypeAlias = _MixedCaseHttpMethod | HTTPMethod
else:
_HttpMethod: TypeAlias = _MixedCaseHttpMethod

class MethodMapper(dict):
def __init__(self, action: _View, methods: Sequence[str]) -> None: ...
Expand All @@ -29,43 +54,8 @@ class MethodMapper(dict):
def options(self, func: _View) -> _View: ...
def trace(self, func: _View) -> _View: ...

_LOWER_CASE_HTTP_VERBS: TypeAlias = Sequence[
Literal[
"get",
"post",
"delete",
"put",
"patch",
"trace",
"head",
"options",
]
]

_MIXED_CASE_HTTP_VERBS: TypeAlias = Sequence[
Literal[
"GET",
"POST",
"DELETE",
"PUT",
"PATCH",
"TRACE",
"HEAD",
"OPTIONS",
"get",
"post",
"delete",
"put",
"patch",
"trace",
"head",
"options",
]
]

class ViewSetAction(Protocol[_View]):
detail: bool
methods: _LOWER_CASE_HTTP_VERBS
url_path: str
url_name: str
kwargs: Mapping[str, Any]
Expand All @@ -84,7 +74,7 @@ def throttle_classes(throttle_classes: Sequence[BaseThrottle | type[BaseThrottle
def permission_classes(permission_classes: Sequence[_PermissionClass]) -> Callable[[_View], _View]: ...
def schema(view_inspector: ViewInspector | type[ViewInspector] | None) -> Callable[[_View], _View]: ...
def action(
methods: _MIXED_CASE_HTTP_VERBS | None = ...,
methods: Sequence[_HttpMethod] | None = ...,
detail: bool = ...,
url_path: str | None = ...,
url_name: str | None = ...,
Expand Down
20 changes: 18 additions & 2 deletions tests/typecheck/test_decorators.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,25 @@
from rest_framework.response import Response

class MyView(viewsets.ViewSet):

@action(methods=("get",), detail=False)
def view_func_1(self, request: Request) -> Response: ...
@action(methods=["post"], detail=False)
def view_func_2(self, request: Request) -> Response: ...
@action(methods=("GET",), detail=False)
def view_func_3(self, request: Request) -> Response: ...

- case: method_decorator_http_libary
skip: sys.version_info < (3, 11)
main: |
from http import HTTPMethod
from rest_framework import viewsets
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.response import Response

@action(methods=["post",], detail=False)
MY_VAR: HTTPMethod = HTTPMethod.POST
class MyView(viewsets.ViewSet):
@action(methods=[HTTPMethod.GET], detail=False)
def view_func_1(self, request: Request) -> Response: ...
@action(methods=[MY_VAR], detail=False)
def view_func_2(self, request: Request) -> Response: ...