diff --git a/replicate/client.py b/replicate/client.py index b92245ab..b28da8b0 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -30,6 +30,7 @@ from replicate.run import async_run, run from replicate.stream import async_stream, stream from replicate.training import Trainings +from replicate.webhook import Webhooks if TYPE_CHECKING: from replicate.stream import ServerSentEvent @@ -144,6 +145,13 @@ def trainings(self) -> Trainings: """ return Trainings(client=self) + @property + def webhooks(self) -> Webhooks: + """ + Namespace for operations related to webhooks. + """ + return Webhooks(client=self) + def run( self, ref: str, diff --git a/replicate/webhook.py b/replicate/webhook.py new file mode 100644 index 00000000..4258a05c --- /dev/null +++ b/replicate/webhook.py @@ -0,0 +1,191 @@ +import base64 +import hmac +from hashlib import sha256 +from typing import ( + TYPE_CHECKING, + Dict, + Optional, + overload, +) + +from replicate.resource import Namespace, Resource + +if TYPE_CHECKING: + import httpx + + +class WebhookSigningSecret(Resource): + key: str + + +class WebhookValidationError(ValueError): + """Base exception for webhook validation errors.""" + + +class MissingWebhookHeaderError(WebhookValidationError): + """Exception raised when a required webhook header is missing.""" + + +class InvalidSecretKeyError(WebhookValidationError): + """Exception raised when the secret key format is invalid.""" + + +class MissingWebhookBodyError(WebhookValidationError): + """Exception raised when the webhook body is missing.""" + + +class InvalidTimestampError(WebhookValidationError): + """Exception raised when the webhook timestamp is invalid or outside the tolerance.""" + + +class InvalidSignatureError(WebhookValidationError): + """Exception raised when the webhook signature is invalid.""" + + +class Webhooks(Namespace): + @property + def default(self) -> "Webhooks.Default": + """ + Namespace for operations related to the default webhook. + """ + + return self.Default(self._client) + + class Default(Namespace): + def secret(self) -> WebhookSigningSecret: + """ + Get the default webhook signing secret. + + Returns: + WebhookSigningSecret: The default webhook signing secret. + """ + + resp = self._client._request("GET", "/v1/webhooks/default/secret") + return WebhookSigningSecret(**resp.json()) + + async def async_secret(self) -> WebhookSigningSecret: + """ + Get the default webhook signing secret. + + Returns: + WebhookSigningSecret: The default webhook signing secret. + """ + + resp = await self._client._async_request( + "GET", "/v1/webhooks/default/secret" + ) + return WebhookSigningSecret(**resp.json()) + + @overload + @staticmethod + def validate( + request: "httpx.Request", + secret: WebhookSigningSecret, + tolerance: Optional[int] = None, + ) -> bool: ... + + @overload + @staticmethod + def validate( + headers: Dict[str, str], + body: str, + secret: WebhookSigningSecret, + tolerance: Optional[int] = None, + ) -> bool: ... + + @staticmethod + def validate( # type: ignore + request: Optional["httpx.Request"] = None, + headers: Optional[Dict[str, str]] = None, + body: Optional[str] = None, + secret: Optional[WebhookSigningSecret] = None, + tolerance: Optional[int] = None, + ) -> None: + """ + Validate the signature from an incoming webhook request using the provided secret. + + Args: + request (httpx.Request): The request object. + headers (Dict[str, str]): The request headers. + body (str): The request body. + secret (WebhookSigningSecret): The webhook signing secret. + tolerance (Optional[int]): Maximum allowed time difference (in seconds) between the current time and the webhook timestamp. + + Returns: + None: If the request is valid. + + Raises: + MissingWebhookHeaderError: If required webhook headers are missing. + InvalidSecretKeyError: If the secret key format is invalid. + MissingWebhookBodyError: If the webhook body is missing. + InvalidTimestampError: If the webhook timestamp is invalid or outside the tolerance. + InvalidSignatureError: If the webhook signature is invalid. + """ + + if not secret: + raise ValueError("Missing webhook signing secret") + + if request and any([headers, body]): + raise ValueError("Only one of request or headers/body can be provided") + + if request and request.headers: + webhook_id = request.headers.get("webhook-id") + timestamp = request.headers.get("webhook-timestamp") + signature = request.headers.get("webhook-signature") + body = request.content.decode("utf-8") + else: + if not headers: + raise MissingWebhookHeaderError("Missing webhook headers") + + # Convert headers to case-insensitive dictionary + headers = {k.lower(): v for k, v in headers.items()} + + webhook_id = headers.get("webhook-id") + timestamp = headers.get("webhook-timestamp") + signature = headers.get("webhook-signature") + + if not webhook_id: + raise MissingWebhookHeaderError("Missing webhook id") + if not timestamp: + raise MissingWebhookHeaderError("Missing webhook timestamp") + if not signature: + raise MissingWebhookHeaderError("Missing webhook signature") + if not body: + raise MissingWebhookBodyError("Missing webhook body") + + if tolerance is not None: + import time # pylint: disable=import-outside-toplevel + + current_time = int(time.time()) + webhook_time = int(timestamp) + time_difference = abs(current_time - webhook_time) + if time_difference > tolerance: + raise InvalidTimestampError( + f"Webhook timestamp is outside the allowed tolerance of {tolerance} seconds" + ) + + signed_content = f"{webhook_id}.{timestamp}.{body}" + + key_parts = secret.key.split("_") + if len(key_parts) != 2: + raise InvalidSecretKeyError(f"Invalid secret key format: {secret.key}") + + secret_bytes = base64.b64decode(key_parts[1]) + + h = hmac.new(secret_bytes, signed_content.encode(), sha256) + computed_signature = h.digest() + + valid = False + for sig in signature.split(): + sig_parts = sig.split(",") + if len(sig_parts) < 2: + raise InvalidSignatureError(f"Invalid signature format: {sig}") + + sig_bytes = base64.b64decode(sig_parts[1]) + + if hmac.compare_digest(sig_bytes, computed_signature): + valid = True + break + + if not valid: + raise InvalidSignatureError("Webhook signature is invalid") diff --git a/tests/test_webhooks.py b/tests/test_webhooks.py new file mode 100644 index 00000000..2e20ee25 --- /dev/null +++ b/tests/test_webhooks.py @@ -0,0 +1,162 @@ +import time + +import httpx +import pytest +import respx +from httpx import Request + +from replicate.client import Client +from replicate.webhook import ( + InvalidSecretKeyError, + InvalidSignatureError, + InvalidTimestampError, + MissingWebhookBodyError, + MissingWebhookHeaderError, + Webhooks, + WebhookSigningSecret, +) + + +@pytest.fixture +def webhook_signing_secret(): + # This is a test secret and should not be used in production + return WebhookSigningSecret(key="whsec_MfKQ9r8GKYqrTwjUPD8ILPZIo2LaLaSw") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +@respx.mock +async def test_get_webhook_secret(async_flag, webhook_signing_secret): + respx.get("https://api.replicate.com/v1/webhooks/default/secret").mock( + return_value=httpx.Response(200, json={"key": webhook_signing_secret.key}) + ) + + client = Client(api_token="test-token") + + if async_flag: + secret = await client.webhooks.default.async_secret() + else: + secret = client.webhooks.default.secret() + + assert isinstance(secret, WebhookSigningSecret) + assert secret.key == webhook_signing_secret.key + + body = '{"test": 2432232314}' + headers = { + "Content-Type": "application/json", + "Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek", + "Webhook-Timestamp": "1614265330", + "Webhook-Signature": "v1,g0hM9SsE+OTPJTGt/tmIKtSyZlE3uFJELVlNIOLJ1OE=", + } + + request = Request( + method="POST", + url="http://test.host/webhook", + headers=headers, + content=body.encode(), + ) + + client.webhooks.validate(request=request, secret=secret) + + +def test_validate_webhook_invalid_signature(webhook_signing_secret): + headers = { + "Content-Type": "application/json", + "Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek", + "Webhook-Timestamp": "1614265330", + "Webhook-Signature": "v1,invalid_signature", + } + body = '{"test": 2432232314}' + + with pytest.raises(InvalidSignatureError, match="Webhook signature is invalid"): + Webhooks.validate(headers=headers, body=body, secret=webhook_signing_secret) + + +def test_validate_webhook_missing_webhook_id(webhook_signing_secret): + headers = { + "Content-Type": "application/json", + } + body = '{"test": 2432232314}' + + with pytest.raises(MissingWebhookHeaderError, match="Missing webhook id"): + Webhooks.validate(headers=headers, body=body, secret=webhook_signing_secret) + + +def test_validate_webhook_invalid_secret(): + headers = { + "Content-Type": "application/json", + "Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek", + "Webhook-Timestamp": "1614265330", + "Webhook-Signature": "v1,invalid_signature", + } + body = '{"test": 2432232314}' + + with pytest.raises(InvalidSecretKeyError, match="Invalid secret key format"): + Webhooks.validate( + headers=headers, + body=body, + secret=WebhookSigningSecret(key="invalid_secret_format"), + ) + + +def test_validate_webhook_missing_headers(webhook_signing_secret): + headers = None + body = '{"test": 2432232314}' + + with pytest.raises(MissingWebhookHeaderError, match="Missing webhook headers"): + Webhooks.validate( + headers=headers, # type: ignore + body=body, + secret=webhook_signing_secret, + ) + + +def test_validate_webhook_missing_body(webhook_signing_secret): + headers = { + "Content-Type": "application/json", + "Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek", + "Webhook-Timestamp": "1614265330", + "Webhook-Signature": "v1,g0hM9SsE+OTPJTGt/tmIKtSyZlE3uFJELVlNIOLJ1OE=", + } + body = None + + with pytest.raises(MissingWebhookBodyError, match="Missing webhook body"): + Webhooks.validate( + headers=headers, + body=body, # type: ignore + secret=webhook_signing_secret, + ) + + +@pytest.mark.parametrize( + "timestamp_offset, timestamp_invalid", + [ + (-3601, True), # 1 hour and 1 second in the past + (3601, True), # 1 hour and 1 second in the future + (-3599, False), # 59 minutes and 59 seconds in the past (valid) + (3599, False), # 59 minutes and 59 seconds in the future (valid) + ], +) +def test_validate_webhook_timestamp( + webhook_signing_secret, timestamp_offset, timestamp_invalid +): + current_time = int(time.time()) + timestamp = str(current_time + timestamp_offset) + + headers = { + "Content-Type": "application/json", + "Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek", + "Webhook-Timestamp": timestamp, + "Webhook-Signature": "v1,invalid_signature", + } + body = '{"test": 2432232314}' + + with pytest.raises( + (InvalidTimestampError if timestamp_invalid else InvalidSignatureError), + ): + Webhooks.validate( + headers=headers, + body=body, + secret=webhook_signing_secret, + tolerance=3600, + )