From 12826138656079a0218444b9bfe28bee499741c0 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Fri, 5 Jul 2024 07:03:33 -0700 Subject: [PATCH 1/2] Add support for validating webhooks Signed-off-by: Mattt Zmuda --- replicate/client.py | 8 +++ replicate/webhook.py | 140 +++++++++++++++++++++++++++++++++++++++++ tests/test_webhooks.py | 121 +++++++++++++++++++++++++++++++++++ 3 files changed, 269 insertions(+) create mode 100644 replicate/webhook.py create mode 100644 tests/test_webhooks.py diff --git a/replicate/client.py b/replicate/client.py index b92245ab..b28da8b0 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -30,6 +30,7 @@ from replicate.run import async_run, run from replicate.stream import async_stream, stream from replicate.training import Trainings +from replicate.webhook import Webhooks if TYPE_CHECKING: from replicate.stream import ServerSentEvent @@ -144,6 +145,13 @@ def trainings(self) -> Trainings: """ return Trainings(client=self) + @property + def webhooks(self) -> Webhooks: + """ + Namespace for operations related to webhooks. + """ + return Webhooks(client=self) + def run( self, ref: str, diff --git a/replicate/webhook.py b/replicate/webhook.py new file mode 100644 index 00000000..880762ab --- /dev/null +++ b/replicate/webhook.py @@ -0,0 +1,140 @@ +import base64 +import hmac +from hashlib import sha256 +from typing import ( + TYPE_CHECKING, + Dict, + Optional, + overload, +) + +from replicate.resource import Namespace, Resource + +if TYPE_CHECKING: + import httpx + + +class WebhookSigningSecret(Resource): + key: str + + +class Webhooks(Namespace): + @property + def default(self) -> "Webhooks.Default": + """ + Namespace for operations related to the default webhook. + """ + + return self.Default(self._client) + + class Default(Namespace): + def secret(self) -> WebhookSigningSecret: + """ + Get the default webhook signing secret. + + Returns: + WebhookSigningSecret: The default webhook signing secret. + """ + + resp = self._client._request("GET", "/v1/webhooks/default/secret") + return WebhookSigningSecret(**resp.json()) + + async def async_secret(self) -> WebhookSigningSecret: + """ + Get the default webhook signing secret. + + Returns: + WebhookSigningSecret: The default webhook signing secret. + """ + + resp = await self._client._async_request( + "GET", "/v1/webhooks/default/secret" + ) + return WebhookSigningSecret(**resp.json()) + + @overload + @staticmethod + def validate(request: "httpx.Request", secret: WebhookSigningSecret) -> bool: ... + + @overload + @staticmethod + def validate( + headers: Dict[str, str], body: str, secret: WebhookSigningSecret + ) -> bool: ... + + @staticmethod + def validate( # type: ignore + request: Optional["httpx.Request"] = None, + headers: Optional[Dict[str, str]] = None, + body: Optional[str] = None, + secret: Optional[WebhookSigningSecret] = None, + ) -> bool: + """ + Validate the signature from an incoming webhook request using the provided secret. + + Args: + request (httpx.Request): The request object. + headers (Dict[str, str]): The request headers. + body (str): The request body. + secret (WebhookSigningSecret): The webhook signing secret. + + Returns: + bool: True if the request is valid, False otherwise. + + Raises: + ValueError: If there are missing headers, invalid secret key format, or missing body. + """ + + if not secret: + raise ValueError("Missing webhook signing secret") + + if request and any([headers, body]): + raise ValueError("Only one of request or headers/body can be provided") + + if request and request.headers: + webhook_id = request.headers.get("webhook-id") + timestamp = request.headers.get("webhook-timestamp") + signature = request.headers.get("webhook-signature") + body = request.content.decode("utf-8") + else: + if not headers: + raise ValueError("Missing webhook headers") + + # Convert headers to case-insensitive dictionary + headers = {k.lower(): v for k, v in headers.items()} + + webhook_id = headers.get("webhook-id") + timestamp = headers.get("webhook-timestamp") + signature = headers.get("webhook-signature") + + if not webhook_id: + raise ValueError("Missing webhook id") + if not timestamp: + raise ValueError("Missing webhook timestamp") + if not signature: + raise ValueError("Missing webhook signature") + if not body: + raise ValueError("Missing webhook body") + + signed_content = f"{webhook_id}.{timestamp}.{body}" + + key_parts = secret.key.split("_") + if len(key_parts) != 2: + raise ValueError(f"Invalid secret key format: {secret.key}") + + secret_bytes = base64.b64decode(key_parts[1]) + + h = hmac.new(secret_bytes, signed_content.encode(), sha256) + computed_signature = h.digest() + + for sig in signature.split(): + sig_parts = sig.split(",") + if len(sig_parts) < 2: + raise ValueError(f"Invalid signature format: {sig}") + + sig_bytes = base64.b64decode(sig_parts[1]) + + if hmac.compare_digest(sig_bytes, computed_signature): + return True + + return False diff --git a/tests/test_webhooks.py b/tests/test_webhooks.py new file mode 100644 index 00000000..d6d7fc39 --- /dev/null +++ b/tests/test_webhooks.py @@ -0,0 +1,121 @@ +import httpx +import pytest +import respx +from httpx import Request + +from replicate.client import Client +from replicate.webhook import Webhooks, WebhookSigningSecret + + +@pytest.fixture +def webhook_signing_secret(): + # This is a test secret and should not be used in production + return WebhookSigningSecret(key="whsec_MfKQ9r8GKYqrTwjUPD8ILPZIo2LaLaSw") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +@respx.mock +async def test_get_webhook_secret(async_flag, webhook_signing_secret): + respx.get("https://api.replicate.com/v1/webhooks/default/secret").mock( + return_value=httpx.Response(200, json={"key": webhook_signing_secret.key}) + ) + + client = Client(api_token="test-token") + + if async_flag: + secret = await client.webhooks.default.async_secret() + else: + secret = client.webhooks.default.secret() + + assert isinstance(secret, WebhookSigningSecret) + assert secret.key == webhook_signing_secret.key + + body = '{"test": 2432232314}' + headers = { + "Content-Type": "application/json", + "Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek", + "Webhook-Timestamp": "1614265330", + "Webhook-Signature": "v1,g0hM9SsE+OTPJTGt/tmIKtSyZlE3uFJELVlNIOLJ1OE=", + } + + request = Request( + method="POST", + url="http://test.host/webhook", + headers=headers, + content=body.encode(), + ) + + is_valid = client.webhooks.validate(request=request, secret=secret) + assert is_valid + + +def test_validate_webhook_invalid_signature(webhook_signing_secret): + headers = { + "Content-Type": "application/json", + "Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek", + "Webhook-Timestamp": "1614265330", + "Webhook-Signature": "v1,invalid_signature", + } + body = '{"test": 2432232314}' + + is_valid = Webhooks.validate( + headers=headers, body=body, secret=webhook_signing_secret + ) + assert not is_valid + + +def test_validate_webhook_missing_webhook_id(webhook_signing_secret): + headers = { + "Content-Type": "application/json", + } + body = '{"test": 2432232314}' + + with pytest.raises(ValueError, match="Missing webhook id"): + Webhooks.validate(headers=headers, body=body, secret=webhook_signing_secret) + + +def test_validate_webhook_invalid_secret(): + headers = { + "Content-Type": "application/json", + "Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek", + "Webhook-Timestamp": "1614265330", + "Webhook-Signature": "v1,invalid_signature", + } + body = '{"test": 2432232314}' + + with pytest.raises(ValueError, match="Invalid secret key format"): + Webhooks.validate( + headers=headers, + body=body, + secret=WebhookSigningSecret(key="invalid_secret_format"), + ) + + +def test_validate_webhook_missing_headers(webhook_signing_secret): + headers = None + body = '{"test": 2432232314}' + + with pytest.raises(ValueError, match="Missing webhook headers"): + Webhooks.validate( + headers=headers, # type: ignore + body=body, + secret=webhook_signing_secret, + ) + + +def test_validate_webhook_missing_body(webhook_signing_secret): + headers = { + "Content-Type": "application/json", + "Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek", + "Webhook-Timestamp": "1614265330", + "Webhook-Signature": "v1,g0hM9SsE+OTPJTGt/tmIKtSyZlE3uFJELVlNIOLJ1OE=", + } + body = None + + with pytest.raises(ValueError, match="Missing webhook body"): + Webhooks.validate( + headers=headers, + body=body, # type: ignore + secret=webhook_signing_secret, + ) From 9ee61b86eddd0fd6d9385151434513619f058a08 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Fri, 5 Jul 2024 11:45:16 -0700 Subject: [PATCH 2/2] Add optional tolerance parameter to validate timestamp Raise typed exceptions to indicate problem Change return type to None Signed-off-by: Mattt Zmuda --- replicate/webhook.py | 79 ++++++++++++++++++++++++++++++++++-------- tests/test_webhooks.py | 63 +++++++++++++++++++++++++++------ 2 files changed, 117 insertions(+), 25 deletions(-) diff --git a/replicate/webhook.py b/replicate/webhook.py index 880762ab..4258a05c 100644 --- a/replicate/webhook.py +++ b/replicate/webhook.py @@ -18,6 +18,30 @@ class WebhookSigningSecret(Resource): key: str +class WebhookValidationError(ValueError): + """Base exception for webhook validation errors.""" + + +class MissingWebhookHeaderError(WebhookValidationError): + """Exception raised when a required webhook header is missing.""" + + +class InvalidSecretKeyError(WebhookValidationError): + """Exception raised when the secret key format is invalid.""" + + +class MissingWebhookBodyError(WebhookValidationError): + """Exception raised when the webhook body is missing.""" + + +class InvalidTimestampError(WebhookValidationError): + """Exception raised when the webhook timestamp is invalid or outside the tolerance.""" + + +class InvalidSignatureError(WebhookValidationError): + """Exception raised when the webhook signature is invalid.""" + + class Webhooks(Namespace): @property def default(self) -> "Webhooks.Default": @@ -54,12 +78,19 @@ async def async_secret(self) -> WebhookSigningSecret: @overload @staticmethod - def validate(request: "httpx.Request", secret: WebhookSigningSecret) -> bool: ... + def validate( + request: "httpx.Request", + secret: WebhookSigningSecret, + tolerance: Optional[int] = None, + ) -> bool: ... @overload @staticmethod def validate( - headers: Dict[str, str], body: str, secret: WebhookSigningSecret + headers: Dict[str, str], + body: str, + secret: WebhookSigningSecret, + tolerance: Optional[int] = None, ) -> bool: ... @staticmethod @@ -68,7 +99,8 @@ def validate( # type: ignore headers: Optional[Dict[str, str]] = None, body: Optional[str] = None, secret: Optional[WebhookSigningSecret] = None, - ) -> bool: + tolerance: Optional[int] = None, + ) -> None: """ Validate the signature from an incoming webhook request using the provided secret. @@ -77,12 +109,17 @@ def validate( # type: ignore headers (Dict[str, str]): The request headers. body (str): The request body. secret (WebhookSigningSecret): The webhook signing secret. + tolerance (Optional[int]): Maximum allowed time difference (in seconds) between the current time and the webhook timestamp. Returns: - bool: True if the request is valid, False otherwise. + None: If the request is valid. Raises: - ValueError: If there are missing headers, invalid secret key format, or missing body. + MissingWebhookHeaderError: If required webhook headers are missing. + InvalidSecretKeyError: If the secret key format is invalid. + MissingWebhookBodyError: If the webhook body is missing. + InvalidTimestampError: If the webhook timestamp is invalid or outside the tolerance. + InvalidSignatureError: If the webhook signature is invalid. """ if not secret: @@ -98,7 +135,7 @@ def validate( # type: ignore body = request.content.decode("utf-8") else: if not headers: - raise ValueError("Missing webhook headers") + raise MissingWebhookHeaderError("Missing webhook headers") # Convert headers to case-insensitive dictionary headers = {k.lower(): v for k, v in headers.items()} @@ -108,33 +145,47 @@ def validate( # type: ignore signature = headers.get("webhook-signature") if not webhook_id: - raise ValueError("Missing webhook id") + raise MissingWebhookHeaderError("Missing webhook id") if not timestamp: - raise ValueError("Missing webhook timestamp") + raise MissingWebhookHeaderError("Missing webhook timestamp") if not signature: - raise ValueError("Missing webhook signature") + raise MissingWebhookHeaderError("Missing webhook signature") if not body: - raise ValueError("Missing webhook body") + raise MissingWebhookBodyError("Missing webhook body") + + if tolerance is not None: + import time # pylint: disable=import-outside-toplevel + + current_time = int(time.time()) + webhook_time = int(timestamp) + time_difference = abs(current_time - webhook_time) + if time_difference > tolerance: + raise InvalidTimestampError( + f"Webhook timestamp is outside the allowed tolerance of {tolerance} seconds" + ) signed_content = f"{webhook_id}.{timestamp}.{body}" key_parts = secret.key.split("_") if len(key_parts) != 2: - raise ValueError(f"Invalid secret key format: {secret.key}") + raise InvalidSecretKeyError(f"Invalid secret key format: {secret.key}") secret_bytes = base64.b64decode(key_parts[1]) h = hmac.new(secret_bytes, signed_content.encode(), sha256) computed_signature = h.digest() + valid = False for sig in signature.split(): sig_parts = sig.split(",") if len(sig_parts) < 2: - raise ValueError(f"Invalid signature format: {sig}") + raise InvalidSignatureError(f"Invalid signature format: {sig}") sig_bytes = base64.b64decode(sig_parts[1]) if hmac.compare_digest(sig_bytes, computed_signature): - return True + valid = True + break - return False + if not valid: + raise InvalidSignatureError("Webhook signature is invalid") diff --git a/tests/test_webhooks.py b/tests/test_webhooks.py index d6d7fc39..2e20ee25 100644 --- a/tests/test_webhooks.py +++ b/tests/test_webhooks.py @@ -1,10 +1,20 @@ +import time + import httpx import pytest import respx from httpx import Request from replicate.client import Client -from replicate.webhook import Webhooks, WebhookSigningSecret +from replicate.webhook import ( + InvalidSecretKeyError, + InvalidSignatureError, + InvalidTimestampError, + MissingWebhookBodyError, + MissingWebhookHeaderError, + Webhooks, + WebhookSigningSecret, +) @pytest.fixture @@ -46,8 +56,7 @@ async def test_get_webhook_secret(async_flag, webhook_signing_secret): content=body.encode(), ) - is_valid = client.webhooks.validate(request=request, secret=secret) - assert is_valid + client.webhooks.validate(request=request, secret=secret) def test_validate_webhook_invalid_signature(webhook_signing_secret): @@ -59,10 +68,8 @@ def test_validate_webhook_invalid_signature(webhook_signing_secret): } body = '{"test": 2432232314}' - is_valid = Webhooks.validate( - headers=headers, body=body, secret=webhook_signing_secret - ) - assert not is_valid + with pytest.raises(InvalidSignatureError, match="Webhook signature is invalid"): + Webhooks.validate(headers=headers, body=body, secret=webhook_signing_secret) def test_validate_webhook_missing_webhook_id(webhook_signing_secret): @@ -71,7 +78,7 @@ def test_validate_webhook_missing_webhook_id(webhook_signing_secret): } body = '{"test": 2432232314}' - with pytest.raises(ValueError, match="Missing webhook id"): + with pytest.raises(MissingWebhookHeaderError, match="Missing webhook id"): Webhooks.validate(headers=headers, body=body, secret=webhook_signing_secret) @@ -84,7 +91,7 @@ def test_validate_webhook_invalid_secret(): } body = '{"test": 2432232314}' - with pytest.raises(ValueError, match="Invalid secret key format"): + with pytest.raises(InvalidSecretKeyError, match="Invalid secret key format"): Webhooks.validate( headers=headers, body=body, @@ -96,7 +103,7 @@ def test_validate_webhook_missing_headers(webhook_signing_secret): headers = None body = '{"test": 2432232314}' - with pytest.raises(ValueError, match="Missing webhook headers"): + with pytest.raises(MissingWebhookHeaderError, match="Missing webhook headers"): Webhooks.validate( headers=headers, # type: ignore body=body, @@ -113,9 +120,43 @@ def test_validate_webhook_missing_body(webhook_signing_secret): } body = None - with pytest.raises(ValueError, match="Missing webhook body"): + with pytest.raises(MissingWebhookBodyError, match="Missing webhook body"): Webhooks.validate( headers=headers, body=body, # type: ignore secret=webhook_signing_secret, ) + + +@pytest.mark.parametrize( + "timestamp_offset, timestamp_invalid", + [ + (-3601, True), # 1 hour and 1 second in the past + (3601, True), # 1 hour and 1 second in the future + (-3599, False), # 59 minutes and 59 seconds in the past (valid) + (3599, False), # 59 minutes and 59 seconds in the future (valid) + ], +) +def test_validate_webhook_timestamp( + webhook_signing_secret, timestamp_offset, timestamp_invalid +): + current_time = int(time.time()) + timestamp = str(current_time + timestamp_offset) + + headers = { + "Content-Type": "application/json", + "Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek", + "Webhook-Timestamp": timestamp, + "Webhook-Signature": "v1,invalid_signature", + } + body = '{"test": 2432232314}' + + with pytest.raises( + (InvalidTimestampError if timestamp_invalid else InvalidSignatureError), + ): + Webhooks.validate( + headers=headers, + body=body, + secret=webhook_signing_secret, + tolerance=3600, + )