Skip to content

Commit e39ebc6

Browse files
authored
Delegating token generation to the google.auth.credentials library (#43)
* Delegating token generation to the google.auth.credentials library * Provided a default implementation for get_access_token() * Updated API doc
1 parent 884ac60 commit e39ebc6

File tree

6 files changed

+33
-101
lines changed

6 files changed

+33
-101
lines changed

firebase_admin/__init__.py

-21
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
_clock = datetime.datetime.utcnow
3131

3232
_DEFAULT_APP_NAME = '[DEFAULT]'
33-
_CLOCK_SKEW_SECONDS = 300
3433

3534

3635
def initialize_app(credential=None, options=None, name=_DEFAULT_APP_NAME):
@@ -182,7 +181,6 @@ def __init__(self, name, credential, options):
182181
'with a valid credential instance.')
183182
self._credential = credential
184183
self._options = _AppOptions(options)
185-
self._token = None
186184
self._lock = threading.RLock()
187185
self._services = {}
188186

@@ -198,25 +196,6 @@ def credential(self):
198196
def options(self):
199197
return self._options
200198

201-
def _get_token(self):
202-
"""Returns an OAuth2 bearer token.
203-
204-
This method may return a cached token. But it handles cache invalidation, and therefore
205-
is guaranteed to always return unexpired tokens.
206-
207-
Returns:
208-
string: An unexpired OAuth2 token.
209-
"""
210-
if not self._token_valid():
211-
self._token = self._credential.get_access_token()
212-
return self._token.access_token
213-
214-
def _token_valid(self):
215-
if self._token is None:
216-
return False
217-
skewed_expiry = self._token.expiry - datetime.timedelta(seconds=_CLOCK_SKEW_SECONDS)
218-
return _clock() < skewed_expiry
219-
220199
def _get_service(self, name, initializer):
221200
"""Returns the service instance identified by the given name.
222201

firebase_admin/credentials.py

+9-30
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,17 @@ class Base(object):
3636
"""Provides OAuth2 access tokens for accessing Firebase services."""
3737

3838
def get_access_token(self):
39-
"""Fetches a Google OAuth2 access token using this credential instance."""
40-
raise NotImplementedError
39+
"""Fetches a Google OAuth2 access token using this credential instance.
40+
41+
Returns:
42+
AccessTokenInfo: An access token obtained using the credential.
43+
"""
44+
google_cred = self.get_credential()
45+
google_cred.refresh(_request)
46+
return AccessTokenInfo(google_cred.token, google_cred.expiry)
4147

4248
def get_credential(self):
43-
"""Returns the credential instance used for authentication."""
49+
"""Returns the Google credential instance used for authentication."""
4450
raise NotImplementedError
4551

4652

@@ -88,15 +94,6 @@ def signer(self):
8894
def service_account_email(self):
8995
return self._g_credential.service_account_email
9096

91-
def get_access_token(self):
92-
"""Fetches a Google OAuth2 access token using this certificate credential.
93-
94-
Returns:
95-
AccessTokenInfo: An access token obtained using the credential.
96-
"""
97-
self._g_credential.refresh(_request)
98-
return AccessTokenInfo(self._g_credential.token, self._g_credential.expiry)
99-
10097
def get_credential(self):
10198
"""Returns the underlying Google credential.
10299
@@ -118,15 +115,6 @@ def __init__(self):
118115
super(ApplicationDefault, self).__init__()
119116
self._g_credential, self._project_id = google.auth.default(scopes=_scopes)
120117

121-
def get_access_token(self):
122-
"""Fetches a Google OAuth2 access token using this application default credential.
123-
124-
Returns:
125-
AccessTokenInfo: An access token obtained using the credential.
126-
"""
127-
self._g_credential.refresh(_request)
128-
return AccessTokenInfo(self._g_credential.token, self._g_credential.expiry)
129-
130118
def get_credential(self):
131119
"""Returns the underlying Google credential.
132120
@@ -184,15 +172,6 @@ def client_secret(self):
184172
def refresh_token(self):
185173
return self._g_credential.refresh_token
186174

187-
def get_access_token(self):
188-
"""Fetches a Google OAuth2 access token using this refresh token credential.
189-
190-
Returns:
191-
AccessTokenInfo: An access token obtained using the credential.
192-
"""
193-
self._g_credential.refresh(_request)
194-
return AccessTokenInfo(self._g_credential.token, self._g_credential.expiry)
195-
196175
def get_credential(self):
197176
"""Returns the underlying Google credential.
198177

firebase_admin/db.py

+5-15
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import numbers
2626
import sys
2727

28+
from google.auth import transport
2829
import requests
2930
import six
3031
from six.moves import urllib
@@ -553,14 +554,12 @@ def __init__(self, **kwargs):
553554
554555
Keyword Args:
555556
url: Firebase Realtime Database URL.
556-
auth: An instance of requests.auth.AuthBase for authenticating outgoing HTTP requests.
557557
session: An HTTP session created using the requests module.
558558
auth_override: A dictionary representing auth variable overrides or None (optional).
559559
Defaults to empty dict, which provides admin privileges. A None value here provides
560560
un-authenticated guest privileges.
561561
"""
562562
self._url = kwargs.pop('url')
563-
self._auth = kwargs.pop('auth')
564563
self._session = kwargs.pop('session')
565564
auth_override = kwargs.pop('auth_override', {})
566565
if auth_override != {}:
@@ -591,9 +590,10 @@ def from_app(cls, app):
591590
raise ValueError('Invalid databaseAuthVariableOverride option: "{0}". Override '
592591
'value must be a dict or None.'.format(auth_override))
593592

594-
session = requests.Session()
593+
g_credential = app.credential.get_credential()
594+
session = transport.requests.AuthorizedSession(g_credential)
595595
session.headers.update({'User-Agent': _USER_AGENT})
596-
return _Client(url='https://{0}'.format(parsed.netloc), auth=_OAuth(app),
596+
return _Client(url='https://{0}'.format(parsed.netloc),
597597
session=session, auth_override=auth_override)
598598

599599
def request(self, method, urlpath, **kwargs):
@@ -628,7 +628,7 @@ def _do_request(self, method, urlpath, **kwargs):
628628
params = self._auth_override
629629
kwargs['params'] = params
630630
try:
631-
resp = self._session.request(method, self._url + urlpath, auth=self._auth, **kwargs)
631+
resp = self._session.request(method, self._url + urlpath, **kwargs)
632632
resp.raise_for_status()
633633
return resp
634634
except requests.exceptions.RequestException as error:
@@ -663,13 +663,3 @@ def close(self):
663663
self._session.close()
664664
self._auth = None
665665
self._url = None
666-
667-
668-
class _OAuth(requests.auth.AuthBase):
669-
def __init__(self, app):
670-
self._app = app
671-
672-
def __call__(self, req):
673-
# pylint: disable=protected-access
674-
req.headers['Authorization'] = 'Bearer {0}'.format(self._app._get_token())
675-
return req

tests/test_app.py

-22
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# limitations under the License.
1414

1515
"""Tests for firebase_admin.App."""
16-
import datetime
17-
import json
1816
import os
1917

2018
import pytest
@@ -161,23 +159,3 @@ def test_app_delete(self, init_app):
161159
firebase_admin.get_app(init_app.name)
162160
with pytest.raises(ValueError):
163161
firebase_admin.delete_app(init_app)
164-
165-
def test_get_token(self, init_app):
166-
mock_response = {'access_token': 'mock_access_token_1', 'expires_in': 3600}
167-
credentials._request = testutils.MockRequest(200, json.dumps(mock_response))
168-
169-
assert init_app._get_token() == 'mock_access_token_1'
170-
171-
mock_response = {'access_token': 'mock_access_token_2', 'expires_in': 3600}
172-
credentials._request = testutils.MockRequest(200, json.dumps(mock_response))
173-
174-
expiry = init_app._token.expiry
175-
# should return same token from cache
176-
firebase_admin._clock = lambda: expiry - datetime.timedelta(
177-
seconds=firebase_admin._CLOCK_SKEW_SECONDS + 1)
178-
assert init_app._get_token() == 'mock_access_token_1'
179-
180-
# should return new token from RPC call
181-
firebase_admin._clock = lambda: expiry - datetime.timedelta(
182-
seconds=firebase_admin._CLOCK_SKEW_SECONDS)
183-
assert init_app._get_token() == 'mock_access_token_2'

tests/test_db.py

+12-13
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
"""Tests for firebase_admin.db."""
1616
import collections
17-
import datetime
1817
import json
1918
import sys
2019

@@ -47,12 +46,14 @@ def send(self, request, **kwargs):
4746

4847

4948
class MockCredential(credentials.Base):
50-
def get_access_token(self):
51-
expiry = datetime.datetime.utcnow() + datetime.timedelta(hours=24)
52-
return credentials.AccessTokenInfo('mock-token', expiry)
49+
"""A mock Firebase credential implementation."""
50+
51+
def __init__(self):
52+
self._g_credential = testutils.MockGoogleCredential()
5353

5454
def get_credential(self):
55-
return None
55+
return self._g_credential
56+
5657

5758
class _Object(object):
5859
pass
@@ -402,15 +403,15 @@ def test_no_app(self):
402403
db.reference()
403404

404405
def test_no_db_url(self):
405-
firebase_admin.initialize_app(credentials.Base())
406+
firebase_admin.initialize_app(MockCredential())
406407
with pytest.raises(ValueError):
407408
db.reference()
408409

409410
@pytest.mark.parametrize('url', [
410411
'https://test.firebaseio.com', 'https://test.firebaseio.com/'
411412
])
412413
def test_valid_db_url(self, url):
413-
firebase_admin.initialize_app(credentials.Base(), {'databaseURL' : url})
414+
firebase_admin.initialize_app(MockCredential(), {'databaseURL' : url})
414415
ref = db.reference()
415416
assert ref._client._url == 'https://test.firebaseio.com'
416417
assert ref._client._auth_override is None
@@ -420,13 +421,13 @@ def test_valid_db_url(self, url):
420421
True, False, 1, 0, dict(), list(), tuple(), _Object()
421422
])
422423
def test_invalid_db_url(self, url):
423-
firebase_admin.initialize_app(credentials.Base(), {'databaseURL' : url})
424+
firebase_admin.initialize_app(MockCredential(), {'databaseURL' : url})
424425
with pytest.raises(ValueError):
425426
db.reference()
426427

427428
@pytest.mark.parametrize('override', [{}, {'uid':'user1'}, None])
428429
def test_valid_auth_override(self, override):
429-
firebase_admin.initialize_app(credentials.Base(), {
430+
firebase_admin.initialize_app(MockCredential(), {
430431
'databaseURL' : 'https://test.firebaseio.com',
431432
'databaseAuthVariableOverride': override
432433
})
@@ -441,7 +442,7 @@ def test_valid_auth_override(self, override):
441442
@pytest.mark.parametrize('override', [
442443
'', 'foo', 0, 1, True, False, list(), tuple(), _Object()])
443444
def test_invalid_auth_override(self, override):
444-
firebase_admin.initialize_app(credentials.Base(), {
445+
firebase_admin.initialize_app(MockCredential(), {
445446
'databaseURL' : 'https://test.firebaseio.com',
446447
'databaseAuthVariableOverride': override
447448
})
@@ -450,12 +451,10 @@ def test_invalid_auth_override(self, override):
450451

451452
def test_app_delete(self):
452453
app = firebase_admin.initialize_app(
453-
credentials.Base(), {'databaseURL' : 'https://test.firebaseio.com'})
454+
MockCredential(), {'databaseURL' : 'https://test.firebaseio.com'})
454455
ref = db.reference()
455456
assert ref is not None
456-
assert ref._client._auth is not None
457457
firebase_admin.delete_app(app)
458-
assert ref._client._auth is None
459458
with pytest.raises(ValueError):
460459
db.reference()
461460

tests/testutils.py

+7
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Common utility classes and functions for testing."""
1616
import os
1717

18+
from google.auth import credentials
1819
from google.auth import transport
1920
import firebase_admin
2021

@@ -68,3 +69,9 @@ def __init__(self, status, response):
6869

6970
def __call__(self, *args, **kwargs):
7071
return self.response
72+
73+
74+
class MockGoogleCredential(credentials.Credentials):
75+
"""A mock Google authentication credential."""
76+
def refresh(self, request):
77+
self.token = 'mock-token'

0 commit comments

Comments
 (0)