diff --git a/src/client/auth.test.ts b/src/client/auth.test.ts index 48be870b..82d55909 100644 --- a/src/client/auth.test.ts +++ b/src/client/auth.test.ts @@ -414,6 +414,22 @@ describe("OAuth Authorization", () => { }); describe("exchangeAuthorization", () => { + const mockProvider: OAuthClientProvider = { + get redirectUrl() { return "http://localhost:3000/callback"; }, + get clientMetadata() { + return { + redirect_uris: ["http://localhost:3000/callback"], + client_name: "Test Client", + }; + }, + clientInformation: jest.fn(), + tokens: jest.fn(), + saveTokens: jest.fn(), + redirectToAuthorization: jest.fn(), + saveCodeVerifier: jest.fn(), + codeVerifier: jest.fn(), + }; + const validTokens = { access_token: "access123", token_type: "Bearer", @@ -449,12 +465,11 @@ describe("OAuth Authorization", () => { }), expect.objectContaining({ method: "POST", - headers: { - "Content-Type": "application/x-www-form-urlencoded", - }, }) ); + const headers = mockFetch.mock.calls[0][1].headers as Headers; + expect(headers.get("Content-Type")).toBe("application/x-www-form-urlencoded"); const body = mockFetch.mock.calls[0][1].body as URLSearchParams; expect(body.get("grant_type")).toBe("authorization_code"); expect(body.get("code")).toBe("code123"); @@ -464,6 +479,48 @@ describe("OAuth Authorization", () => { expect(body.get("redirect_uri")).toBe("http://localhost:3000/callback"); }); + it("exchanges code for tokens with auth", async () => { + mockProvider.authToTokenEndpoint = function(headers: Headers, params: URLSearchParams) { + headers.set("Authorization", "Basic " + btoa(validClientInfo.client_id + ":" + validClientInfo.client_secret)); + params.set("example_param", "example_value") + }; + + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens, + }); + + const tokens = await exchangeAuthorization("https://auth.example.com", { + clientInformation: validClientInfo, + authorizationCode: "code123", + codeVerifier: "verifier123", + redirectUri: "http://localhost:3000/callback", + }, mockProvider); + + expect(tokens).toEqual(validTokens); + expect(mockFetch).toHaveBeenCalledWith( + expect.objectContaining({ + href: "https://auth.example.com/token", + }), + expect.objectContaining({ + method: "POST", + }) + ); + + const headers = mockFetch.mock.calls[0][1].headers as Headers; + expect(headers.get("Content-Type")).toBe("application/x-www-form-urlencoded"); + expect(headers.get("Authorization")).toBe("Basic Y2xpZW50MTIzOnNlY3JldDEyMw=="); + const body = mockFetch.mock.calls[0][1].body as URLSearchParams; + expect(body.get("grant_type")).toBe("authorization_code"); + expect(body.get("code")).toBe("code123"); + expect(body.get("code_verifier")).toBe("verifier123"); + expect(body.get("client_id")).toBe("client123"); + expect(body.get("redirect_uri")).toBe("http://localhost:3000/callback"); + expect(body.get("example_param")).toBe("example_value"); + expect(body.get("client_secret")).toBeUndefined; + }); + it("validates token response schema", async () => { mockFetch.mockResolvedValueOnce({ ok: true, @@ -502,6 +559,22 @@ describe("OAuth Authorization", () => { }); describe("refreshAuthorization", () => { + const mockProvider: OAuthClientProvider = { + get redirectUrl() { return "http://localhost:3000/callback"; }, + get clientMetadata() { + return { + redirect_uris: ["http://localhost:3000/callback"], + client_name: "Test Client", + }; + }, + clientInformation: jest.fn(), + tokens: jest.fn(), + saveTokens: jest.fn(), + redirectToAuthorization: jest.fn(), + saveCodeVerifier: jest.fn(), + codeVerifier: jest.fn(), + }; + const validTokens = { access_token: "newaccess123", token_type: "Bearer", @@ -538,12 +611,11 @@ describe("OAuth Authorization", () => { }), expect.objectContaining({ method: "POST", - headers: { - "Content-Type": "application/x-www-form-urlencoded", - }, }) ); + const headers = mockFetch.mock.calls[0][1].headers as Headers; + expect(headers.get("Content-Type")).toBe("application/x-www-form-urlencoded"); const body = mockFetch.mock.calls[0][1].body as URLSearchParams; expect(body.get("grant_type")).toBe("refresh_token"); expect(body.get("refresh_token")).toBe("refresh123"); @@ -551,6 +623,44 @@ describe("OAuth Authorization", () => { expect(body.get("client_secret")).toBe("secret123"); }); + it("exchanges refresh token for new tokens with auth", async () => { + mockProvider.authToTokenEndpoint = function(headers: Headers, params: URLSearchParams) { + headers.set("Authorization", "Basic " + btoa(validClientInfo.client_id + ":" + validClientInfo.client_secret)); + params.set("example_param", "example_value") + }; + + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokensWithNewRefreshToken, + }); + + const tokens = await refreshAuthorization("https://auth.example.com", { + clientInformation: validClientInfo, + refreshToken: "refresh123", + }, mockProvider); + + expect(tokens).toEqual(validTokensWithNewRefreshToken); + expect(mockFetch).toHaveBeenCalledWith( + expect.objectContaining({ + href: "https://auth.example.com/token", + }), + expect.objectContaining({ + method: "POST", + }) + ); + + const headers = mockFetch.mock.calls[0][1].headers as Headers; + expect(headers.get("Content-Type")).toBe("application/x-www-form-urlencoded"); + expect(headers.get("Authorization")).toBe("Basic Y2xpZW50MTIzOnNlY3JldDEyMw=="); + const body = mockFetch.mock.calls[0][1].body as URLSearchParams; + expect(body.get("grant_type")).toBe("refresh_token"); + expect(body.get("refresh_token")).toBe("refresh123"); + expect(body.get("client_id")).toBe("client123"); + expect(body.get("example_param")).toBe("example_value"); + expect(body.get("client_secret")).toBeUndefined; + }); + it("exchanges refresh token for new tokens and keep existing refresh token if none is returned", async () => { mockFetch.mockResolvedValueOnce({ ok: true, diff --git a/src/client/auth.ts b/src/client/auth.ts index 16f0a550..827f23f7 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -66,6 +66,8 @@ export interface OAuthClientProvider { * the authorization result. */ codeVerifier(): string | Promise; + + authToTokenEndpoint?(headers: Headers, params: URLSearchParams): void | Promise; } export type AuthResult = "AUTHORIZED" | "REDIRECT"; @@ -137,7 +139,7 @@ export async function auth( authorizationCode, codeVerifier, redirectUri: provider.redirectUrl, - }); + }, provider); await provider.saveTokens(tokens); return "AUTHORIZED"; @@ -153,7 +155,7 @@ export async function auth( metadata, clientInformation, refreshToken: tokens.refresh_token, - }); + }, provider); await provider.saveTokens(newTokens); return "AUTHORIZED"; @@ -372,6 +374,7 @@ export async function exchangeAuthorization( codeVerifier: string; redirectUri: string | URL; }, + provider?: OAuthClientProvider ): Promise { const grantType = "authorization_code"; @@ -392,6 +395,9 @@ export async function exchangeAuthorization( } // Exchange code for tokens + const headers = new Headers({ + "Content-Type": "application/x-www-form-urlencoded", + }); const params = new URLSearchParams({ grant_type: grantType, client_id: clientInformation.client_id, @@ -400,15 +406,15 @@ export async function exchangeAuthorization( redirect_uri: String(redirectUri), }); - if (clientInformation.client_secret) { + if (provider?.authToTokenEndpoint) { + provider.authToTokenEndpoint(headers, params); + } else if (clientInformation.client_secret) { params.set("client_secret", clientInformation.client_secret); } const response = await fetch(tokenUrl, { method: "POST", - headers: { - "Content-Type": "application/x-www-form-urlencoded", - }, + headers: headers, body: params, }); @@ -433,6 +439,7 @@ export async function refreshAuthorization( clientInformation: OAuthClientInformation; refreshToken: string; }, + provider?: OAuthClientProvider, ): Promise { const grantType = "refresh_token"; @@ -453,21 +460,24 @@ export async function refreshAuthorization( } // Exchange refresh token + const headers = new Headers({ + "Content-Type": "application/x-www-form-urlencoded", + }); const params = new URLSearchParams({ grant_type: grantType, client_id: clientInformation.client_id, refresh_token: refreshToken, }); - if (clientInformation.client_secret) { + if (provider?.authToTokenEndpoint) { + provider.authToTokenEndpoint(headers, params); + } else if (clientInformation.client_secret) { params.set("client_secret", clientInformation.client_secret); } const response = await fetch(tokenUrl, { method: "POST", - headers: { - "Content-Type": "application/x-www-form-urlencoded", - }, + headers: headers, body: params, }); if (!response.ok) {