Skip to content

Add support for validating webhooks #321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions replicate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from replicate.run import async_run, run
from replicate.stream import async_stream, stream
from replicate.training import Trainings
from replicate.webhook import Webhooks

if TYPE_CHECKING:
from replicate.stream import ServerSentEvent
Expand Down Expand Up @@ -144,6 +145,13 @@ def trainings(self) -> Trainings:
"""
return Trainings(client=self)

@property
def webhooks(self) -> Webhooks:
"""
Namespace for operations related to webhooks.
"""
return Webhooks(client=self)

def run(
self,
ref: str,
Expand Down
191 changes: 191 additions & 0 deletions replicate/webhook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import base64
import hmac
from hashlib import sha256
from typing import (
TYPE_CHECKING,
Dict,
Optional,
overload,
)

from replicate.resource import Namespace, Resource

if TYPE_CHECKING:
import httpx


class WebhookSigningSecret(Resource):
key: str


class WebhookValidationError(ValueError):
"""Base exception for webhook validation errors."""


class MissingWebhookHeaderError(WebhookValidationError):
"""Exception raised when a required webhook header is missing."""


class InvalidSecretKeyError(WebhookValidationError):
"""Exception raised when the secret key format is invalid."""


class MissingWebhookBodyError(WebhookValidationError):
"""Exception raised when the webhook body is missing."""


class InvalidTimestampError(WebhookValidationError):
"""Exception raised when the webhook timestamp is invalid or outside the tolerance."""


class InvalidSignatureError(WebhookValidationError):
"""Exception raised when the webhook signature is invalid."""


class Webhooks(Namespace):
@property
def default(self) -> "Webhooks.Default":
"""
Namespace for operations related to the default webhook.
"""

return self.Default(self._client)

class Default(Namespace):
def secret(self) -> WebhookSigningSecret:
"""
Get the default webhook signing secret.

Returns:
WebhookSigningSecret: The default webhook signing secret.
"""

resp = self._client._request("GET", "/v1/webhooks/default/secret")
return WebhookSigningSecret(**resp.json())

async def async_secret(self) -> WebhookSigningSecret:
"""
Get the default webhook signing secret.

Returns:
WebhookSigningSecret: The default webhook signing secret.
"""

resp = await self._client._async_request(
"GET", "/v1/webhooks/default/secret"
)
return WebhookSigningSecret(**resp.json())

@overload
@staticmethod
def validate(
request: "httpx.Request",
secret: WebhookSigningSecret,
tolerance: Optional[int] = None,
) -> bool: ...

@overload
@staticmethod
def validate(
headers: Dict[str, str],
body: str,
secret: WebhookSigningSecret,
tolerance: Optional[int] = None,
) -> bool: ...

@staticmethod
def validate( # type: ignore
request: Optional["httpx.Request"] = None,
headers: Optional[Dict[str, str]] = None,
body: Optional[str] = None,
secret: Optional[WebhookSigningSecret] = None,
tolerance: Optional[int] = None,
) -> None:
"""
Validate the signature from an incoming webhook request using the provided secret.

Args:
request (httpx.Request): The request object.
headers (Dict[str, str]): The request headers.
body (str): The request body.
secret (WebhookSigningSecret): The webhook signing secret.
tolerance (Optional[int]): Maximum allowed time difference (in seconds) between the current time and the webhook timestamp.

Returns:
None: If the request is valid.

Raises:
MissingWebhookHeaderError: If required webhook headers are missing.
InvalidSecretKeyError: If the secret key format is invalid.
MissingWebhookBodyError: If the webhook body is missing.
InvalidTimestampError: If the webhook timestamp is invalid or outside the tolerance.
InvalidSignatureError: If the webhook signature is invalid.
"""

if not secret:
raise ValueError("Missing webhook signing secret")

if request and any([headers, body]):
raise ValueError("Only one of request or headers/body can be provided")

if request and request.headers:
webhook_id = request.headers.get("webhook-id")
timestamp = request.headers.get("webhook-timestamp")
signature = request.headers.get("webhook-signature")
body = request.content.decode("utf-8")
else:
if not headers:
raise MissingWebhookHeaderError("Missing webhook headers")

# Convert headers to case-insensitive dictionary
headers = {k.lower(): v for k, v in headers.items()}

webhook_id = headers.get("webhook-id")
timestamp = headers.get("webhook-timestamp")
signature = headers.get("webhook-signature")

if not webhook_id:
raise MissingWebhookHeaderError("Missing webhook id")
if not timestamp:
raise MissingWebhookHeaderError("Missing webhook timestamp")
if not signature:
raise MissingWebhookHeaderError("Missing webhook signature")
if not body:
raise MissingWebhookBodyError("Missing webhook body")

if tolerance is not None:
import time # pylint: disable=import-outside-toplevel

current_time = int(time.time())
webhook_time = int(timestamp)
time_difference = abs(current_time - webhook_time)
if time_difference > tolerance:
raise InvalidTimestampError(
f"Webhook timestamp is outside the allowed tolerance of {tolerance} seconds"
)

signed_content = f"{webhook_id}.{timestamp}.{body}"

key_parts = secret.key.split("_")
if len(key_parts) != 2:
raise InvalidSecretKeyError(f"Invalid secret key format: {secret.key}")

secret_bytes = base64.b64decode(key_parts[1])

h = hmac.new(secret_bytes, signed_content.encode(), sha256)
computed_signature = h.digest()

valid = False
for sig in signature.split():
sig_parts = sig.split(",")
if len(sig_parts) < 2:
raise InvalidSignatureError(f"Invalid signature format: {sig}")

sig_bytes = base64.b64decode(sig_parts[1])

if hmac.compare_digest(sig_bytes, computed_signature):
valid = True
break

if not valid:
raise InvalidSignatureError("Webhook signature is invalid")
162 changes: 162 additions & 0 deletions tests/test_webhooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import time

import httpx
import pytest
import respx
from httpx import Request

from replicate.client import Client
from replicate.webhook import (
InvalidSecretKeyError,
InvalidSignatureError,
InvalidTimestampError,
MissingWebhookBodyError,
MissingWebhookHeaderError,
Webhooks,
WebhookSigningSecret,
)


@pytest.fixture
def webhook_signing_secret():
# This is a test secret and should not be used in production
return WebhookSigningSecret(key="whsec_MfKQ9r8GKYqrTwjUPD8ILPZIo2LaLaSw")


@pytest.mark.asyncio
@pytest.mark.parametrize("async_flag", [True, False])
@respx.mock
async def test_get_webhook_secret(async_flag, webhook_signing_secret):
respx.get("https://api.replicate.com/v1/webhooks/default/secret").mock(
return_value=httpx.Response(200, json={"key": webhook_signing_secret.key})
)

client = Client(api_token="test-token")

if async_flag:
secret = await client.webhooks.default.async_secret()
else:
secret = client.webhooks.default.secret()

assert isinstance(secret, WebhookSigningSecret)
assert secret.key == webhook_signing_secret.key

body = '{"test": 2432232314}'
headers = {
"Content-Type": "application/json",
"Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek",
"Webhook-Timestamp": "1614265330",
"Webhook-Signature": "v1,g0hM9SsE+OTPJTGt/tmIKtSyZlE3uFJELVlNIOLJ1OE=",
}

request = Request(
method="POST",
url="http://test.host/webhook",
headers=headers,
content=body.encode(),
)

client.webhooks.validate(request=request, secret=secret)


def test_validate_webhook_invalid_signature(webhook_signing_secret):
headers = {
"Content-Type": "application/json",
"Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek",
"Webhook-Timestamp": "1614265330",
"Webhook-Signature": "v1,invalid_signature",
}
body = '{"test": 2432232314}'

with pytest.raises(InvalidSignatureError, match="Webhook signature is invalid"):
Webhooks.validate(headers=headers, body=body, secret=webhook_signing_secret)


def test_validate_webhook_missing_webhook_id(webhook_signing_secret):
headers = {
"Content-Type": "application/json",
}
body = '{"test": 2432232314}'

with pytest.raises(MissingWebhookHeaderError, match="Missing webhook id"):
Webhooks.validate(headers=headers, body=body, secret=webhook_signing_secret)


def test_validate_webhook_invalid_secret():
headers = {
"Content-Type": "application/json",
"Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek",
"Webhook-Timestamp": "1614265330",
"Webhook-Signature": "v1,invalid_signature",
}
body = '{"test": 2432232314}'

with pytest.raises(InvalidSecretKeyError, match="Invalid secret key format"):
Webhooks.validate(
headers=headers,
body=body,
secret=WebhookSigningSecret(key="invalid_secret_format"),
)


def test_validate_webhook_missing_headers(webhook_signing_secret):
headers = None
body = '{"test": 2432232314}'

with pytest.raises(MissingWebhookHeaderError, match="Missing webhook headers"):
Webhooks.validate(
headers=headers, # type: ignore
body=body,
secret=webhook_signing_secret,
)


def test_validate_webhook_missing_body(webhook_signing_secret):
headers = {
"Content-Type": "application/json",
"Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek",
"Webhook-Timestamp": "1614265330",
"Webhook-Signature": "v1,g0hM9SsE+OTPJTGt/tmIKtSyZlE3uFJELVlNIOLJ1OE=",
}
body = None

with pytest.raises(MissingWebhookBodyError, match="Missing webhook body"):
Webhooks.validate(
headers=headers,
body=body, # type: ignore
secret=webhook_signing_secret,
)


@pytest.mark.parametrize(
"timestamp_offset, timestamp_invalid",
[
(-3601, True), # 1 hour and 1 second in the past
(3601, True), # 1 hour and 1 second in the future
(-3599, False), # 59 minutes and 59 seconds in the past (valid)
(3599, False), # 59 minutes and 59 seconds in the future (valid)
],
)
def test_validate_webhook_timestamp(
webhook_signing_secret, timestamp_offset, timestamp_invalid
):
current_time = int(time.time())
timestamp = str(current_time + timestamp_offset)

headers = {
"Content-Type": "application/json",
"Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek",
"Webhook-Timestamp": timestamp,
"Webhook-Signature": "v1,invalid_signature",
}
body = '{"test": 2432232314}'

with pytest.raises(
(InvalidTimestampError if timestamp_invalid else InvalidSignatureError),
):
Webhooks.validate(
headers=headers,
body=body,
secret=webhook_signing_secret,
tolerance=3600,
)