import logging
from base64 import b64decode
from json import JSONDecodeError
from typing import TYPE_CHECKING, Any, ClassVar, List, Optional, Union, cast

from flask import Flask, json, jsonify, make_response, request
from flask.views import View

from ..local import infra
from ..local.context import format_context
from ..local.event import format_http_event

if TYPE_CHECKING:
    from flask.wrappers import Request as FlaskRequest
    from flask.wrappers import Response as FlaskResponse

    from ..framework.v1 import hints

# TODO?: Switch to https://docs.python.org/3/library/http.html#http-methods
# for Python 3.11+
ALL_HTTP_METHODS = [
    "GET",
    "HEAD",
    "POST",
    "PUT",
    "DELETE",
    "CONNECT",
    "OPTIONS",
    "TRACE",
    "PATCH",
]
MAX_CONTENT_LENGTH = 6291456


class HandlerWrapper(View):  # type: ignore # Subclass of untyped class
    """View that emulates the provider-side processing of requests."""

    init_every_request: ClassVar[bool] = False

    def __init__(self, handler: "hints.Handler") -> None:
        self.handler = handler

    @property
    def logger(self) -> "logging.Logger":
        """Utility function to get a logger."""
        return logging.getLogger(self.handler.__name__)

    def dispatch_request(self, *_args: Any, **_kwargs: Any) -> "FlaskResponse":
        """Handle http requests."""
        self.emulate_core_preprocess(request)

        event = format_http_event(request)
        infra.inject_ingress_headers(request, event)

        context = format_context(self.handler)

        sub_response = self.emulate_subruntime(event, context)
        record = self.emulate_core_postprocess(sub_response)

        resp = self.resp_record_to_flask_response(record)
        infra.inject_egress_headers(resp)

        return resp

    def emulate_core_preprocess(self, req: "FlaskRequest") -> None:
        """Emulate the CoreRT guard."""
        if req.content_length and req.content_length > MAX_CONTENT_LENGTH:
            self.logger.warning(
                "Request is too big, should not exceed %s Mb but is %s Mb",
                MAX_CONTENT_LENGTH / (1 << 20),
                req.content_length / (1 << 20),
            )
        if req.path in ["/favicon.ico", "/robots.txt"]:
            self.logger.warning(
                "Requests to either favicon.ico or robots.txt are dropped"
            )

    def emulate_subruntime(
        self, event: "hints.Event", context: "hints.Context"
    ) -> "FlaskResponse":
        """Emulate the subruntime."""
        try:
            function_result = self.handler(event, context)
        except Exception as e:  # pylint: disable=broad-exception-caught # from subRT
            self.logger.warning(
                "Exception caught in handler %s, this will return a 500 when deployed",
                self.handler.__name__,
            )
            raise e
        if isinstance(function_result, str):
            return make_response(function_result)
        return jsonify(function_result)

    def emulate_core_postprocess(
        self, sub_response: "FlaskResponse"
    ) -> "hints.ResponseRecord":
        """Emulate the CoreRT runtime response processing.

        While it seems unecessary to generate an intermediate response,
        the serialization followed by a deserizalization does affect the final response.
        It also makes it easier to maintain compatibility with the CoreRT.
        """
        body = sub_response.get_data(as_text=True)
        response: "hints.ResponseRecord" = {
            "statusCode": sub_response.status_code,
            "headers": dict(sub_response.headers.items()),
            "body": body,
        }
        try:
            record = json.loads(body)
            if not isinstance(record, dict):
                return response

            # Not using the |= operator to manually drop unexpected keys
            response = cast(
                "hints.ResponseRecord",
                {
                    key: val
                    for key, val in record.items()
                    if key in response or key == "isBase64Encoded"
                },
            )
            return response
        except JSONDecodeError:
            return response

    def resp_record_to_flask_response(
        self, record: "hints.ResponseRecord"
    ) -> "FlaskResponse":
        """Transform the ReponseRecord into an http reponse."""
        body: Union[str, bytes] = record.get("body", "")
        if record.get("isBase64Encoded") and body:
            body = b64decode(cast(str, body).encode("utf-8"), validate=True)

        resp = make_response(body, record.get("statusCode"))

        # Those headers are added for convenience, but will be
        # overwritten if set in the handler
        resp.headers.add("Access-Control-Allow-Origin", "*")
        resp.headers.add("Access-Control-Allow-Headers", "Content-Type")

        resp.headers.update(record.get("headers") or {})

        return resp


class LocalFunctionServer:
    """LocalFunctionServer serves Scaleway FaaS handlers on a local http server."""

    def __init__(self) -> None:
        self.app = Flask("serverless_local")

    def add_handler(
        self,
        handler: "hints.Handler",
        relative_url: Optional[str] = None,
        http_methods: Optional[List[str]] = None,
    ) -> "LocalFunctionServer":
        """Add a handler to be served by the server.

        :param handler: serverless python handler
        :param relative_url: path to the handler, defaults to / + handler's name
        :param http_methods: HTTP methods for the handler, defaults to all methods
        """
        relative_url = relative_url if relative_url else "/" + handler.__name__
        if not relative_url.startswith("/"):
            relative_url = "/" + relative_url

        http_methods = http_methods if http_methods else ALL_HTTP_METHODS
        http_methods = [method.upper() for method in http_methods]

        view = HandlerWrapper(handler).as_view(handler.__name__, handler)

        # By default, methods contains ["GET", "HEAD", "OPTIONS"]
        self.app.add_url_rule(
            f"{relative_url}/<path:path>", methods=http_methods, view_func=view
        )
        self.app.add_url_rule(
            relative_url,
            methods=http_methods,
            defaults={"path": ""},
            view_func=view,
        )

        return self

    def serve(
        self, *args: Any, port: int = 8080, debug: bool = True, **kwargs: Any
    ) -> None:
        """Serve the added FaaS handlers.

        :param port: port that the server should listen on, defaults to 8080
        :param debug: run Flask in debug mode, enables hot-reloading and stack trace.
        """
        kwargs["port"] = port
        kwargs["debug"] = debug
        self.app.run(*args, **kwargs)


def serve_handler(
    handler: "hints.Handler",
    *args: Any,
    port: int = 8080,
    debug: bool = True,
    **kwargs: Any,
) -> None:
    """Serve a single FaaS handler on a local http server.

    :param handler: serverless python handler
    :param port: port that the server should listen on, defaults to 8080
    :param debug: run Flask in debug mode, enables hot-reloading and stack trace.

    Example:
        >>> def handle(event, _context):
        ...     return {"body": event["httpMethod"]}
        >>> serve_handler_locally(handle, port=8080)
    """
    server = LocalFunctionServer()
    server.add_handler(handler=handler, relative_url="/")
    server.serve(*args, port=port, debug=debug, **kwargs)