Skip to content

Misc fixes in auth #157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
30 changes: 30 additions & 0 deletions src/client/auth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,36 @@ describe("OAuth Authorization", () => {
});
});

it("returns metadata when first fetch fails but second without MCP header succeeds", async () => {
// First request with MCP header fails
mockFetch.mockRejectedValueOnce(new Error("Network error"));

// Second request without header succeeds
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
json: async () => validMetadata,
});

const metadata = await discoverOAuthMetadata("https://auth.example.com");
expect(metadata).toEqual(validMetadata);

// Verify second call was made without header
expect(mockFetch).toHaveBeenCalledTimes(2);
const secondCallOptions = mockFetch.mock.calls[1][1];
expect(secondCallOptions).toBeUndefined(); // No options means no headers
});

it("returns undefined when all fetch attempts fail", async () => {
// Both requests fail
mockFetch.mockRejectedValueOnce(new Error("Network error"));
mockFetch.mockRejectedValueOnce(new Error("Network error"));

const metadata = await discoverOAuthMetadata("https://auth.example.com");
expect(metadata).toBeUndefined();
expect(mockFetch).toHaveBeenCalledTimes(2);
});

it("returns undefined when discovery endpoint returns 404", async () => {
mockFetch.mockResolvedValueOnce({
ok: false,
Expand Down
17 changes: 13 additions & 4 deletions src/client/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,20 @@ export async function discoverOAuthMetadata(
opts?: { protocolVersion?: string },
): Promise<OAuthMetadata | undefined> {
const url = new URL("/.well-known/oauth-authorization-server", serverUrl);
const response = await fetch(url, {
headers: {
"MCP-Protocol-Version": opts?.protocolVersion ?? LATEST_PROTOCOL_VERSION
let response: Response;
try {
response = await fetch(url, {
headers: {
"MCP-Protocol-Version": opts?.protocolVersion ?? LATEST_PROTOCOL_VERSION
}
});
} catch {
try {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to conditionalize this based on what the error actually was?

response = await fetch(url);
} catch {
return undefined;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure you don't want to log/throw the exception?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be fine to let this error propagate if there's an actual network error (or otherwise), rather than a 404.

}
});
}

if (response.status === 404) {
return undefined;
Expand Down
31 changes: 31 additions & 0 deletions src/server/auth/handlers/register.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,37 @@ describe('Client Registration Handler', () => {

expect(response.status).toBe(201);
expect(response.body.client_secret).toBeUndefined();
expect(response.body.client_secret_expires_at).toBeUndefined();
});

it('sets client_secret_expires_at for public clients only', async () => {
// Test for public client (token_endpoint_auth_method not 'none')
const publicClientMetadata: OAuthClientMetadata = {
redirect_uris: ['https://example.com/callback'],
token_endpoint_auth_method: 'client_secret_basic'
};

const publicResponse = await supertest(app)
.post('/register')
.send(publicClientMetadata);

expect(publicResponse.status).toBe(201);
expect(publicResponse.body.client_secret).toBeDefined();
expect(publicResponse.body.client_secret_expires_at).toBeDefined();

// Test for non-public client (token_endpoint_auth_method is 'none')
const nonPublicClientMetadata: OAuthClientMetadata = {
redirect_uris: ['https://example.com/callback'],
token_endpoint_auth_method: 'none'
};

const nonPublicResponse = await supertest(app)
.post('/register')
.send(nonPublicClientMetadata);

expect(nonPublicResponse.status).toBe(201);
expect(nonPublicResponse.body.client_secret).toBeUndefined();
expect(nonPublicResponse.body.client_secret_expires_at).toBeUndefined();
});

it('sets expiry based on clientSecretExpirySeconds', async () => {
Expand Down
9 changes: 7 additions & 2 deletions src/server/auth/handlers/register.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,11 @@ export function clientRegistrationHandler({
}

const clientMetadata = parseResult.data;
const isPublicClient = clientMetadata.token_endpoint_auth_method !== 'none'

// Generate client credentials
const clientId = crypto.randomUUID();
const clientSecret = clientMetadata.token_endpoint_auth_method !== 'none'
const clientSecret = isPublicClient
? crypto.randomBytes(32).toString('hex')
: undefined;
const clientIdIssuedAt = Math.floor(Date.now() / 1000);
Expand All @@ -88,7 +89,11 @@ export function clientRegistrationHandler({
client_id: clientId,
client_secret: clientSecret,
client_id_issued_at: clientIdIssuedAt,
client_secret_expires_at: clientSecretExpirySeconds > 0 ? clientIdIssuedAt + clientSecretExpirySeconds : 0
client_secret_expires_at: isPublicClient
? clientSecretExpirySeconds > 0
? clientIdIssuedAt + clientSecretExpirySeconds
: 0
: undefined,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nested ternaries is a bit much—can we pull some of this out into one or two intermediate variables?

};

clientInfo = await clientsStore.registerClient!(clientInfo);
Expand Down
51 changes: 51 additions & 0 deletions src/server/auth/middleware/bearerAuth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,57 @@
expect(mockResponse.status).not.toHaveBeenCalled();
expect(mockResponse.json).not.toHaveBeenCalled();
});

it("should reject expired tokens", async () => {
const expiredAuthInfo: AuthInfo = {
token: "expired-token",
clientId: "client-123",
scopes: ["read", "write"],
expiresAt: Math.floor(Date.now() / 1000) - 100, // Token expired 100 seconds ago
};
mockVerifyAccessToken.mockResolvedValue(expiredAuthInfo);

mockRequest.headers = {
authorization: "Bearer expired-token",
};

const middleware = requireBearerAuth({ provider: mockProvider });
await middleware(mockRequest as Request, mockResponse as Response, nextFunction);

expect(mockVerifyAccessToken).toHaveBeenCalledWith("expired-token");
expect(mockResponse.status).toHaveBeenCalledWith(401);
expect(mockResponse.set).toHaveBeenCalledWith(
"WWW-Authenticate",
expect.stringContaining('Bearer error="invalid_token"')
);
expect(mockResponse.json).toHaveBeenCalledWith(
expect.objectContaining({ error: "invalid_token", error_description: "Token has expired" })
);
expect(nextFunction).not.toHaveBeenCalled();
});

it("should accept non-expired tokens", async () => {
const nonExpiredAuthInfo: AuthInfo = {
token: "valid-token",
clientId: "client-123",
scopes: ["read", "write"],
expiresAt: Math.floor(Date.now() / 1000) + 3600, // Token expires in an hour
};
mockVerifyAccessToken.mockResolvedValue(nonExpiredAuthInfo);

mockRequest.headers = {
authorization: "Bearer valid-token",
};

const middleware = requireBearerAuth({ provider: mockProvider });
await middleware(mockRequest as Request, mockResponse as Response, nextFunction);

expect(mockVerifyAccessToken).toHaveBeenCalledWith("valid-token");
expect(mockRequest.auth).toEqual(nonExpiredAuthInfo);
expect(nextFunction).toHaveBeenCalled();
expect(mockResponse.status).not.toHaveBeenCalled();
expect(mockResponse.json).not.toHaveBeenCalled();
});

it("should require specific scopes when configured", async () => {
const authInfo: AuthInfo = {
Expand Down
5 changes: 5 additions & 0 deletions src/server/auth/middleware/bearerAuth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ export function requireBearerAuth({ provider, requiredScopes = [] }: BearerAuthM
}
}

// Check if the token is expired
if (!!authInfo.expiresAt && authInfo.expiresAt < Date.now() / 1000) {
throw new InvalidTokenError("Token has expired");
}

req.auth = authInfo;
next();
} catch (error) {
Expand Down