Skip to content

Misc fixes in auth #157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
58 changes: 58 additions & 0 deletions src/client/auth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,64 @@ describe("OAuth Authorization", () => {
});
});

it("returns metadata when first fetch fails but second without MCP header succeeds", async () => {
// Set up a counter to control behavior
let callCount = 0;

// Mock implementation that changes behavior based on call count
mockFetch.mockImplementation((_url, _options) => {
callCount++;

if (callCount === 1) {
// First call with MCP header - fail with TypeError (simulating CORS error)
// We need to use TypeError specifically because that's what the implementation checks for
return Promise.reject(new TypeError("Network error"));
} else {
// Second call without header - succeed
return Promise.resolve({
ok: true,
status: 200,
json: async () => validMetadata
});
}
});

// Should succeed with the second call
const metadata = await discoverOAuthMetadata("https://auth.example.com");
expect(metadata).toEqual(validMetadata);

// Verify both calls were made
expect(mockFetch).toHaveBeenCalledTimes(2);

// Verify first call had MCP header
expect(mockFetch.mock.calls[0][1]?.headers).toHaveProperty("MCP-Protocol-Version");
});

it("throws an error when all fetch attempts fail", async () => {
// Set up a counter to control behavior
let callCount = 0;

// Mock implementation that changes behavior based on call count
mockFetch.mockImplementation((_url, _options) => {
callCount++;

if (callCount === 1) {
// First call - fail with TypeError
return Promise.reject(new TypeError("First failure"));
} else {
// Second call - fail with different error
return Promise.reject(new Error("Second failure"));
}
});

// Should fail with the second error
await expect(discoverOAuthMetadata("https://auth.example.com"))
.rejects.toThrow("Second failure");

// Verify both calls were made
expect(mockFetch).toHaveBeenCalledTimes(2);
});

it("returns undefined when discovery endpoint returns 404", async () => {
mockFetch.mockResolvedValueOnce({
ok: false,
Expand Down
18 changes: 14 additions & 4 deletions src/client/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,21 @@ export async function discoverOAuthMetadata(
opts?: { protocolVersion?: string },
): Promise<OAuthMetadata | undefined> {
const url = new URL("/.well-known/oauth-authorization-server", serverUrl);
const response = await fetch(url, {
headers: {
"MCP-Protocol-Version": opts?.protocolVersion ?? LATEST_PROTOCOL_VERSION
let response: Response;
try {
response = await fetch(url, {
headers: {
"MCP-Protocol-Version": opts?.protocolVersion ?? LATEST_PROTOCOL_VERSION
}
});
} catch (error) {
// CORS errors come back as TypeError
if (error instanceof TypeError) {
response = await fetch(url);
} else {
throw error;
}
});
}

if (response.status === 404) {
return undefined;
Expand Down
31 changes: 31 additions & 0 deletions src/server/auth/handlers/register.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,37 @@ describe('Client Registration Handler', () => {

expect(response.status).toBe(201);
expect(response.body.client_secret).toBeUndefined();
expect(response.body.client_secret_expires_at).toBeUndefined();
});

it('sets client_secret_expires_at for public clients only', async () => {
// Test for public client (token_endpoint_auth_method not 'none')
const publicClientMetadata: OAuthClientMetadata = {
redirect_uris: ['https://example.com/callback'],
token_endpoint_auth_method: 'client_secret_basic'
};

const publicResponse = await supertest(app)
.post('/register')
.send(publicClientMetadata);

expect(publicResponse.status).toBe(201);
expect(publicResponse.body.client_secret).toBeDefined();
expect(publicResponse.body.client_secret_expires_at).toBeDefined();

// Test for non-public client (token_endpoint_auth_method is 'none')
const nonPublicClientMetadata: OAuthClientMetadata = {
redirect_uris: ['https://example.com/callback'],
token_endpoint_auth_method: 'none'
};

const nonPublicResponse = await supertest(app)
.post('/register')
.send(nonPublicClientMetadata);

expect(nonPublicResponse.status).toBe(201);
expect(nonPublicResponse.body.client_secret).toBeUndefined();
expect(nonPublicResponse.body.client_secret_expires_at).toBeUndefined();
});

it('sets expiry based on clientSecretExpirySeconds', async () => {
Expand Down
14 changes: 10 additions & 4 deletions src/server/auth/handlers/register.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,26 @@ export function clientRegistrationHandler({
}

const clientMetadata = parseResult.data;
const isPublicClient = clientMetadata.token_endpoint_auth_method === 'none'

// Generate client credentials
const clientId = crypto.randomUUID();
const clientSecret = clientMetadata.token_endpoint_auth_method !== 'none'
? crypto.randomBytes(32).toString('hex')
: undefined;
const clientSecret = isPublicClient
? undefined
: crypto.randomBytes(32).toString('hex');
const clientIdIssuedAt = Math.floor(Date.now() / 1000);

// Calculate client secret expiry time
const clientsDoExpire = clientSecretExpirySeconds > 0
const secretExpiryTime = clientsDoExpire ? clientIdIssuedAt + clientSecretExpirySeconds : 0
const clientSecretExpiresAt = isPublicClient ? undefined : secretExpiryTime

let clientInfo: OAuthClientInformationFull = {
...clientMetadata,
client_id: clientId,
client_secret: clientSecret,
client_id_issued_at: clientIdIssuedAt,
client_secret_expires_at: clientSecretExpirySeconds > 0 ? clientIdIssuedAt + clientSecretExpirySeconds : 0
client_secret_expires_at: clientSecretExpiresAt,
};

clientInfo = await clientsStore.registerClient!(clientInfo);
Expand Down
51 changes: 51 additions & 0 deletions src/server/auth/middleware/bearerAuth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,57 @@ describe("requireBearerAuth middleware", () => {
expect(mockResponse.status).not.toHaveBeenCalled();
expect(mockResponse.json).not.toHaveBeenCalled();
});

it("should reject expired tokens", async () => {
const expiredAuthInfo: AuthInfo = {
token: "expired-token",
clientId: "client-123",
scopes: ["read", "write"],
expiresAt: Math.floor(Date.now() / 1000) - 100, // Token expired 100 seconds ago
};
mockVerifyAccessToken.mockResolvedValue(expiredAuthInfo);

mockRequest.headers = {
authorization: "Bearer expired-token",
};

const middleware = requireBearerAuth({ provider: mockProvider });
await middleware(mockRequest as Request, mockResponse as Response, nextFunction);

expect(mockVerifyAccessToken).toHaveBeenCalledWith("expired-token");
expect(mockResponse.status).toHaveBeenCalledWith(401);
expect(mockResponse.set).toHaveBeenCalledWith(
"WWW-Authenticate",
expect.stringContaining('Bearer error="invalid_token"')
);
expect(mockResponse.json).toHaveBeenCalledWith(
expect.objectContaining({ error: "invalid_token", error_description: "Token has expired" })
);
expect(nextFunction).not.toHaveBeenCalled();
});

it("should accept non-expired tokens", async () => {
const nonExpiredAuthInfo: AuthInfo = {
token: "valid-token",
clientId: "client-123",
scopes: ["read", "write"],
expiresAt: Math.floor(Date.now() / 1000) + 3600, // Token expires in an hour
};
mockVerifyAccessToken.mockResolvedValue(nonExpiredAuthInfo);

mockRequest.headers = {
authorization: "Bearer valid-token",
};

const middleware = requireBearerAuth({ provider: mockProvider });
await middleware(mockRequest as Request, mockResponse as Response, nextFunction);

expect(mockVerifyAccessToken).toHaveBeenCalledWith("valid-token");
expect(mockRequest.auth).toEqual(nonExpiredAuthInfo);
expect(nextFunction).toHaveBeenCalled();
expect(mockResponse.status).not.toHaveBeenCalled();
expect(mockResponse.json).not.toHaveBeenCalled();
});

it("should require specific scopes when configured", async () => {
const authInfo: AuthInfo = {
Expand Down
5 changes: 5 additions & 0 deletions src/server/auth/middleware/bearerAuth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ export function requireBearerAuth({ provider, requiredScopes = [] }: BearerAuthM
}
}

// Check if the token is expired
if (!!authInfo.expiresAt && authInfo.expiresAt < Date.now() / 1000) {
throw new InvalidTokenError("Token has expired");
}

req.auth = authInfo;
next();
} catch (error) {
Expand Down