diff --git a/src/firebase_functions/https_fn.py b/src/firebase_functions/https_fn.py index 61ff13f..10749e9 100644 --- a/src/firebase_functions/https_fn.py +++ b/src/firebase_functions/https_fn.py @@ -352,10 +352,8 @@ class CallableRequest(_typing.Generic[_core.T]): _C2 = _typing.Callable[[CallableRequest[_typing.Any]], _typing.Any] -def _on_call_handler(func: _C2, - request: Request, - enforce_app_check: bool, - verify_token: bool = True) -> Response: +def _on_call_handler(func: _C2, request: Request, + enforce_app_check: bool) -> Response: try: if not _util.valid_on_call_request(request): _logging.error("Invalid request, unable to process.") @@ -365,8 +363,7 @@ def _on_call_handler(func: _C2, data=_json.loads(request.data)["data"], ) - token_status = _util.on_call_check_tokens(request, - verify_token=verify_token) + token_status = _util.on_call_check_tokens(request) if token_status.auth == _util.OnCallTokenState.INVALID: raise HttpsError(FunctionsErrorCode.UNAUTHENTICATED, @@ -420,7 +417,7 @@ def _on_call_handler(func: _C2, def on_request(**kwargs) -> _typing.Callable[[_C1], _C1]: """ Handler which handles HTTPS requests. - Requires a function that takes a ``Request`` and ``Response`` object, + Requires a function that takes a ``Request`` and ``Response`` object, the same signature as a Flask app. Example: diff --git a/src/firebase_functions/private/util.py b/src/firebase_functions/private/util.py index 09521fc..b24ea26 100644 --- a/src/firebase_functions/private/util.py +++ b/src/firebase_functions/private/util.py @@ -212,11 +212,10 @@ def as_dict(self) -> dict: def _on_call_check_auth_token( - request: _Request, - verify_token: bool = True, + request: _Request ) -> None | _typing.Literal[OnCallTokenState.INVALID] | dict[str, _typing.Any]: """ - Validates the auth token in a callable request. + Validates the auth token in a callable request. If verify_token is False, the token will be decoded without verification. """ authorization = request.headers.get("Authorization") @@ -227,10 +226,7 @@ def _on_call_check_auth_token( return OnCallTokenState.INVALID try: id_token = authorization.replace("Bearer ", "") - if verify_token: - auth_token = _auth.verify_id_token(id_token) - else: - auth_token = _unsafe_decode_id_token(id_token) + auth_token = _auth.verify_id_token(id_token) return auth_token # pylint: disable=broad-except except Exception as err: @@ -273,25 +269,23 @@ def _unsafe_decode_id_token(token: str): return payload -def on_call_check_tokens(request: _Request, - verify_token: bool = True) -> _OnCallTokenVerification: +def on_call_check_tokens(request: _Request) -> _OnCallTokenVerification: """Check tokens""" verifications = _OnCallTokenVerification() - auth_token = _on_call_check_auth_token(request, verify_token=verify_token) + auth_token = _on_call_check_auth_token(request) if auth_token is None: verifications.auth = OnCallTokenState.MISSING elif isinstance(auth_token, dict): verifications.auth = OnCallTokenState.VALID verifications.auth_token = auth_token - if verify_token: - app_token = _on_call_check_app_token(request) - if app_token is None: - verifications.app = OnCallTokenState.MISSING - elif isinstance(app_token, dict): - verifications.app = OnCallTokenState.VALID - verifications.app_token = app_token + app_token = _on_call_check_app_token(request) + if app_token is None: + verifications.app = OnCallTokenState.MISSING + elif isinstance(app_token, dict): + verifications.app = OnCallTokenState.VALID + verifications.app_token = app_token log_payload = { **verifications.as_dict(), @@ -301,7 +295,7 @@ def on_call_check_tokens(request: _Request, } errs = [] - if verify_token and verifications.app == OnCallTokenState.INVALID: + if verifications.app == OnCallTokenState.INVALID: errs.append(("AppCheck token was rejected.", log_payload)) if verifications.auth == OnCallTokenState.INVALID: diff --git a/src/firebase_functions/tasks_fn.py b/src/firebase_functions/tasks_fn.py index e0ecf3b..7b8c675 100644 --- a/src/firebase_functions/tasks_fn.py +++ b/src/firebase_functions/tasks_fn.py @@ -16,14 +16,55 @@ # pylint: disable=protected-access import typing as _typing import functools as _functools +import dataclasses as _dataclasses +import json as _json -from flask import Request, Response +from flask import Request, Response, make_response as _make_response, jsonify as _jsonify +import firebase_functions.core as _core import firebase_functions.options as _options import firebase_functions.private.util as _util -from firebase_functions.https_fn import CallableRequest, _on_call_handler +from firebase_functions.https_fn import CallableRequest, HttpsError, FunctionsErrorCode + +from functions_framework import logging as _logging _C = _typing.Callable[[CallableRequest[_typing.Any]], _typing.Any] +_C1 = _typing.Callable[[Request], Response] +_C2 = _typing.Callable[[CallableRequest[_typing.Any]], _typing.Any] + + +def _on_call_handler(func: _C2, request: Request) -> Response: + try: + if not _util.valid_on_call_request(request): + _logging.error("Invalid request, unable to process.") + raise HttpsError(FunctionsErrorCode.INVALID_ARGUMENT, "Bad Request") + context: CallableRequest = CallableRequest( + raw_request=request, + data=_json.loads(request.data)["data"], + ) + + instance_id = request.headers.get("Firebase-Instance-ID-Token") + if instance_id is not None: + # Validating the token requires an http request, so we don't do it. + # If the user wants to use it for something, it will be validated then. + # Currently, the only real use case for this token is for sending + # pushes with FCM. In that case, the FCM APIs will validate the token. + context = _dataclasses.replace( + context, + instance_id_token=request.headers.get( + "Firebase-Instance-ID-Token"), + ) + result = _core._with_init(func)(context) + return _jsonify(result=result) + # Disable broad exceptions lint since we want to handle all exceptions here + # and wrap as an HttpsError. + # pylint: disable=broad-except + except Exception as err: + if not isinstance(err, HttpsError): + _logging.error("Unhandled error: %s", err) + err = HttpsError(FunctionsErrorCode.INTERNAL, "INTERNAL") + status = err._http_error_code.status + return _make_response(_jsonify(error=err._as_dict()), status) @_util.copy_func_kwargs(_options.TaskQueueOptions) @@ -53,10 +94,7 @@ def on_task_dispatched_decorator(func: _C): @_functools.wraps(func) def on_task_dispatched_wrapped(request: Request) -> Response: - return _on_call_handler(func, - request, - enforce_app_check=False, - verify_token=False) + return _on_call_handler(func, request) _util.set_func_endpoint_attr( on_task_dispatched_wrapped, diff --git a/tests/test_tasks_fn.py b/tests/test_tasks_fn.py index 531594c..b16ede3 100644 --- a/tests/test_tasks_fn.py +++ b/tests/test_tasks_fn.py @@ -71,41 +71,6 @@ def example(request: CallableRequest[object]) -> str: '{"result":"Hello World"}\n', ) - def test_token_is_decoded(self): - """ - Test that the token is decoded instead of verifying auth first. - """ - app = Flask(__name__) - - @on_task_dispatched() - def example(request: CallableRequest[object]) -> str: - auth = request.auth - # Make mypy happy - if auth is None: - self.fail("Auth is None") - return "No Auth" - self.assertEqual(auth.token["sub"], "firebase") - self.assertEqual(auth.token["name"], "John Doe") - return "Hello World" - - with app.test_request_context("/"): - # pylint: disable=line-too-long - test_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJmaXJlYmFzZSIsIm5hbWUiOiJKb2huIERvZSJ9.74A24Y821E7CZx8aYCsCKo0Y-W0qXwqME-14QlEMcB0" - environ = EnvironBuilder( - method="POST", - headers={ - "Authorization": f"Bearer {test_token}" - }, - json={ - "data": { - "test": "value" - }, - }, - ).get_environ() - request = Request(environ) - response = example(request) - self.assertEqual(response.status_code, 200) - def test_calls_init(self): hello = None