Skip to content

Commit b8196f5

Browse files
authored
Add support for validating webhooks (#321)
Resolves #312 --------- Signed-off-by: Mattt Zmuda <[email protected]>
1 parent ec459ed commit b8196f5

File tree

3 files changed

+361
-0
lines changed

3 files changed

+361
-0
lines changed

replicate/client.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from replicate.run import async_run, run
3131
from replicate.stream import async_stream, stream
3232
from replicate.training import Trainings
33+
from replicate.webhook import Webhooks
3334

3435
if TYPE_CHECKING:
3536
from replicate.stream import ServerSentEvent
@@ -144,6 +145,13 @@ def trainings(self) -> Trainings:
144145
"""
145146
return Trainings(client=self)
146147

148+
@property
149+
def webhooks(self) -> Webhooks:
150+
"""
151+
Namespace for operations related to webhooks.
152+
"""
153+
return Webhooks(client=self)
154+
147155
def run(
148156
self,
149157
ref: str,

replicate/webhook.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import base64
2+
import hmac
3+
from hashlib import sha256
4+
from typing import (
5+
TYPE_CHECKING,
6+
Dict,
7+
Optional,
8+
overload,
9+
)
10+
11+
from replicate.resource import Namespace, Resource
12+
13+
if TYPE_CHECKING:
14+
import httpx
15+
16+
17+
class WebhookSigningSecret(Resource):
18+
key: str
19+
20+
21+
class WebhookValidationError(ValueError):
22+
"""Base exception for webhook validation errors."""
23+
24+
25+
class MissingWebhookHeaderError(WebhookValidationError):
26+
"""Exception raised when a required webhook header is missing."""
27+
28+
29+
class InvalidSecretKeyError(WebhookValidationError):
30+
"""Exception raised when the secret key format is invalid."""
31+
32+
33+
class MissingWebhookBodyError(WebhookValidationError):
34+
"""Exception raised when the webhook body is missing."""
35+
36+
37+
class InvalidTimestampError(WebhookValidationError):
38+
"""Exception raised when the webhook timestamp is invalid or outside the tolerance."""
39+
40+
41+
class InvalidSignatureError(WebhookValidationError):
42+
"""Exception raised when the webhook signature is invalid."""
43+
44+
45+
class Webhooks(Namespace):
46+
@property
47+
def default(self) -> "Webhooks.Default":
48+
"""
49+
Namespace for operations related to the default webhook.
50+
"""
51+
52+
return self.Default(self._client)
53+
54+
class Default(Namespace):
55+
def secret(self) -> WebhookSigningSecret:
56+
"""
57+
Get the default webhook signing secret.
58+
59+
Returns:
60+
WebhookSigningSecret: The default webhook signing secret.
61+
"""
62+
63+
resp = self._client._request("GET", "/v1/webhooks/default/secret")
64+
return WebhookSigningSecret(**resp.json())
65+
66+
async def async_secret(self) -> WebhookSigningSecret:
67+
"""
68+
Get the default webhook signing secret.
69+
70+
Returns:
71+
WebhookSigningSecret: The default webhook signing secret.
72+
"""
73+
74+
resp = await self._client._async_request(
75+
"GET", "/v1/webhooks/default/secret"
76+
)
77+
return WebhookSigningSecret(**resp.json())
78+
79+
@overload
80+
@staticmethod
81+
def validate(
82+
request: "httpx.Request",
83+
secret: WebhookSigningSecret,
84+
tolerance: Optional[int] = None,
85+
) -> bool: ...
86+
87+
@overload
88+
@staticmethod
89+
def validate(
90+
headers: Dict[str, str],
91+
body: str,
92+
secret: WebhookSigningSecret,
93+
tolerance: Optional[int] = None,
94+
) -> bool: ...
95+
96+
@staticmethod
97+
def validate( # type: ignore
98+
request: Optional["httpx.Request"] = None,
99+
headers: Optional[Dict[str, str]] = None,
100+
body: Optional[str] = None,
101+
secret: Optional[WebhookSigningSecret] = None,
102+
tolerance: Optional[int] = None,
103+
) -> None:
104+
"""
105+
Validate the signature from an incoming webhook request using the provided secret.
106+
107+
Args:
108+
request (httpx.Request): The request object.
109+
headers (Dict[str, str]): The request headers.
110+
body (str): The request body.
111+
secret (WebhookSigningSecret): The webhook signing secret.
112+
tolerance (Optional[int]): Maximum allowed time difference (in seconds) between the current time and the webhook timestamp.
113+
114+
Returns:
115+
None: If the request is valid.
116+
117+
Raises:
118+
MissingWebhookHeaderError: If required webhook headers are missing.
119+
InvalidSecretKeyError: If the secret key format is invalid.
120+
MissingWebhookBodyError: If the webhook body is missing.
121+
InvalidTimestampError: If the webhook timestamp is invalid or outside the tolerance.
122+
InvalidSignatureError: If the webhook signature is invalid.
123+
"""
124+
125+
if not secret:
126+
raise ValueError("Missing webhook signing secret")
127+
128+
if request and any([headers, body]):
129+
raise ValueError("Only one of request or headers/body can be provided")
130+
131+
if request and request.headers:
132+
webhook_id = request.headers.get("webhook-id")
133+
timestamp = request.headers.get("webhook-timestamp")
134+
signature = request.headers.get("webhook-signature")
135+
body = request.content.decode("utf-8")
136+
else:
137+
if not headers:
138+
raise MissingWebhookHeaderError("Missing webhook headers")
139+
140+
# Convert headers to case-insensitive dictionary
141+
headers = {k.lower(): v for k, v in headers.items()}
142+
143+
webhook_id = headers.get("webhook-id")
144+
timestamp = headers.get("webhook-timestamp")
145+
signature = headers.get("webhook-signature")
146+
147+
if not webhook_id:
148+
raise MissingWebhookHeaderError("Missing webhook id")
149+
if not timestamp:
150+
raise MissingWebhookHeaderError("Missing webhook timestamp")
151+
if not signature:
152+
raise MissingWebhookHeaderError("Missing webhook signature")
153+
if not body:
154+
raise MissingWebhookBodyError("Missing webhook body")
155+
156+
if tolerance is not None:
157+
import time # pylint: disable=import-outside-toplevel
158+
159+
current_time = int(time.time())
160+
webhook_time = int(timestamp)
161+
time_difference = abs(current_time - webhook_time)
162+
if time_difference > tolerance:
163+
raise InvalidTimestampError(
164+
f"Webhook timestamp is outside the allowed tolerance of {tolerance} seconds"
165+
)
166+
167+
signed_content = f"{webhook_id}.{timestamp}.{body}"
168+
169+
key_parts = secret.key.split("_")
170+
if len(key_parts) != 2:
171+
raise InvalidSecretKeyError(f"Invalid secret key format: {secret.key}")
172+
173+
secret_bytes = base64.b64decode(key_parts[1])
174+
175+
h = hmac.new(secret_bytes, signed_content.encode(), sha256)
176+
computed_signature = h.digest()
177+
178+
valid = False
179+
for sig in signature.split():
180+
sig_parts = sig.split(",")
181+
if len(sig_parts) < 2:
182+
raise InvalidSignatureError(f"Invalid signature format: {sig}")
183+
184+
sig_bytes = base64.b64decode(sig_parts[1])
185+
186+
if hmac.compare_digest(sig_bytes, computed_signature):
187+
valid = True
188+
break
189+
190+
if not valid:
191+
raise InvalidSignatureError("Webhook signature is invalid")

tests/test_webhooks.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import time
2+
3+
import httpx
4+
import pytest
5+
import respx
6+
from httpx import Request
7+
8+
from replicate.client import Client
9+
from replicate.webhook import (
10+
InvalidSecretKeyError,
11+
InvalidSignatureError,
12+
InvalidTimestampError,
13+
MissingWebhookBodyError,
14+
MissingWebhookHeaderError,
15+
Webhooks,
16+
WebhookSigningSecret,
17+
)
18+
19+
20+
@pytest.fixture
21+
def webhook_signing_secret():
22+
# This is a test secret and should not be used in production
23+
return WebhookSigningSecret(key="whsec_MfKQ9r8GKYqrTwjUPD8ILPZIo2LaLaSw")
24+
25+
26+
@pytest.mark.asyncio
27+
@pytest.mark.parametrize("async_flag", [True, False])
28+
@respx.mock
29+
async def test_get_webhook_secret(async_flag, webhook_signing_secret):
30+
respx.get("https://api.replicate.com/v1/webhooks/default/secret").mock(
31+
return_value=httpx.Response(200, json={"key": webhook_signing_secret.key})
32+
)
33+
34+
client = Client(api_token="test-token")
35+
36+
if async_flag:
37+
secret = await client.webhooks.default.async_secret()
38+
else:
39+
secret = client.webhooks.default.secret()
40+
41+
assert isinstance(secret, WebhookSigningSecret)
42+
assert secret.key == webhook_signing_secret.key
43+
44+
body = '{"test": 2432232314}'
45+
headers = {
46+
"Content-Type": "application/json",
47+
"Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek",
48+
"Webhook-Timestamp": "1614265330",
49+
"Webhook-Signature": "v1,g0hM9SsE+OTPJTGt/tmIKtSyZlE3uFJELVlNIOLJ1OE=",
50+
}
51+
52+
request = Request(
53+
method="POST",
54+
url="http://test.host/webhook",
55+
headers=headers,
56+
content=body.encode(),
57+
)
58+
59+
client.webhooks.validate(request=request, secret=secret)
60+
61+
62+
def test_validate_webhook_invalid_signature(webhook_signing_secret):
63+
headers = {
64+
"Content-Type": "application/json",
65+
"Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek",
66+
"Webhook-Timestamp": "1614265330",
67+
"Webhook-Signature": "v1,invalid_signature",
68+
}
69+
body = '{"test": 2432232314}'
70+
71+
with pytest.raises(InvalidSignatureError, match="Webhook signature is invalid"):
72+
Webhooks.validate(headers=headers, body=body, secret=webhook_signing_secret)
73+
74+
75+
def test_validate_webhook_missing_webhook_id(webhook_signing_secret):
76+
headers = {
77+
"Content-Type": "application/json",
78+
}
79+
body = '{"test": 2432232314}'
80+
81+
with pytest.raises(MissingWebhookHeaderError, match="Missing webhook id"):
82+
Webhooks.validate(headers=headers, body=body, secret=webhook_signing_secret)
83+
84+
85+
def test_validate_webhook_invalid_secret():
86+
headers = {
87+
"Content-Type": "application/json",
88+
"Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek",
89+
"Webhook-Timestamp": "1614265330",
90+
"Webhook-Signature": "v1,invalid_signature",
91+
}
92+
body = '{"test": 2432232314}'
93+
94+
with pytest.raises(InvalidSecretKeyError, match="Invalid secret key format"):
95+
Webhooks.validate(
96+
headers=headers,
97+
body=body,
98+
secret=WebhookSigningSecret(key="invalid_secret_format"),
99+
)
100+
101+
102+
def test_validate_webhook_missing_headers(webhook_signing_secret):
103+
headers = None
104+
body = '{"test": 2432232314}'
105+
106+
with pytest.raises(MissingWebhookHeaderError, match="Missing webhook headers"):
107+
Webhooks.validate(
108+
headers=headers, # type: ignore
109+
body=body,
110+
secret=webhook_signing_secret,
111+
)
112+
113+
114+
def test_validate_webhook_missing_body(webhook_signing_secret):
115+
headers = {
116+
"Content-Type": "application/json",
117+
"Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek",
118+
"Webhook-Timestamp": "1614265330",
119+
"Webhook-Signature": "v1,g0hM9SsE+OTPJTGt/tmIKtSyZlE3uFJELVlNIOLJ1OE=",
120+
}
121+
body = None
122+
123+
with pytest.raises(MissingWebhookBodyError, match="Missing webhook body"):
124+
Webhooks.validate(
125+
headers=headers,
126+
body=body, # type: ignore
127+
secret=webhook_signing_secret,
128+
)
129+
130+
131+
@pytest.mark.parametrize(
132+
"timestamp_offset, timestamp_invalid",
133+
[
134+
(-3601, True), # 1 hour and 1 second in the past
135+
(3601, True), # 1 hour and 1 second in the future
136+
(-3599, False), # 59 minutes and 59 seconds in the past (valid)
137+
(3599, False), # 59 minutes and 59 seconds in the future (valid)
138+
],
139+
)
140+
def test_validate_webhook_timestamp(
141+
webhook_signing_secret, timestamp_offset, timestamp_invalid
142+
):
143+
current_time = int(time.time())
144+
timestamp = str(current_time + timestamp_offset)
145+
146+
headers = {
147+
"Content-Type": "application/json",
148+
"Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek",
149+
"Webhook-Timestamp": timestamp,
150+
"Webhook-Signature": "v1,invalid_signature",
151+
}
152+
body = '{"test": 2432232314}'
153+
154+
with pytest.raises(
155+
(InvalidTimestampError if timestamp_invalid else InvalidSignatureError),
156+
):
157+
Webhooks.validate(
158+
headers=headers,
159+
body=body,
160+
secret=webhook_signing_secret,
161+
tolerance=3600,
162+
)

0 commit comments

Comments
 (0)