diff --git a/firebase_admin/_auth_client.py b/firebase_admin/_auth_client.py index 60be96811..2f6713d41 100644 --- a/firebase_admin/_auth_client.py +++ b/firebase_admin/_auth_client.py @@ -24,6 +24,7 @@ from firebase_admin import _user_identifier from firebase_admin import _user_import from firebase_admin import _user_mgt +from firebase_admin import _utils class Client: @@ -36,18 +37,37 @@ def __init__(self, app, tenant_id=None): 2. set the project ID explicitly via Firebase App options, or 3. set the project ID via the GOOGLE_CLOUD_PROJECT environment variable.""") - credential = app.credential.get_credential() + credential = None version_header = 'Python/Admin/{0}'.format(firebase_admin.__version__) timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) + # Non-default endpoint URLs for emulator support are set in this dict later. + endpoint_urls = {} + self.emulated = False + + # If an emulator is present, check that the given value matches the expected format and set + # endpoint URLs to use the emulator. Additionally, use a fake credential. + emulator_host = _auth_utils.get_emulator_host() + if emulator_host: + base_url = 'http://{0}/identitytoolkit.googleapis.com'.format(emulator_host) + endpoint_urls['v1'] = base_url + '/v1' + endpoint_urls['v2beta1'] = base_url + '/v2beta1' + credential = _utils.EmulatorAdminCredentials() + self.emulated = True + else: + # Use credentials if provided + credential = app.credential.get_credential() + http_client = _http_client.JsonHttpClient( credential=credential, headers={'X-Client-Version': version_header}, timeout=timeout) self._tenant_id = tenant_id - self._token_generator = _token_gen.TokenGenerator(app, http_client) + self._token_generator = _token_gen.TokenGenerator( + app, http_client, url_override=endpoint_urls.get('v1')) self._token_verifier = _token_gen.TokenVerifier(app) - self._user_manager = _user_mgt.UserManager(http_client, app.project_id, tenant_id) + self._user_manager = _user_mgt.UserManager( + http_client, app.project_id, tenant_id, url_override=endpoint_urls.get('v1')) self._provider_manager = _auth_providers.ProviderConfigClient( - http_client, app.project_id, tenant_id) + http_client, app.project_id, tenant_id, url_override=endpoint_urls.get('v2beta1')) @property def tenant_id(self): diff --git a/firebase_admin/_auth_providers.py b/firebase_admin/_auth_providers.py index 46de6fe5f..5126c862c 100644 --- a/firebase_admin/_auth_providers.py +++ b/firebase_admin/_auth_providers.py @@ -166,9 +166,10 @@ class ProviderConfigClient: PROVIDER_CONFIG_URL = 'https://identitytoolkit.googleapis.com/v2beta1' - def __init__(self, http_client, project_id, tenant_id=None): + def __init__(self, http_client, project_id, tenant_id=None, url_override=None): self.http_client = http_client - self.base_url = '{0}/projects/{1}'.format(self.PROVIDER_CONFIG_URL, project_id) + url_prefix = url_override or self.PROVIDER_CONFIG_URL + self.base_url = '{0}/projects/{1}'.format(url_prefix, project_id) if tenant_id: self.base_url += '/tenants/{0}'.format(tenant_id) diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index 2226675f9..d8e49b1a1 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -15,6 +15,7 @@ """Firebase auth utils.""" import json +import os import re from urllib import parse @@ -22,6 +23,7 @@ from firebase_admin import _utils +EMULATOR_HOST_ENV_VAR = 'FIREBASE_AUTH_EMULATOR_HOST' MAX_CLAIMS_PAYLOAD_SIZE = 1000 RESERVED_CLAIMS = set([ 'acr', 'amr', 'at_hash', 'aud', 'auth_time', 'azp', 'cnf', 'c_hash', 'exp', 'iat', @@ -66,6 +68,19 @@ def __iter__(self): return self +def get_emulator_host(): + emulator_host = os.getenv(EMULATOR_HOST_ENV_VAR, '') + if emulator_host and '//' in emulator_host: + raise ValueError( + 'Invalid {0}: "{1}". It must follow format "host:port".'.format( + EMULATOR_HOST_ENV_VAR, emulator_host)) + return emulator_host + + +def is_emulated(): + return get_emulator_host() != '' + + def validate_uid(uid, required=False): if uid is None and not required: return None diff --git a/firebase_admin/_token_gen.py b/firebase_admin/_token_gen.py index 18a8008c7..562e77fa5 100644 --- a/firebase_admin/_token_gen.py +++ b/firebase_admin/_token_gen.py @@ -53,6 +53,19 @@ METADATA_SERVICE_URL = ('http://metadata.google.internal/computeMetadata/v1/instance/' 'service-accounts/default/email') +# Emulator fake account +AUTH_EMULATOR_EMAIL = 'firebase-auth-emulator@example.com' + + +class _EmulatedSigner(google.auth.crypt.Signer): + key_id = None + + def __init__(self): + pass + + def sign(self, message): + return b'' + class _SigningProvider: """Stores a reference to a google.auth.crypto.Signer.""" @@ -78,21 +91,28 @@ def from_iam(cls, request, google_cred, service_account): signer = iam.Signer(request, google_cred, service_account) return _SigningProvider(signer, service_account) + @classmethod + def for_emulator(cls): + return _SigningProvider(_EmulatedSigner(), AUTH_EMULATOR_EMAIL) + class TokenGenerator: """Generates custom tokens and session cookies.""" ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1' - def __init__(self, app, http_client): + def __init__(self, app, http_client, url_override=None): self.app = app self.http_client = http_client self.request = transport.requests.Request() - self.base_url = '{0}/projects/{1}'.format(self.ID_TOOLKIT_URL, app.project_id) + url_prefix = url_override or self.ID_TOOLKIT_URL + self.base_url = '{0}/projects/{1}'.format(url_prefix, app.project_id) self._signing_provider = None def _init_signing_provider(self): """Initializes a signing provider by following the go/firebase-admin-sign protocol.""" + if _auth_utils.is_emulated(): + return _SigningProvider.for_emulator() # If the SDK was initialized with a service account, use it to sign bytes. google_cred = self.app.credential.get_credential() if isinstance(google_cred, google.oauth2.service_account.Credentials): @@ -285,12 +305,14 @@ def verify(self, token, request): verify_id_token_msg = ( 'See {0} for details on how to retrieve {1}.'.format(self.url, self.short_name)) + emulated = _auth_utils.is_emulated() + error_message = None if audience == FIREBASE_AUDIENCE: error_message = ( '{0} expects {1}, but was given a custom ' 'token.'.format(self.operation, self.articled_short_name)) - elif not header.get('kid'): + elif not emulated and not header.get('kid'): if header.get('alg') == 'HS256' and payload.get( 'v') == 0 and 'uid' in payload.get('d', {}): error_message = ( @@ -298,7 +320,7 @@ def verify(self, token, request): 'token.'.format(self.operation, self.articled_short_name)) else: error_message = 'Firebase {0} has no "kid" claim.'.format(self.short_name) - elif header.get('alg') != 'RS256': + elif not emulated and header.get('alg') != 'RS256': error_message = ( 'Firebase {0} has incorrect algorithm. Expected "RS256" but got ' '"{1}". {2}'.format(self.short_name, header.get('alg'), verify_id_token_msg)) @@ -329,11 +351,14 @@ def verify(self, token, request): raise self._invalid_token_error(error_message) try: - verified_claims = google.oauth2.id_token.verify_token( - token, - request=request, - audience=self.project_id, - certs_url=self.cert_url) + if emulated: + verified_claims = payload + else: + verified_claims = google.oauth2.id_token.verify_token( + token, + request=request, + audience=self.project_id, + certs_url=self.cert_url) verified_claims['uid'] = verified_claims['sub'] return verified_claims except google.auth.exceptions.TransportError as error: diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 1d97dd504..b60c4d100 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -573,9 +573,10 @@ class UserManager: ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1' - def __init__(self, http_client, project_id, tenant_id=None): + def __init__(self, http_client, project_id, tenant_id=None, url_override=None): self.http_client = http_client - self.base_url = '{0}/projects/{1}'.format(self.ID_TOOLKIT_URL, project_id) + url_prefix = url_override or self.ID_TOOLKIT_URL + self.base_url = '{0}/projects/{1}'.format(url_prefix, project_id) if tenant_id: self.base_url += '/tenants/{0}'.format(tenant_id) diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index a5fc8d022..8c640276c 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -18,6 +18,7 @@ import json import socket +import google.auth import googleapiclient import httplib2 import requests @@ -339,3 +340,20 @@ def _parse_platform_error(content, status_code): if not msg: msg = 'Unexpected HTTP response with status: {0}; body: {1}'.format(status_code, content) return error_dict, msg + + +# Temporarily disable the lint rule. For more information see: +# https://github.com/googleapis/google-auth-library-python/pull/561 +# pylint: disable=abstract-method +class EmulatorAdminCredentials(google.auth.credentials.Credentials): + """ Credentials for use with the firebase local emulator. + + This is used instead of user-supplied credentials or ADC. It will silently do nothing when + asked to refresh credentials. + """ + def __init__(self): + google.auth.credentials.Credentials.__init__(self) + self.token = 'owner' + + def refresh(self, request): + pass diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 3384bd440..1d293bb89 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -27,7 +27,6 @@ import threading from urllib import parse -import google.auth import requests import firebase_admin @@ -808,7 +807,7 @@ def get_client(self, db_url=None): emulator_config = self._get_emulator_config(parsed_url) if emulator_config: - credential = _EmulatorAdminCredentials() + credential = _utils.EmulatorAdminCredentials() base_url = emulator_config.base_url params = {'ns': emulator_config.namespace} else: @@ -965,14 +964,3 @@ def _extract_error_message(cls, response): message = 'Unexpected response from database: {0}'.format(response.content.decode()) return message - -# Temporarily disable the lint rule. For more information see: -# https://github.com/googleapis/google-auth-library-python/pull/561 -# pylint: disable=abstract-method -class _EmulatorAdminCredentials(google.auth.credentials.Credentials): - def __init__(self): - google.auth.credentials.Credentials.__init__(self) - self.token = 'owner' - - def refresh(self, request): - pass diff --git a/tests/test_auth_providers.py b/tests/test_auth_providers.py index 124aea3cc..0947c77ae 100644 --- a/tests/test_auth_providers.py +++ b/tests/test_auth_providers.py @@ -21,10 +21,18 @@ import firebase_admin from firebase_admin import auth from firebase_admin import exceptions -from firebase_admin import _auth_providers from tests import testutils -USER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v2beta1/projects/mock-project-id' +ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v2beta1' +EMULATOR_HOST_ENV_VAR = 'FIREBASE_AUTH_EMULATOR_HOST' +AUTH_EMULATOR_HOST = 'localhost:9099' +EMULATED_ID_TOOLKIT_URL = 'http://{}/identitytoolkit.googleapis.com/v2beta1'.format( + AUTH_EMULATOR_HOST) +URL_PROJECT_SUFFIX = '/projects/mock-project-id' +USER_MGT_URLS = { + 'ID_TOOLKIT': ID_TOOLKIT_URL, + 'PREFIX': ID_TOOLKIT_URL + URL_PROJECT_SUFFIX, +} OIDC_PROVIDER_CONFIG_RESPONSE = testutils.resource('oidc_provider_config.json') SAML_PROVIDER_CONFIG_RESPONSE = testutils.resource('saml_provider_config.json') LIST_OIDC_PROVIDER_CONFIGS_RESPONSE = testutils.resource('list_oidc_provider_configs.json') @@ -39,12 +47,18 @@ INVALID_PROVIDER_IDS = [None, True, False, 1, 0, list(), tuple(), dict(), ''] -@pytest.fixture(scope='module') -def user_mgt_app(): +@pytest.fixture(scope='module', params=[{'emulated': False}, {'emulated': True}]) +def user_mgt_app(request): + monkeypatch = testutils.new_monkeypatch() + if request.param['emulated']: + monkeypatch.setenv(EMULATOR_HOST_ENV_VAR, AUTH_EMULATOR_HOST) + monkeypatch.setitem(USER_MGT_URLS, 'ID_TOOLKIT', EMULATED_ID_TOOLKIT_URL) + monkeypatch.setitem(USER_MGT_URLS, 'PREFIX', EMULATED_ID_TOOLKIT_URL + URL_PROJECT_SUFFIX) app = firebase_admin.initialize_app(testutils.MockCredential(), name='providerConfig', options={'projectId': 'mock-project-id'}) yield app firebase_admin.delete_app(app) + monkeypatch.undo() def _instrument_provider_mgt(app, status, payload): @@ -52,7 +66,7 @@ def _instrument_provider_mgt(app, status, payload): provider_manager = client._provider_manager recorder = [] provider_manager.http_client.session.mount( - _auth_providers.ProviderConfigClient.PROVIDER_CONFIG_URL, + USER_MGT_URLS['ID_TOOLKIT'], testutils.MockAdapter(payload, status, recorder)) return recorder @@ -90,7 +104,7 @@ def test_get(self, user_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/oauthIdpConfigs/oidc.provider') + assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], '/oauthIdpConfigs/oidc.provider') @pytest.mark.parametrize('invalid_opts', [ {'provider_id': None}, {'provider_id': ''}, {'provider_id': 'saml.provider'}, @@ -116,7 +130,7 @@ def test_create(self, user_mgt_app): req = recorder[0] assert req.method == 'POST' assert req.url == '{0}/oauthIdpConfigs?oauthIdpConfigId=oidc.provider'.format( - USER_MGT_URL_PREFIX) + USER_MGT_URLS['PREFIX']) got = json.loads(req.body.decode()) assert got == self.OIDC_CONFIG_REQUEST @@ -136,7 +150,7 @@ def test_create_minimal(self, user_mgt_app): req = recorder[0] assert req.method == 'POST' assert req.url == '{0}/oauthIdpConfigs?oauthIdpConfigId=oidc.provider'.format( - USER_MGT_URL_PREFIX) + USER_MGT_URLS['PREFIX']) got = json.loads(req.body.decode()) assert got == want @@ -156,7 +170,7 @@ def test_create_empty_values(self, user_mgt_app): req = recorder[0] assert req.method == 'POST' assert req.url == '{0}/oauthIdpConfigs?oauthIdpConfigId=oidc.provider'.format( - USER_MGT_URL_PREFIX) + USER_MGT_URLS['PREFIX']) got = json.loads(req.body.decode()) assert got == want @@ -186,7 +200,7 @@ def test_update(self, user_mgt_app): assert req.method == 'PATCH' mask = ['clientId', 'displayName', 'enabled', 'issuer'] assert req.url == '{0}/oauthIdpConfigs/oidc.provider?updateMask={1}'.format( - USER_MGT_URL_PREFIX, ','.join(mask)) + USER_MGT_URLS['PREFIX'], ','.join(mask)) got = json.loads(req.body.decode()) assert got == self.OIDC_CONFIG_REQUEST @@ -201,7 +215,7 @@ def test_update_minimal(self, user_mgt_app): req = recorder[0] assert req.method == 'PATCH' assert req.url == '{0}/oauthIdpConfigs/oidc.provider?updateMask=displayName'.format( - USER_MGT_URL_PREFIX) + USER_MGT_URLS['PREFIX']) got = json.loads(req.body.decode()) assert got == {'displayName': 'oidcProviderName'} @@ -217,7 +231,7 @@ def test_update_empty_values(self, user_mgt_app): assert req.method == 'PATCH' mask = ['displayName', 'enabled'] assert req.url == '{0}/oauthIdpConfigs/oidc.provider?updateMask={1}'.format( - USER_MGT_URL_PREFIX, ','.join(mask)) + USER_MGT_URLS['PREFIX'], ','.join(mask)) got = json.loads(req.body.decode()) assert got == {'displayName': None, 'enabled': False} @@ -236,7 +250,7 @@ def test_delete(self, user_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'DELETE' - assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/oauthIdpConfigs/oidc.provider') + assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], '/oauthIdpConfigs/oidc.provider') @pytest.mark.parametrize('arg', [None, 'foo', list(), dict(), 0, -1, 101, False]) def test_invalid_max_results(self, user_mgt_app, arg): @@ -259,7 +273,7 @@ def test_list_single_page(self, user_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/oauthIdpConfigs?pageSize=100') + assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], '/oauthIdpConfigs?pageSize=100') def test_list_multiple_pages(self, user_mgt_app): sample_response = json.loads(OIDC_PROVIDER_CONFIG_RESPONSE) @@ -277,7 +291,7 @@ def test_list_multiple_pages(self, user_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}/oauthIdpConfigs?pageSize=10'.format(USER_MGT_URL_PREFIX) + assert req.url == '{0}/oauthIdpConfigs?pageSize=10'.format(USER_MGT_URLS['PREFIX']) # Page 2 (also the last page) response = {'oauthIdpConfigs': configs[2:]} @@ -289,7 +303,7 @@ def test_list_multiple_pages(self, user_mgt_app): req = recorder[0] assert req.method == 'GET' assert req.url == '{0}/oauthIdpConfigs?pageSize=10&pageToken=token'.format( - USER_MGT_URL_PREFIX) + USER_MGT_URLS['PREFIX']) def test_paged_iteration(self, user_mgt_app): sample_response = json.loads(OIDC_PROVIDER_CONFIG_RESPONSE) @@ -310,7 +324,7 @@ def test_paged_iteration(self, user_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}/oauthIdpConfigs?pageSize=100'.format(USER_MGT_URL_PREFIX) + assert req.url == '{0}/oauthIdpConfigs?pageSize=100'.format(USER_MGT_URLS['PREFIX']) # Page 2 (also the last page) response = {'oauthIdpConfigs': configs[2:]} @@ -322,7 +336,7 @@ def test_paged_iteration(self, user_mgt_app): req = recorder[0] assert req.method == 'GET' assert req.url == '{0}/oauthIdpConfigs?pageSize=100&pageToken=token'.format( - USER_MGT_URL_PREFIX) + USER_MGT_URLS['PREFIX']) with pytest.raises(StopIteration): next(iterator) @@ -421,7 +435,8 @@ def test_get(self, user_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/inboundSamlConfigs/saml.provider') + assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], + '/inboundSamlConfigs/saml.provider') @pytest.mark.parametrize('invalid_opts', [ {'provider_id': None}, {'provider_id': ''}, {'provider_id': 'oidc.provider'}, @@ -451,7 +466,7 @@ def test_create(self, user_mgt_app): req = recorder[0] assert req.method == 'POST' assert req.url == '{0}/inboundSamlConfigs?inboundSamlConfigId=saml.provider'.format( - USER_MGT_URL_PREFIX) + USER_MGT_URLS['PREFIX']) got = json.loads(req.body.decode()) assert got == self.SAML_CONFIG_REQUEST @@ -471,7 +486,7 @@ def test_create_minimal(self, user_mgt_app): req = recorder[0] assert req.method == 'POST' assert req.url == '{0}/inboundSamlConfigs?inboundSamlConfigId=saml.provider'.format( - USER_MGT_URL_PREFIX) + USER_MGT_URLS['PREFIX']) got = json.loads(req.body.decode()) assert got == want @@ -491,7 +506,7 @@ def test_create_empty_values(self, user_mgt_app): req = recorder[0] assert req.method == 'POST' assert req.url == '{0}/inboundSamlConfigs?inboundSamlConfigId=saml.provider'.format( - USER_MGT_URL_PREFIX) + USER_MGT_URLS['PREFIX']) got = json.loads(req.body.decode()) assert got == want @@ -528,7 +543,7 @@ def test_update(self, user_mgt_app): 'idpConfig.ssoUrl', 'spConfig.callbackUri', 'spConfig.spEntityId', ] assert req.url == '{0}/inboundSamlConfigs/saml.provider?updateMask={1}'.format( - USER_MGT_URL_PREFIX, ','.join(mask)) + USER_MGT_URLS['PREFIX'], ','.join(mask)) got = json.loads(req.body.decode()) assert got == self.SAML_CONFIG_REQUEST @@ -543,7 +558,7 @@ def test_update_minimal(self, user_mgt_app): req = recorder[0] assert req.method == 'PATCH' assert req.url == '{0}/inboundSamlConfigs/saml.provider?updateMask=displayName'.format( - USER_MGT_URL_PREFIX) + USER_MGT_URLS['PREFIX']) got = json.loads(req.body.decode()) assert got == {'displayName': 'samlProviderName'} @@ -559,7 +574,7 @@ def test_update_empty_values(self, user_mgt_app): assert req.method == 'PATCH' mask = ['displayName', 'enabled'] assert req.url == '{0}/inboundSamlConfigs/saml.provider?updateMask={1}'.format( - USER_MGT_URL_PREFIX, ','.join(mask)) + USER_MGT_URLS['PREFIX'], ','.join(mask)) got = json.loads(req.body.decode()) assert got == {'displayName': None, 'enabled': False} @@ -578,7 +593,8 @@ def test_delete(self, user_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'DELETE' - assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/inboundSamlConfigs/saml.provider') + assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], + '/inboundSamlConfigs/saml.provider') def test_config_not_found(self, user_mgt_app): _instrument_provider_mgt(user_mgt_app, 500, CONFIG_NOT_FOUND_RESPONSE) @@ -613,7 +629,8 @@ def test_list_single_page(self, user_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/inboundSamlConfigs?pageSize=100') + assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], + '/inboundSamlConfigs?pageSize=100') def test_list_multiple_pages(self, user_mgt_app): sample_response = json.loads(SAML_PROVIDER_CONFIG_RESPONSE) @@ -631,7 +648,7 @@ def test_list_multiple_pages(self, user_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}/inboundSamlConfigs?pageSize=10'.format(USER_MGT_URL_PREFIX) + assert req.url == '{0}/inboundSamlConfigs?pageSize=10'.format(USER_MGT_URLS['PREFIX']) # Page 2 (also the last page) response = {'inboundSamlConfigs': configs[2:]} @@ -643,7 +660,7 @@ def test_list_multiple_pages(self, user_mgt_app): req = recorder[0] assert req.method == 'GET' assert req.url == '{0}/inboundSamlConfigs?pageSize=10&pageToken=token'.format( - USER_MGT_URL_PREFIX) + USER_MGT_URLS['PREFIX']) def test_paged_iteration(self, user_mgt_app): sample_response = json.loads(SAML_PROVIDER_CONFIG_RESPONSE) @@ -664,7 +681,7 @@ def test_paged_iteration(self, user_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}/inboundSamlConfigs?pageSize=100'.format(USER_MGT_URL_PREFIX) + assert req.url == '{0}/inboundSamlConfigs?pageSize=100'.format(USER_MGT_URLS['PREFIX']) # Page 2 (also the last page) response = {'inboundSamlConfigs': configs[2:]} @@ -676,7 +693,7 @@ def test_paged_iteration(self, user_mgt_app): req = recorder[0] assert req.method == 'GET' assert req.url == '{0}/inboundSamlConfigs?pageSize=100&pageToken=token'.format( - USER_MGT_URL_PREFIX) + USER_MGT_URLS['PREFIX']) with pytest.raises(StopIteration): next(iterator) diff --git a/tests/test_db.py b/tests/test_db.py index 5f8ba4b51..aa2c83bd9 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -26,6 +26,7 @@ from firebase_admin import exceptions from firebase_admin import _http_client from firebase_admin import _sseclient +from firebase_admin import _utils from tests import testutils @@ -730,7 +731,7 @@ def test_parse_db_url(self, url, emulator_host, expected_base_url, expected_name assert ref._client._base_url == expected_base_url assert ref._client.params.get('ns') == expected_namespace if expected_base_url.startswith('http://localhost'): - assert isinstance(ref._client.credential, db._EmulatorAdminCredentials) + assert isinstance(ref._client.credential, _utils.EmulatorAdminCredentials) else: assert isinstance(ref._client.credential, testutils.MockGoogleCredential) finally: diff --git a/tests/test_token_gen.py b/tests/test_token_gen.py index f88c87ff4..29c70da80 100644 --- a/tests/test_token_gen.py +++ b/tests/test_token_gen.py @@ -55,6 +55,14 @@ 'NonEmptyDictToken': {'a': 1}, } +ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1' +EMULATOR_HOST_ENV_VAR = 'FIREBASE_AUTH_EMULATOR_HOST' +AUTH_EMULATOR_HOST = 'localhost:9099' +EMULATED_ID_TOOLKIT_URL = 'http://{}/identitytoolkit.googleapis.com/v1'.format(AUTH_EMULATOR_HOST) +TOKEN_MGT_URLS = { + 'ID_TOOLKIT': ID_TOOLKIT_URL, +} + # Fixture for mocking a HTTP server httpserver = plugin.httpserver @@ -68,13 +76,18 @@ def _merge_jwt_claims(defaults, overrides): def verify_custom_token(custom_token, expected_claims, tenant_id=None): assert isinstance(custom_token, bytes) - token = google.oauth2.id_token.verify_token( - custom_token, - testutils.MockRequest(200, MOCK_PUBLIC_CERTS), - _token_gen.FIREBASE_AUDIENCE) + expected_email = MOCK_SERVICE_ACCOUNT_EMAIL + if _is_emulated(): + expected_email = _token_gen.AUTH_EMULATOR_EMAIL + token = jwt.decode(custom_token, verify=False) + else: + token = google.oauth2.id_token.verify_token( + custom_token, + testutils.MockRequest(200, MOCK_PUBLIC_CERTS), + _token_gen.FIREBASE_AUDIENCE) assert token['uid'] == MOCK_UID - assert token['iss'] == MOCK_SERVICE_ACCOUNT_EMAIL - assert token['sub'] == MOCK_SERVICE_ACCOUNT_EMAIL + assert token['iss'] == expected_email + assert token['sub'] == expected_email if tenant_id is None: assert 'tenant_id' not in token else: @@ -121,7 +134,7 @@ def _instrument_user_manager(app, status, payload): user_manager = client._user_manager recorder = [] user_manager.http_client.session.mount( - _token_gen.TokenGenerator.ID_TOOLKIT_URL, + TOKEN_MGT_URLS['ID_TOOLKIT'], testutils.MockAdapter(payload, status, recorder)) return user_manager, recorder @@ -133,23 +146,41 @@ def _overwrite_iam_request(app, request): client = auth._get_client(app) client._token_generator.request = request -@pytest.fixture(scope='module') -def auth_app(): + +def _is_emulated(): + emulator_host = os.getenv(EMULATOR_HOST_ENV_VAR, '') + return emulator_host and '//' not in emulator_host + + +# These fixtures are set to the default function scope as the emulator environment variable bleeds +# over when in module scope. +@pytest.fixture(params=[{'emulated': False}, {'emulated': True}]) +def auth_app(request): """Returns an App initialized with a mock service account credential. This can be used in any scenario where the private key is required. Use user_mgt_app for everything else. """ + monkeypatch = testutils.new_monkeypatch() + if request.param['emulated']: + monkeypatch.setenv(EMULATOR_HOST_ENV_VAR, AUTH_EMULATOR_HOST) + monkeypatch.setitem(TOKEN_MGT_URLS, 'ID_TOOLKIT', EMULATED_ID_TOOLKIT_URL) app = firebase_admin.initialize_app(MOCK_CREDENTIAL, name='tokenGen') yield app firebase_admin.delete_app(app) - -@pytest.fixture(scope='module') -def user_mgt_app(): + monkeypatch.undo() + +@pytest.fixture(params=[{'emulated': False}, {'emulated': True}]) +def user_mgt_app(request): + monkeypatch = testutils.new_monkeypatch() + if request.param['emulated']: + monkeypatch.setenv(EMULATOR_HOST_ENV_VAR, AUTH_EMULATOR_HOST) + monkeypatch.setitem(TOKEN_MGT_URLS, 'ID_TOOLKIT', EMULATED_ID_TOOLKIT_URL) app = firebase_admin.initialize_app(testutils.MockCredential(), name='userMgt', options={'projectId': 'mock-project-id'}) yield app firebase_admin.delete_app(app) + monkeypatch.undo() @pytest.fixture def env_var_app(request): @@ -212,6 +243,12 @@ def test_invalid_params(self, auth_app, values): auth.create_custom_token(user, claims, app=auth_app) def test_noncert_credential(self, user_mgt_app): + if _is_emulated(): + # Should work fine with the emulator, so do a condensed version of + # test_sign_with_iam below. + custom_token = auth.create_custom_token(MOCK_UID, app=user_mgt_app).decode() + self._verify_signer(custom_token, _token_gen.AUTH_EMULATOR_EMAIL) + return with pytest.raises(ValueError): auth.create_custom_token(MOCK_UID, app=user_mgt_app) @@ -286,7 +323,7 @@ def test_sign_with_discovery_failure(self): def _verify_signer(self, token, signer): segments = token.split('.') assert len(segments) == 3 - body = json.loads(base64.b64decode(segments[1]).decode()) + body = jwt.decode(token, verify=False) assert body['iss'] == signer assert body['sub'] == signer @@ -388,14 +425,24 @@ class TestVerifyIdToken: 'BadFormatToken': 'foobar' } - @pytest.mark.parametrize('id_token', valid_tokens.values(), ids=list(valid_tokens)) - def test_valid_token(self, user_mgt_app, id_token): - _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) - claims = auth.verify_id_token(id_token, app=user_mgt_app) + tokens_accepted_in_emulator = [ + 'NoKid', + 'WrongKid', + 'FutureToken', + 'ExpiredToken' + ] + + def _assert_valid_token(self, id_token, app): + claims = auth.verify_id_token(id_token, app=app) assert claims['admin'] is True assert claims['uid'] == claims['sub'] assert claims['firebase']['sign_in_provider'] == 'provider' + @pytest.mark.parametrize('id_token', valid_tokens.values(), ids=list(valid_tokens)) + def test_valid_token(self, user_mgt_app, id_token): + _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) + self._assert_valid_token(id_token, app=user_mgt_app) + def test_valid_token_with_tenant(self, user_mgt_app): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) claims = auth.verify_id_token(TEST_ID_TOKEN_WITH_TENANT, app=user_mgt_app) @@ -440,8 +487,12 @@ def test_invalid_arg(self, user_mgt_app, id_token): auth.verify_id_token(id_token, app=user_mgt_app) assert 'Illegal ID token provided' in str(excinfo.value) - @pytest.mark.parametrize('id_token', invalid_tokens.values(), ids=list(invalid_tokens)) - def test_invalid_token(self, user_mgt_app, id_token): + @pytest.mark.parametrize('id_token_key', list(invalid_tokens)) + def test_invalid_token(self, user_mgt_app, id_token_key): + id_token = self.invalid_tokens[id_token_key] + if _is_emulated() and id_token_key in self.tokens_accepted_in_emulator: + self._assert_valid_token(id_token, user_mgt_app) + return _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) with pytest.raises(auth.InvalidIdTokenError) as excinfo: auth.verify_id_token(id_token, app=user_mgt_app) @@ -451,6 +502,9 @@ def test_invalid_token(self, user_mgt_app, id_token): def test_expired_token(self, user_mgt_app): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) id_token = self.invalid_tokens['ExpiredToken'] + if _is_emulated(): + self._assert_valid_token(id_token, user_mgt_app) + return with pytest.raises(auth.ExpiredIdTokenError) as excinfo: auth.verify_id_token(id_token, app=user_mgt_app) assert isinstance(excinfo.value, auth.InvalidIdTokenError) @@ -488,6 +542,10 @@ def test_custom_token(self, auth_app): def test_certificate_request_failure(self, user_mgt_app): _overwrite_cert_request(user_mgt_app, testutils.MockRequest(404, 'not found')) + if _is_emulated(): + # Shouldn't fetch certificates in emulator mode. + self._assert_valid_token(TEST_ID_TOKEN, app=user_mgt_app) + return with pytest.raises(auth.CertificateFetchError) as excinfo: auth.verify_id_token(TEST_ID_TOKEN, app=user_mgt_app) assert 'Could not fetch certificates' in str(excinfo.value) @@ -522,20 +580,28 @@ class TestVerifySessionCookie: 'IDToken': TEST_ID_TOKEN, } + cookies_accepted_in_emulator = [ + 'NoKid', + 'WrongKid', + 'FutureCookie', + 'ExpiredCookie' + ] + + def _assert_valid_cookie(self, cookie, app, check_revoked=False): + claims = auth.verify_session_cookie(cookie, app=app, check_revoked=check_revoked) + assert claims['admin'] is True + assert claims['uid'] == claims['sub'] + @pytest.mark.parametrize('cookie', valid_cookies.values(), ids=list(valid_cookies)) def test_valid_cookie(self, user_mgt_app, cookie): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) - claims = auth.verify_session_cookie(cookie, app=user_mgt_app) - assert claims['admin'] is True - assert claims['uid'] == claims['sub'] + self._assert_valid_cookie(cookie, user_mgt_app) @pytest.mark.parametrize('cookie', valid_cookies.values(), ids=list(valid_cookies)) def test_valid_cookie_check_revoked(self, user_mgt_app, cookie): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) _instrument_user_manager(user_mgt_app, 200, MOCK_GET_USER_RESPONSE) - claims = auth.verify_session_cookie(cookie, app=user_mgt_app, check_revoked=True) - assert claims['admin'] is True - assert claims['uid'] == claims['sub'] + self._assert_valid_cookie(cookie, app=user_mgt_app, check_revoked=True) @pytest.mark.parametrize('cookie', valid_cookies.values(), ids=list(valid_cookies)) def test_revoked_cookie_check_revoked(self, user_mgt_app, revoked_tokens, cookie): @@ -549,9 +615,7 @@ def test_revoked_cookie_check_revoked(self, user_mgt_app, revoked_tokens, cookie def test_revoked_cookie_does_not_check_revoked(self, user_mgt_app, revoked_tokens, cookie): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) _instrument_user_manager(user_mgt_app, 200, revoked_tokens) - claims = auth.verify_session_cookie(cookie, app=user_mgt_app, check_revoked=False) - assert claims['admin'] is True - assert claims['uid'] == claims['sub'] + self._assert_valid_cookie(cookie, app=user_mgt_app, check_revoked=False) @pytest.mark.parametrize('cookie', INVALID_JWT_ARGS.values(), ids=list(INVALID_JWT_ARGS)) def test_invalid_args(self, user_mgt_app, cookie): @@ -560,8 +624,12 @@ def test_invalid_args(self, user_mgt_app, cookie): auth.verify_session_cookie(cookie, app=user_mgt_app) assert 'Illegal session cookie provided' in str(excinfo.value) - @pytest.mark.parametrize('cookie', invalid_cookies.values(), ids=list(invalid_cookies)) - def test_invalid_cookie(self, user_mgt_app, cookie): + @pytest.mark.parametrize('cookie_key', list(invalid_cookies)) + def test_invalid_cookie(self, user_mgt_app, cookie_key): + cookie = self.invalid_cookies[cookie_key] + if _is_emulated() and cookie_key in self.cookies_accepted_in_emulator: + self._assert_valid_cookie(cookie, user_mgt_app) + return _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) with pytest.raises(auth.InvalidSessionCookieError) as excinfo: auth.verify_session_cookie(cookie, app=user_mgt_app) @@ -571,6 +639,9 @@ def test_invalid_cookie(self, user_mgt_app, cookie): def test_expired_cookie(self, user_mgt_app): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) cookie = self.invalid_cookies['ExpiredCookie'] + if _is_emulated(): + self._assert_valid_cookie(cookie, user_mgt_app) + return with pytest.raises(auth.ExpiredSessionCookieError) as excinfo: auth.verify_session_cookie(cookie, app=user_mgt_app) assert isinstance(excinfo.value, auth.InvalidSessionCookieError) @@ -603,6 +674,10 @@ def test_custom_token(self, auth_app): def test_certificate_request_failure(self, user_mgt_app): _overwrite_cert_request(user_mgt_app, testutils.MockRequest(404, 'not found')) + if _is_emulated(): + # Shouldn't fetch certificates in emulator mode. + auth.verify_session_cookie(TEST_SESSION_COOKIE, app=user_mgt_app) + return with pytest.raises(auth.CertificateFetchError) as excinfo: auth.verify_session_cookie(TEST_SESSION_COOKIE, app=user_mgt_app) assert 'Could not fetch certificates' in str(excinfo.value) @@ -619,9 +694,11 @@ def test_certificate_caching(self, user_mgt_app, httpserver): verifier.cookie_verifier.cert_url = httpserver.url verifier.id_token_verifier.cert_url = httpserver.url verifier.verify_session_cookie(TEST_SESSION_COOKIE) - assert len(httpserver.requests) == 1 + # No requests should be made in emulated mode + request_count = 0 if _is_emulated() else 1 + assert len(httpserver.requests) == request_count # Subsequent requests should not fetch certs from the server verifier.verify_session_cookie(TEST_SESSION_COOKIE) - assert len(httpserver.requests) == 1 + assert len(httpserver.requests) == request_count verifier.verify_id_token(TEST_ID_TOKEN) - assert len(httpserver.requests) == 1 + assert len(httpserver.requests) == request_count diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index 240f19bdc..ac80a92a6 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -50,19 +50,32 @@ } MOCK_ACTION_CODE_SETTINGS = auth.ActionCodeSettings(**MOCK_ACTION_CODE_DATA) -USER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v1/projects/mock-project-id' - TEST_TIMEOUT = 42 +ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1' +EMULATOR_HOST_ENV_VAR = 'FIREBASE_AUTH_EMULATOR_HOST' +AUTH_EMULATOR_HOST = 'localhost:9099' +EMULATED_ID_TOOLKIT_URL = 'http://{}/identitytoolkit.googleapis.com/v1'.format(AUTH_EMULATOR_HOST) +URL_PROJECT_SUFFIX = '/projects/mock-project-id' +USER_MGT_URLS = { + 'ID_TOOLKIT': ID_TOOLKIT_URL, + 'PREFIX': ID_TOOLKIT_URL + URL_PROJECT_SUFFIX, +} -@pytest.fixture(scope='module') -def user_mgt_app(): +@pytest.fixture(params=[{'emulated': False}, {'emulated': True}]) +def user_mgt_app(request): + monkeypatch = testutils.new_monkeypatch() + if request.param['emulated']: + monkeypatch.setenv(EMULATOR_HOST_ENV_VAR, AUTH_EMULATOR_HOST) + monkeypatch.setitem(USER_MGT_URLS, 'ID_TOOLKIT', EMULATED_ID_TOOLKIT_URL) + monkeypatch.setitem(USER_MGT_URLS, 'PREFIX', EMULATED_ID_TOOLKIT_URL + URL_PROJECT_SUFFIX) app = firebase_admin.initialize_app(testutils.MockCredential(), name='userMgt', options={'projectId': 'mock-project-id'}) yield app firebase_admin.delete_app(app) + monkeypatch.undo() -@pytest.fixture(scope='module') +@pytest.fixture def user_mgt_app_with_timeout(): app = firebase_admin.initialize_app( testutils.MockCredential(), @@ -77,7 +90,7 @@ def _instrument_user_manager(app, status, payload): user_manager = client._user_manager recorder = [] user_manager.http_client.session.mount( - _user_mgt.UserManager.ID_TOOLKIT_URL, + USER_MGT_URLS['ID_TOOLKIT'], testutils.MockAdapter(payload, status, recorder)) return user_manager, recorder @@ -121,7 +134,7 @@ def _check_request(recorder, want_url, want_body=None, want_timeout=None): assert len(recorder) == 1 req = recorder[0] assert req.method == 'POST' - assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, want_url) + assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], want_url) if want_body: body = json.loads(req.body.decode()) assert body == want_body diff --git a/tests/testutils.py b/tests/testutils.py index 556155253..4a77c9d80 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -16,6 +16,8 @@ import io import os +import pytest + from google.auth import credentials from google.auth import transport from requests import adapters @@ -58,6 +60,15 @@ def run_without_project_id(func): os.environ[env_var] = gcloud_project +def new_monkeypatch(): + try: + return pytest.MonkeyPatch() + except AttributeError: + # Fallback for Python 3.5 + from _pytest.monkeypatch import MonkeyPatch + return MonkeyPatch() + + class MockResponse(transport.Response): def __init__(self, status, response): self._status = status