From cffdae1a301bdfb4fef055acfffc16c80fca59a7 Mon Sep 17 00:00:00 2001 From: Brandon Miles Date: Wed, 23 Oct 2013 15:39:10 -0700 Subject: [PATCH] Added support for client_credentials grant type --- provider/oauth2/forms.py | 6 ++++++ provider/oauth2/tests.py | 19 +++++++++++++++++++ provider/oauth2/views.py | 8 +++++++- provider/views.py | 30 +++++++++++++++++++++++++++++- 4 files changed, 61 insertions(+), 2 deletions(-) diff --git a/provider/oauth2/forms.py b/provider/oauth2/forms.py index 1e0485a7..9540b659 100644 --- a/provider/oauth2/forms.py +++ b/provider/oauth2/forms.py @@ -301,3 +301,9 @@ def clean(self): data['user'] = user return data + +class ClientCredentialsGrantForm(ScopeMixin, OAuthForm): + """ + Validate a client credentials grant request. + """ + scope = ScopeChoiceField(choices=SCOPE_NAMES, required=False) diff --git a/provider/oauth2/tests.py b/provider/oauth2/tests.py index e5ad8f59..f0ef7ccc 100644 --- a/provider/oauth2/tests.py +++ b/provider/oauth2/tests.py @@ -356,6 +356,25 @@ def test_password_grant(self): self.assertEqual(400, response.status_code, response.content) self.assertEqual('invalid_grant', json.loads(response.content)['error']) + def test_client_credentials_grant(self): + response = self.client.post(self.access_token_url(), { + 'grant_type': 'client_credentials', + 'client_id': self.get_client().client_id, + 'client_secret': self.get_client().client_secret, + }) + + self.assertEqual(200, response.status_code, response.content) + + response = self.client.post(self.access_token_url(), { + 'grant_type': 'client_credentials', + 'client_id': self.get_client().client_id, + 'client_secret': self.get_client().client_secret + 'invalid', + }) + + self.assertEqual(400, response.status_code, response.content) + self.assertEqual('invalid_client', + json.loads(response.content)['error']) + class AuthBackendTest(BaseOAuth2TestCase): fixtures = ['test_oauth2'] diff --git a/provider/oauth2/views.py b/provider/oauth2/views.py index f90871e9..208367ea 100644 --- a/provider/oauth2/views.py +++ b/provider/oauth2/views.py @@ -5,7 +5,7 @@ from ..utils import now from .forms import AuthorizationRequestForm, AuthorizationForm from .forms import PasswordGrantForm, RefreshTokenGrantForm -from .forms import AuthorizationCodeGrantForm +from .forms import AuthorizationCodeGrantForm, ClientCredentialsGrantForm from .models import Client, RefreshToken, AccessToken from .backends import BasicClientBackend, RequestParamsClientBackend @@ -90,6 +90,12 @@ def get_password_grant(self, request, data, client): raise OAuthError(form.errors) return form.cleaned_data + def get_client_credentials_grant(self, request, data, client): + form = ClientCredentialsGrantForm(data, client=client) + if not form.is_valid(): + raise OAuthError(form.errors) + return form.cleaned_data + def get_access_token(self, request, user, scope, client): try: # Attempt to fetch an existing access token. diff --git a/provider/views.py b/provider/views.py index a774d4df..3fe50ce8 100644 --- a/provider/views.py +++ b/provider/views.py @@ -353,7 +353,8 @@ class AccessToken(OAuthView, Mixin): Authentication backends used to authenticate a particular client. """ - grant_types = ['authorization_code', 'refresh_token', 'password'] + grant_types = ['authorization_code', 'refresh_token', 'password', + 'client_credentials'] """ The default grant types supported by this view. """ @@ -382,6 +383,14 @@ def get_password_grant(self, request, data, client): """ raise NotImplementedError + def get_client_credentials_grant(self, request, data, client): + """ + Return the optional parameters (scope) associated with this request. + + :return: ``tuple`` - ``(True or False, options)`` + """ + raise NotImplementedError + def get_access_token(self, request, user, scope, client): """ Override to handle fetching of an existing access token. @@ -506,6 +515,23 @@ def password(self, request, data, client): return self.access_token_response(at) + def client_credentials(self, request, data, client): + """ + Handle ``grant_type=client_credentials`` requests as defined in + :rfc:`4.4`. + """ + data = self.get_client_credentials_grant(request, data, client) + scope = data.get('scope') + + if constants.SINGLE_ACCESS_TOKEN: + at = self.get_access_token(request, client.user, scope, client) + else: + at = self.create_access_token(request, client.user, scope, client) + rt = self.create_refresh_token(request, client.user, scope, at, + client) + + return self.access_token_response(at) + def get_handler(self, grant_type): """ Return a function or method that is capable handling the ``grant_type`` @@ -518,6 +544,8 @@ def get_handler(self, grant_type): return self.refresh_token elif grant_type == 'password': return self.password + elif grant_type == 'client_credentials': + return self.client_credentials return None def get(self, request):