|
| 1 | +import base64 |
| 2 | +import hmac |
| 3 | +from hashlib import sha256 |
| 4 | +from typing import ( |
| 5 | + TYPE_CHECKING, |
| 6 | + Dict, |
| 7 | + Optional, |
| 8 | + overload, |
| 9 | +) |
| 10 | + |
| 11 | +from replicate.resource import Namespace, Resource |
| 12 | + |
| 13 | +if TYPE_CHECKING: |
| 14 | + import httpx |
| 15 | + |
| 16 | + |
| 17 | +class WebhookSigningSecret(Resource): |
| 18 | + key: str |
| 19 | + |
| 20 | + |
| 21 | +class WebhookValidationError(ValueError): |
| 22 | + """Base exception for webhook validation errors.""" |
| 23 | + |
| 24 | + |
| 25 | +class MissingWebhookHeaderError(WebhookValidationError): |
| 26 | + """Exception raised when a required webhook header is missing.""" |
| 27 | + |
| 28 | + |
| 29 | +class InvalidSecretKeyError(WebhookValidationError): |
| 30 | + """Exception raised when the secret key format is invalid.""" |
| 31 | + |
| 32 | + |
| 33 | +class MissingWebhookBodyError(WebhookValidationError): |
| 34 | + """Exception raised when the webhook body is missing.""" |
| 35 | + |
| 36 | + |
| 37 | +class InvalidTimestampError(WebhookValidationError): |
| 38 | + """Exception raised when the webhook timestamp is invalid or outside the tolerance.""" |
| 39 | + |
| 40 | + |
| 41 | +class InvalidSignatureError(WebhookValidationError): |
| 42 | + """Exception raised when the webhook signature is invalid.""" |
| 43 | + |
| 44 | + |
| 45 | +class Webhooks(Namespace): |
| 46 | + @property |
| 47 | + def default(self) -> "Webhooks.Default": |
| 48 | + """ |
| 49 | + Namespace for operations related to the default webhook. |
| 50 | + """ |
| 51 | + |
| 52 | + return self.Default(self._client) |
| 53 | + |
| 54 | + class Default(Namespace): |
| 55 | + def secret(self) -> WebhookSigningSecret: |
| 56 | + """ |
| 57 | + Get the default webhook signing secret. |
| 58 | +
|
| 59 | + Returns: |
| 60 | + WebhookSigningSecret: The default webhook signing secret. |
| 61 | + """ |
| 62 | + |
| 63 | + resp = self._client._request("GET", "/v1/webhooks/default/secret") |
| 64 | + return WebhookSigningSecret(**resp.json()) |
| 65 | + |
| 66 | + async def async_secret(self) -> WebhookSigningSecret: |
| 67 | + """ |
| 68 | + Get the default webhook signing secret. |
| 69 | +
|
| 70 | + Returns: |
| 71 | + WebhookSigningSecret: The default webhook signing secret. |
| 72 | + """ |
| 73 | + |
| 74 | + resp = await self._client._async_request( |
| 75 | + "GET", "/v1/webhooks/default/secret" |
| 76 | + ) |
| 77 | + return WebhookSigningSecret(**resp.json()) |
| 78 | + |
| 79 | + @overload |
| 80 | + @staticmethod |
| 81 | + def validate( |
| 82 | + request: "httpx.Request", |
| 83 | + secret: WebhookSigningSecret, |
| 84 | + tolerance: Optional[int] = None, |
| 85 | + ) -> bool: ... |
| 86 | + |
| 87 | + @overload |
| 88 | + @staticmethod |
| 89 | + def validate( |
| 90 | + headers: Dict[str, str], |
| 91 | + body: str, |
| 92 | + secret: WebhookSigningSecret, |
| 93 | + tolerance: Optional[int] = None, |
| 94 | + ) -> bool: ... |
| 95 | + |
| 96 | + @staticmethod |
| 97 | + def validate( # type: ignore |
| 98 | + request: Optional["httpx.Request"] = None, |
| 99 | + headers: Optional[Dict[str, str]] = None, |
| 100 | + body: Optional[str] = None, |
| 101 | + secret: Optional[WebhookSigningSecret] = None, |
| 102 | + tolerance: Optional[int] = None, |
| 103 | + ) -> None: |
| 104 | + """ |
| 105 | + Validate the signature from an incoming webhook request using the provided secret. |
| 106 | +
|
| 107 | + Args: |
| 108 | + request (httpx.Request): The request object. |
| 109 | + headers (Dict[str, str]): The request headers. |
| 110 | + body (str): The request body. |
| 111 | + secret (WebhookSigningSecret): The webhook signing secret. |
| 112 | + tolerance (Optional[int]): Maximum allowed time difference (in seconds) between the current time and the webhook timestamp. |
| 113 | +
|
| 114 | + Returns: |
| 115 | + None: If the request is valid. |
| 116 | +
|
| 117 | + Raises: |
| 118 | + MissingWebhookHeaderError: If required webhook headers are missing. |
| 119 | + InvalidSecretKeyError: If the secret key format is invalid. |
| 120 | + MissingWebhookBodyError: If the webhook body is missing. |
| 121 | + InvalidTimestampError: If the webhook timestamp is invalid or outside the tolerance. |
| 122 | + InvalidSignatureError: If the webhook signature is invalid. |
| 123 | + """ |
| 124 | + |
| 125 | + if not secret: |
| 126 | + raise ValueError("Missing webhook signing secret") |
| 127 | + |
| 128 | + if request and any([headers, body]): |
| 129 | + raise ValueError("Only one of request or headers/body can be provided") |
| 130 | + |
| 131 | + if request and request.headers: |
| 132 | + webhook_id = request.headers.get("webhook-id") |
| 133 | + timestamp = request.headers.get("webhook-timestamp") |
| 134 | + signature = request.headers.get("webhook-signature") |
| 135 | + body = request.content.decode("utf-8") |
| 136 | + else: |
| 137 | + if not headers: |
| 138 | + raise MissingWebhookHeaderError("Missing webhook headers") |
| 139 | + |
| 140 | + # Convert headers to case-insensitive dictionary |
| 141 | + headers = {k.lower(): v for k, v in headers.items()} |
| 142 | + |
| 143 | + webhook_id = headers.get("webhook-id") |
| 144 | + timestamp = headers.get("webhook-timestamp") |
| 145 | + signature = headers.get("webhook-signature") |
| 146 | + |
| 147 | + if not webhook_id: |
| 148 | + raise MissingWebhookHeaderError("Missing webhook id") |
| 149 | + if not timestamp: |
| 150 | + raise MissingWebhookHeaderError("Missing webhook timestamp") |
| 151 | + if not signature: |
| 152 | + raise MissingWebhookHeaderError("Missing webhook signature") |
| 153 | + if not body: |
| 154 | + raise MissingWebhookBodyError("Missing webhook body") |
| 155 | + |
| 156 | + if tolerance is not None: |
| 157 | + import time # pylint: disable=import-outside-toplevel |
| 158 | + |
| 159 | + current_time = int(time.time()) |
| 160 | + webhook_time = int(timestamp) |
| 161 | + time_difference = abs(current_time - webhook_time) |
| 162 | + if time_difference > tolerance: |
| 163 | + raise InvalidTimestampError( |
| 164 | + f"Webhook timestamp is outside the allowed tolerance of {tolerance} seconds" |
| 165 | + ) |
| 166 | + |
| 167 | + signed_content = f"{webhook_id}.{timestamp}.{body}" |
| 168 | + |
| 169 | + key_parts = secret.key.split("_") |
| 170 | + if len(key_parts) != 2: |
| 171 | + raise InvalidSecretKeyError(f"Invalid secret key format: {secret.key}") |
| 172 | + |
| 173 | + secret_bytes = base64.b64decode(key_parts[1]) |
| 174 | + |
| 175 | + h = hmac.new(secret_bytes, signed_content.encode(), sha256) |
| 176 | + computed_signature = h.digest() |
| 177 | + |
| 178 | + valid = False |
| 179 | + for sig in signature.split(): |
| 180 | + sig_parts = sig.split(",") |
| 181 | + if len(sig_parts) < 2: |
| 182 | + raise InvalidSignatureError(f"Invalid signature format: {sig}") |
| 183 | + |
| 184 | + sig_bytes = base64.b64decode(sig_parts[1]) |
| 185 | + |
| 186 | + if hmac.compare_digest(sig_bytes, computed_signature): |
| 187 | + valid = True |
| 188 | + break |
| 189 | + |
| 190 | + if not valid: |
| 191 | + raise InvalidSignatureError("Webhook signature is invalid") |
0 commit comments