diff --git a/aws-replicator/aws_replicator/client/auth_proxy.py b/aws-replicator/aws_replicator/client/auth_proxy.py index 966abe2..d715bcb 100644 --- a/aws-replicator/aws_replicator/client/auth_proxy.py +++ b/aws-replicator/aws_replicator/client/auth_proxy.py @@ -5,20 +5,27 @@ import subprocess import sys from functools import cache +from io import BytesIO from typing import Dict, Optional, Tuple from urllib.parse import urlparse, urlunparse import boto3 import requests -from botocore.awsrequest import AWSPreparedRequest +from botocore.awsrequest import AWSPreparedRequest, AWSResponse +from botocore.httpchecksum import resolve_checksum_context from botocore.model import OperationModel from localstack import config from localstack import config as localstack_config +from localstack.aws.chain import HandlerChain +from localstack.aws.chain import RequestContext as AwsRequestContext +from localstack.aws.gateway import Gateway from localstack.aws.protocol.parser import create_parser from localstack.aws.spec import load_service from localstack.config import external_service_url from localstack.constants import AWS_REGION_US_EAST_1, DOCKER_IMAGE_NAME_PRO from localstack.http import Request +from localstack.http import Response as HttpResponse +from localstack.http.hypercorn import GatewayServer from localstack.utils.aws.aws_responses import requests_response from localstack.utils.bootstrap import setup_logging from localstack.utils.collections import select_attributes @@ -37,8 +44,6 @@ from aws_replicator.config import HANDLER_PATH_PROXIES from aws_replicator.shared.models import AddProxyRequest, ProxyConfig -from .http2_server import run_server - LOG = logging.getLogger(__name__) LOG.setLevel(logging.INFO) if config.DEBUG: @@ -57,6 +62,207 @@ DEFAULT_BIND_HOST = "127.0.0.1" +class AwsProxyHandler: + """ + A handler for an AWS Handler chain that attempts to forward the request using a specific boto3 session. + This can be used to proxy incoming requests to real AWS. + """ + + def __init__(self, session: boto3.Session = None): + self.session = session or boto3.Session() + + def __call__(self, chain: HandlerChain, context: AwsRequestContext, response: HttpResponse): + # prepare the API invocation parameters + LOG.info( + "Received %s.%s = %s", + context.service.service_name, + context.operation.name, + context.service_request, + ) + + # make the actual API call against upstream AWS (will also calculate a new auth signature) + try: + aws_response = self._make_aws_api_call(context) + except Exception: + LOG.exception( + "Exception while proxying %s.%s to AWS", + context.service.service_name, + context.operation.name, + ) + raise + + # tell the handler chain to respond + LOG.info( + "AWS Response %s.%s: url=%s status_code=%s, headers=%s, content=%s", + context.service.service_name, + context.operation.name, + aws_response.url, + aws_response.status_code, + aws_response.headers, + aws_response.content, + ) + chain.respond(aws_response.status_code, aws_response.content, dict(aws_response.headers)) + + def _make_aws_api_call(self, context: AwsRequestContext) -> AWSResponse: + # TODO: reconcile with AwsRequestProxy from localstack, and other forwarder tools + # create a real AWS client + client = self.session.client(context.service.service_name, region_name=context.region) + operation_model = context.operation + + # prepare API request parameters as expected by boto + api_params = {k: v for k, v in context.service_request.items() if v is not None} + + # this is a stripped down version of botocore's client._make_api_call to immediately get the HTTP + # response instead of a parsed response. + request_context = { + "client_region": client.meta.region_name, + "client_config": client.meta.config, + "has_streaming_input": operation_model.has_streaming_input, + "auth_type": operation_model.auth_type, + } + + ( + endpoint_url, + additional_headers, + properties, + ) = client._resolve_endpoint_ruleset(operation_model, api_params, request_context) + if properties: + # Pass arbitrary endpoint info with the Request + # for use during construction. + request_context["endpoint_properties"] = properties + + request_dict = client._convert_to_request_dict( + api_params=api_params, + operation_model=operation_model, + endpoint_url=endpoint_url, + context=request_context, + headers=additional_headers, + ) + resolve_checksum_context(request_dict, operation_model, api_params) + + if operation_model.has_streaming_input: + request_dict["body"] = request_dict["body"].read() + + self._adjust_request_dict(context.service.service_name, request_dict) + + if operation_model.has_streaming_input: + request_dict["body"] = BytesIO(request_dict["body"]) + + LOG.info("Making AWS request %s", request_dict) + http, _ = client._endpoint.make_request(operation_model, request_dict) + + http: AWSResponse + + # for some elusive reasons, these header modifications are needed (were part of http2_server) + http.headers.pop("Date", None) + http.headers.pop("Server", None) + if operation_model.has_streaming_output: + http.headers.pop("Content-Length", None) + + return http + + def _adjust_request_dict(self, service_name: str, request_dict: Dict): + """Apply minor fixes to the request dict, which seem to be required in the current setup.""" + # TODO: replacing localstack-specific URLs, IDs, etc, should ideally be done in a more generalized + # way. + + req_body = request_dict.get("body") + + # TODO: fix for switch between path/host addressing + # Note: the behavior seems to be different across botocore versions. Seems to be working + # with 1.29.97 (fix below not required) whereas newer versions like 1.29.151 require the fix. + if service_name == "s3": + body_str = run_safe(lambda: to_str(req_body)) or "" + + request_url = request_dict["url"] + url_parsed = list(urlparse(request_url)) + path_parts = url_parsed[2].strip("/").split("/") + bucket_subdomain_prefix = f"://{path_parts[0]}.s3." + if bucket_subdomain_prefix in request_url: + prefix = f"/{path_parts[0]}" + url_parsed[2] = url_parsed[2].removeprefix(prefix) + request_dict["url_path"] = request_dict["url_path"].removeprefix(prefix) + # replace empty path with "/" (seems required for signature calculation) + request_dict["url_path"] = request_dict["url_path"] or "/" + url_parsed[2] = url_parsed[2] or "/" + # re-construct final URL + request_dict["url"] = urlunparse(url_parsed) + + # TODO: this custom fix should not be required - investigate and remove! + if "' + f"{region}" + ) + + if service_name == "sqs" and isinstance(req_body, dict): + account_id = self._query_account_id_from_aws() + if "QueueUrl" in req_body: + queue_name = req_body["QueueUrl"].split("/")[-1] + req_body["QueueUrl"] = f"https://queue.amazonaws.com/{account_id}/{queue_name}" + if "QueueOwnerAWSAccountId" in req_body: + req_body["QueueOwnerAWSAccountId"] = account_id + if service_name == "sqs" and request_dict.get("url"): + req_json = run_safe(lambda: json.loads(body_str)) or {} + account_id = self._query_account_id_from_aws() + queue_name = req_json.get("QueueName") + if account_id and queue_name: + request_dict["url"] = f"https://queue.amazonaws.com/{account_id}/{queue_name}" + req_json["QueueOwnerAWSAccountId"] = account_id + request_dict["body"] = to_bytes(json.dumps(req_json)) + + def _fix_headers(self, request: Request, service_name: str): + if service_name == "s3": + # fix the Host header, to avoid bucket addressing issues + host = request.headers.get("Host") or "" + regex = r"^(https?://)?([0-9.]+|localhost)(:[0-9]+)?" + if re.match(regex, host): + request.headers["Host"] = re.sub(regex, r"\1s3.localhost.localstack.cloud", host) + request.headers.pop("Content-Length", None) + request.headers.pop("x-localstack-request-url", None) + request.headers.pop("X-Forwarded-For", None) + request.headers.pop("X-Localstack-Tgt-Api", None) + request.headers.pop("X-Moto-Account-Id", None) + request.headers.pop("Remote-Addr", None) + + @cache + def _query_account_id_from_aws(self) -> str: + sts_client = self.session.client("sts") + result = sts_client.get_caller_identity() + return result["Account"] + + +class AwsProxyGateway(Gateway): + """ + A handler chain that receives AWS requests, and proxies them transparently to upstream AWS using real + credentials. It de-constructs the incoming request, and creates a new request signed with the AWS + credentials configured in the environment. + """ + + def __init__(self) -> None: + from localstack.aws import handlers + + super().__init__( + request_handlers=[ + handlers.parse_service_name, + handlers.content_decoder, + handlers.add_region_from_header, + handlers.add_account_id, + handlers.parse_service_request, + AwsProxyHandler(), + ], + exception_handlers=[ + handlers.log_exception, + handlers.handle_internal_failure, + ], + context_class=AwsRequestContext, + ) + + class AuthProxyAWS(Server): def __init__(self, config: ProxyConfig, port: int = None): self.config = config @@ -65,9 +271,13 @@ def __init__(self, config: ProxyConfig, port: int = None): def do_run(self): self.register_in_instance() + bind_host = self.config.get("bind_host") or DEFAULT_BIND_HOST - proxy = run_server(port=self.port, bind_addresses=[bind_host], handler=self.proxy_request) - proxy.join() + srv = GatewayServer(AwsProxyGateway(), localstack_config.HostAndPort(bind_host, self.port)) + srv.start() + srv.join() + # proxy = run_server(port=self.port, bind_addresses=[bind_host], handler=self.proxy_request) + # proxy.join() def proxy_request(self, request: Request, data: bytes) -> Response: parsed = self._extract_region_and_service(request.headers) @@ -214,20 +424,23 @@ def _parse_aws_request( def _adjust_request_dict(self, service_name: str, request_dict: Dict): """Apply minor fixes to the request dict, which seem to be required in the current setup.""" - + # TODO: replacing localstack-specific URLs, IDs, etc, should ideally be done in a more generalized + # way. req_body = request_dict.get("body") - body_str = run_safe(lambda: to_str(req_body)) or "" - - # TODO: this custom fix should not be required - investigate and remove! - if "' - f"{region}" - ) + + if service_name == "s3": + body_str = run_safe(lambda: to_str(req_body)) or "" + + # TODO: this custom fix should not be required - investigate and remove! + if "' + f"{region}" + ) if service_name == "sqs" and isinstance(req_body, dict): account_id = self._query_account_id_from_aws() @@ -327,8 +540,7 @@ def start_aws_auth_proxy_in_container( command = [ "bash", "-c", - # TODO: manually installing quart/h11/hypercorn as a dirty quick fix for now. To be fixed! - f"{venv_activate}; pip install h11 hypercorn quart; pip install --upgrade --no-deps '{CLI_PIP_PACKAGE}'", + f"{venv_activate}; pip install --upgrade --no-deps '{CLI_PIP_PACKAGE}'", ] DOCKER_CLIENT.exec_in_container(container_name, command=command) diff --git a/aws-replicator/aws_replicator/server/aws_request_forwarder.py b/aws-replicator/aws_replicator/server/aws_request_forwarder.py index c2e6ba1..2bbe1fe 100644 --- a/aws-replicator/aws_replicator/server/aws_request_forwarder.py +++ b/aws-replicator/aws_replicator/server/aws_request_forwarder.py @@ -16,6 +16,7 @@ from localstack.utils.net import get_addressable_container_host from localstack.utils.strings import to_str, truncate from requests.structures import CaseInsensitiveDict +from rolo.proxy import forward try: from localstack.testing.config import TEST_AWS_ACCESS_KEY_ID @@ -37,15 +38,16 @@ def __call__(self, chain: HandlerChain, context: RequestContext, response: Respo return # forward request to proxy - response = self.forward_request(context, proxy) + response_ = self.forward_request(context, proxy) - if response is None: + if response_ is None: return + # Remove `Transfer-Encoding` header (which could be set to 'chunked'), to prevent client timeouts + response_.headers.pop("Transfer-Encoding", None) + # set response details, then stop handler chain to return response - chain.response.data = response.raw_content - chain.response.status_code = response.status_code - chain.response.headers.update(dict(response.headers)) + response.update_from(response_) chain.stop() def select_proxy(self, context: RequestContext) -> Optional[ProxyInstance]: @@ -126,6 +128,22 @@ def forward_request(self, context: RequestContext, proxy: ProxyInstance) -> requ port = proxy["port"] request = context.request target_host = get_addressable_container_host(default_local_hostname=LOCALHOST) + + try: + LOG.info("Forwarding request: %s", context) + response = forward(request, f"http://{target_host}:{port}") + LOG.info( + "Received response: status=%s headers=%s body=%s", + response.status_code, + response.headers, + response.data, + ) + except Exception: + LOG.exception("Exception while forwarding request") + raise + + return response + url = f"http://{target_host}:{port}{request.path}?{to_str(request.query_string)}" # inject Auth header, to ensure we're passing the right region to the proxy (e.g., for Cognito InitiateAuth) @@ -158,7 +176,7 @@ def forward_request(self, context: RequestContext, proxy: ProxyInstance) -> requ ) except requests.exceptions.ConnectionError: # remove unreachable proxy - LOG.info("Removing unreachable AWS forward proxy due to connection issue: %s", url) + LOG.exception("Removing unreachable AWS forward proxy due to connection issue: %s", url) self.PROXY_INSTANCES.pop(port, None) return result diff --git a/aws-replicator/aws_replicator/server/extension.py b/aws-replicator/aws_replicator/server/extension.py index 7fe30e0..aca9103 100644 --- a/aws-replicator/aws_replicator/server/extension.py +++ b/aws-replicator/aws_replicator/server/extension.py @@ -12,6 +12,10 @@ class AwsReplicatorExtension(Extension): name = "aws-replicator" def on_extension_load(self): + logging.getLogger("aws_replicator").setLevel( + logging.DEBUG if config.DEBUG else logging.INFO + ) + if config.GATEWAY_SERVER == "twisted": LOG.warning( "AWS resource replicator: The aws-replicator extension currently requires hypercorn as "