diff --git a/.changeset/spicy-bags-sparkle.md b/.changeset/spicy-bags-sparkle.md new file mode 100644 index 00000000000..3d5ef9c98d7 --- /dev/null +++ b/.changeset/spicy-bags-sparkle.md @@ -0,0 +1,5 @@ +--- +'@firebase/app-check': patch +--- + +Clear App Check exchange promise correctly after request succeeds. diff --git a/packages/app-check/src/api.test.ts b/packages/app-check/src/api.test.ts index 724839226af..655e34e9d70 100644 --- a/packages/app-check/src/api.test.ts +++ b/packages/app-check/src/api.test.ts @@ -16,7 +16,7 @@ */ import '../test/setup'; import { expect } from 'chai'; -import { spy, stub } from 'sinon'; +import { SinonStub, spy, stub } from 'sinon'; import { setTokenAutoRefreshEnabled, initializeAppCheck, @@ -31,7 +31,12 @@ import { getFakeAppCheck, removegreCAPTCHAScriptsOnPage } from '../test/util'; -import { clearState, getState } from './state'; +import { + clearState, + DEFAULT_STATE, + getStateReference, + setInitialState +} from './state'; import * as reCAPTCHA from './recaptcha'; import * as util from './util'; import * as logger from './logger'; @@ -52,15 +57,21 @@ import { getDebugToken } from './debug'; describe('api', () => { let app: FirebaseApp; + let storageReadStub: SinonStub; + let storageWriteStub: SinonStub; beforeEach(() => { app = getFullApp(); + storageReadStub = stub(storage, 'readTokenFromStorage').resolves(undefined); + storageWriteStub = stub(storage, 'writeTokenToStorage'); stub(util, 'getRecaptcha').returns(getFakeGreCAPTCHA()); }); - afterEach(() => { + afterEach(async () => { clearState(); removegreCAPTCHAScriptsOnPage(); + storageReadStub.restore(); + storageWriteStub.restore(); return deleteApp(app); }); @@ -216,11 +227,11 @@ describe('api', () => { }); it('sets activated to true', () => { - expect(getState(app).activated).to.equal(false); + expect(getStateReference(app).activated).to.equal(false); initializeAppCheck(app, { provider: new ReCaptchaV3Provider(FAKE_SITE_KEY) }); - expect(getState(app).activated).to.equal(true); + expect(getStateReference(app).activated).to.equal(true); }); it('isTokenAutoRefreshEnabled value defaults to global setting', () => { @@ -228,7 +239,7 @@ describe('api', () => { initializeAppCheck(app, { provider: new ReCaptchaV3Provider(FAKE_SITE_KEY) }); - expect(getState(app).isTokenAutoRefreshEnabled).to.equal(false); + expect(getStateReference(app).isTokenAutoRefreshEnabled).to.equal(false); }); it('sets isTokenAutoRefreshEnabled correctly, overriding global setting', () => { @@ -237,15 +248,16 @@ describe('api', () => { provider: new ReCaptchaV3Provider(FAKE_SITE_KEY), isTokenAutoRefreshEnabled: true }); - expect(getState(app).isTokenAutoRefreshEnabled).to.equal(true); + expect(getStateReference(app).isTokenAutoRefreshEnabled).to.equal(true); }); }); describe('setTokenAutoRefreshEnabled()', () => { it('sets isTokenAutoRefreshEnabled correctly', () => { const app = getFakeApp({ automaticDataCollectionEnabled: false }); const appCheck = getFakeAppCheck(app); + setInitialState(app, { ...DEFAULT_STATE }); setTokenAutoRefreshEnabled(appCheck, true); - expect(getState(app).isTokenAutoRefreshEnabled).to.equal(true); + expect(getStateReference(app).isTokenAutoRefreshEnabled).to.equal(true); }); }); describe('getToken()', () => { @@ -279,7 +291,7 @@ describe('api', () => { isTokenAutoRefreshEnabled: true }); - expect(getState(app).tokenObservers.length).to.equal(1); + expect(getStateReference(app).tokenObservers.length).to.equal(1); const fakeRecaptchaToken = 'fake-recaptcha-token'; const fakeRecaptchaAppCheckToken = { @@ -291,7 +303,6 @@ describe('api', () => { stub(client, 'exchangeToken').returns( Promise.resolve(fakeRecaptchaAppCheckToken) ); - stub(storage, 'writeTokenToStorage').returns(Promise.resolve(undefined)); const listener1 = stub().throws(new Error()); const listener2 = spy(); @@ -302,7 +313,7 @@ describe('api', () => { const unsubscribe1 = onTokenChanged(appCheck, listener1, errorFn1); const unsubscribe2 = onTokenChanged(appCheck, listener2, errorFn2); - expect(getState(app).tokenObservers.length).to.equal(3); + expect(getStateReference(app).tokenObservers.length).to.equal(3); await internalApi.getToken(appCheck as AppCheckService); @@ -315,7 +326,7 @@ describe('api', () => { expect(errorFn2).to.not.be.called; unsubscribe1(); unsubscribe2(); - expect(getState(app).tokenObservers.length).to.equal(1); + expect(getStateReference(app).tokenObservers.length).to.equal(1); }); it('Listeners work when using Observer pattern', async () => { @@ -324,7 +335,7 @@ describe('api', () => { isTokenAutoRefreshEnabled: true }); - expect(getState(app).tokenObservers.length).to.equal(1); + expect(getStateReference(app).tokenObservers.length).to.equal(1); const fakeRecaptchaToken = 'fake-recaptcha-token'; const fakeRecaptchaAppCheckToken = { @@ -336,7 +347,7 @@ describe('api', () => { stub(client, 'exchangeToken').returns( Promise.resolve(fakeRecaptchaAppCheckToken) ); - stub(storage, 'writeTokenToStorage').returns(Promise.resolve(undefined)); + storageWriteStub.returns(Promise.resolve(undefined)); const listener1 = stub().throws(new Error()); const listener2 = spy(); @@ -357,7 +368,7 @@ describe('api', () => { error: errorFn1 }); - expect(getState(app).tokenObservers.length).to.equal(3); + expect(getStateReference(app).tokenObservers.length).to.equal(3); await internalApi.getToken(appCheck as AppCheckService); @@ -370,7 +381,7 @@ describe('api', () => { expect(errorFn2).to.not.be.called; unsubscribe1(); unsubscribe2(); - expect(getState(app).tokenObservers.length).to.equal(1); + expect(getStateReference(app).tokenObservers.length).to.equal(1); }); it('onError() catches token errors', async () => { @@ -380,12 +391,12 @@ describe('api', () => { isTokenAutoRefreshEnabled: false }); - expect(getState(app).tokenObservers.length).to.equal(0); + expect(getStateReference(app).tokenObservers.length).to.equal(0); const fakeRecaptchaToken = 'fake-recaptcha-token'; stub(reCAPTCHA, 'getToken').returns(Promise.resolve(fakeRecaptchaToken)); stub(client, 'exchangeToken').rejects('exchange error'); - stub(storage, 'writeTokenToStorage').returns(Promise.resolve(undefined)); + storageWriteStub.returns(Promise.resolve(undefined)); const listener1 = spy(); @@ -395,13 +406,13 @@ describe('api', () => { await internalApi.getToken(appCheck as AppCheckService); - expect(getState(app).tokenObservers.length).to.equal(1); + expect(getStateReference(app).tokenObservers.length).to.equal(1); expect(errorFn1).to.be.calledOnce; expect(errorFn1.args[0][0].name).to.include('exchange error'); unsubscribe1(); - expect(getState(app).tokenObservers.length).to.equal(0); + expect(getStateReference(app).tokenObservers.length).to.equal(0); }); }); }); diff --git a/packages/app-check/src/api.ts b/packages/app-check/src/api.ts index b885cec1201..1b7e29e008d 100644 --- a/packages/app-check/src/api.ts +++ b/packages/app-check/src/api.ts @@ -23,7 +23,12 @@ import { PartialObserver } from './public-types'; import { ERROR_FACTORY, AppCheckError } from './errors'; -import { getState, setState, AppCheckState, getDebugState } from './state'; +import { + getStateReference, + getDebugState, + DEFAULT_STATE, + setInitialState +} from './state'; import { FirebaseApp, getApp, _getProvider } from '@firebase/app'; import { getModularInstance, ErrorFn, NextFn } from '@firebase/util'; import { AppCheckService } from './factory'; @@ -101,7 +106,7 @@ export function initializeAppCheck( // If isTokenAutoRefreshEnabled is false, do not send any requests to the // exchange endpoint without an explicit call from the user either directly // or through another Firebase library (storage, functions, etc.) - if (getState(app).isTokenAutoRefreshEnabled) { + if (getStateReference(app).isTokenAutoRefreshEnabled) { // Adding a listener will start the refresher and fetch a token if needed. // This gets a token ready and prevents a delay when an internal library // requests the token. @@ -128,13 +133,15 @@ function _activate( provider: AppCheckProvider, isTokenAutoRefreshEnabled?: boolean ): void { - const state = getState(app); + // Create an entry in the APP_CHECK_STATES map. Further changes should + // directly mutate this object. + const state = setInitialState(app, { ...DEFAULT_STATE }); - const newState: AppCheckState = { ...state, activated: true }; - newState.provider = provider; // Read cached token from storage if it exists and store it in memory. - newState.cachedTokenPromise = readTokenFromStorage(app).then(cachedToken => { + state.activated = true; + state.provider = provider; // Read cached token from storage if it exists and store it in memory. + state.cachedTokenPromise = readTokenFromStorage(app).then(cachedToken => { if (cachedToken && isValid(cachedToken)) { - setState(app, { ...getState(app), token: cachedToken }); + state.token = cachedToken; // notify all listeners with the cached token notifyTokenListeners(app, { token: cachedToken.token }); } @@ -144,14 +151,12 @@ function _activate( // Use value of global `automaticDataCollectionEnabled` (which // itself defaults to false if not specified in config) if // `isTokenAutoRefreshEnabled` param was not provided by user. - newState.isTokenAutoRefreshEnabled = + state.isTokenAutoRefreshEnabled = isTokenAutoRefreshEnabled === undefined ? app.automaticDataCollectionEnabled : isTokenAutoRefreshEnabled; - setState(app, newState); - - newState.provider.initialize(app); + state.provider.initialize(app); } /** @@ -168,7 +173,7 @@ export function setTokenAutoRefreshEnabled( isTokenAutoRefreshEnabled: boolean ): void { const app = appCheckInstance.app; - const state = getState(app); + const state = getStateReference(app); // This will exist if any product libraries have called // `addTokenListener()` if (state.tokenRefresher) { @@ -178,7 +183,7 @@ export function setTokenAutoRefreshEnabled( state.tokenRefresher.stop(); } } - setState(app, { ...state, isTokenAutoRefreshEnabled }); + state.isTokenAutoRefreshEnabled = isTokenAutoRefreshEnabled; } /** * Get the current App Check token. Attaches to the most recent diff --git a/packages/app-check/src/factory.ts b/packages/app-check/src/factory.ts index 1870772cdd1..438c527f154 100644 --- a/packages/app-check/src/factory.ts +++ b/packages/app-check/src/factory.ts @@ -24,7 +24,7 @@ import { removeTokenListener } from './internal-api'; import { Provider } from '@firebase/component'; -import { getState } from './state'; +import { getStateReference } from './state'; /** * AppCheck Service class. @@ -35,7 +35,7 @@ export class AppCheckService implements AppCheck, _FirebaseService { public heartbeatServiceProvider: Provider<'heartbeat'> ) {} _delete(): Promise { - const { tokenObservers } = getState(this.app); + const { tokenObservers } = getStateReference(this.app); for (const tokenObserver of tokenObservers) { removeTokenListener(this.app, tokenObserver.next); } diff --git a/packages/app-check/src/internal-api.test.ts b/packages/app-check/src/internal-api.test.ts index 03c690fb42e..4c29476e897 100644 --- a/packages/app-check/src/internal-api.test.ts +++ b/packages/app-check/src/internal-api.test.ts @@ -39,7 +39,12 @@ import * as client from './client'; import * as storage from './storage'; import * as util from './util'; import { logger } from './logger'; -import { getState, clearState, setState, getDebugState } from './state'; +import { + getStateReference, + clearState, + setInitialState, + getDebugState +} from './state'; import { AppCheckTokenListener } from './public-types'; import { Deferred } from '@firebase/util'; import { ReCaptchaEnterpriseProvider, ReCaptchaV3Provider } from './providers'; @@ -304,11 +309,11 @@ describe('internal api', () => { const clientStub = stub(client, 'exchangeToken'); - expect(getState(app).token).to.equal(undefined); + expect(getStateReference(app).token).to.equal(undefined); expect(await getToken(appCheck as AppCheckService)).to.deep.equal({ token: fakeCachedAppCheckToken.token }); - expect(getState(app).token).to.equal(fakeCachedAppCheckToken); + expect(getStateReference(app).token).to.equal(fakeCachedAppCheckToken); expect(clientStub).has.not.been.called; clock.restore(); @@ -337,7 +342,10 @@ describe('internal api', () => { const appCheck = initializeAppCheck(app, { provider: new ReCaptchaV3Provider(FAKE_SITE_KEY) }); - setState(app, { ...getState(app), token: fakeRecaptchaAppCheckToken }); + setInitialState(app, { + ...getStateReference(app), + token: fakeRecaptchaAppCheckToken + }); const clientStub = stub(client, 'exchangeToken'); expect(await getToken(appCheck as AppCheckService)).to.deep.equal({ @@ -352,7 +360,10 @@ describe('internal api', () => { const appCheck = initializeAppCheck(app, { provider: new ReCaptchaV3Provider(FAKE_SITE_KEY) }); - setState(app, { ...getState(app), token: fakeRecaptchaAppCheckToken }); + setInitialState(app, { + ...getStateReference(app), + token: fakeRecaptchaAppCheckToken + }); stub(reCAPTCHA, 'getToken').returns(Promise.resolve(fakeRecaptchaToken)); stub(client, 'exchangeToken').returns( @@ -368,12 +379,105 @@ describe('internal api', () => { }); }); + it('no dangling exchangeToken promise internal', async () => { + const appCheck = initializeAppCheck(app, { + provider: new ReCaptchaV3Provider(FAKE_SITE_KEY) + }); + + setInitialState(app, { + ...getStateReference(app), + token: fakeRecaptchaAppCheckToken, + cachedTokenPromise: undefined + }); + + stub(reCAPTCHA, 'getToken').returns(Promise.resolve(fakeRecaptchaToken)); + stub(client, 'exchangeToken').returns( + Promise.resolve({ + token: 'new-recaptcha-app-check-token', + expireTimeMillis: Date.now() + 60000, + issuedAtTimeMillis: 0 + }) + ); + + const getTokenPromise = getToken(appCheck as AppCheckService, true); + + expect(getStateReference(app).exchangeTokenPromise).to.be.instanceOf( + Promise + ); + + await getTokenPromise; + + expect(getStateReference(app).exchangeTokenPromise).to.be.equal( + undefined + ); + }); + + it('no dangling exchangeToken promise', async () => { + const clock = useFakeTimers(); + + const appCheck = initializeAppCheck(app, { + provider: new ReCaptchaV3Provider(FAKE_SITE_KEY) + }); + + const soonExpiredToken = { + token: `recaptcha-app-check-token-old`, + expireTimeMillis: Date.now() + 1000, + issuedAtTimeMillis: 0 + }; + + setInitialState(app, { + ...getStateReference(app), + token: soonExpiredToken, + cachedTokenPromise: undefined + }); + + stub(reCAPTCHA, 'getToken').returns(Promise.resolve(fakeRecaptchaToken)); + let count = 0; + stub(client, 'exchangeToken').callsFake( + () => + new Promise(res => + setTimeout( + () => + res({ + token: `recaptcha-app-check-token-new-${count++}`, + expireTimeMillis: Date.now() + 60000, + issuedAtTimeMillis: 0 + }), + 3000 + ) + ) + ); + + // start fetch token + void getToken(appCheck as AppCheckService, true); + + clock.tick(2000); + + // save expired `token-old` with copied state and wait fetch token + void getToken(appCheck as AppCheckService); + + // wait fetch token with copied state + void getToken(appCheck as AppCheckService); + + // stored copied state with `token-new-0` + await clock.runAllAsync(); + + // fetch token with copied state + const newToken = getToken(appCheck as AppCheckService, true); + + await clock.runAllAsync(); + + expect(await newToken).to.deep.equal({ + token: 'recaptcha-app-check-token-new-1' + }); + }); + it('ignores in-memory token if it is invalid and continues to exchange request', async () => { const appCheck = initializeAppCheck(app, { provider: new ReCaptchaV3Provider(FAKE_SITE_KEY) }); - setState(app, { - ...getState(app), + setInitialState(app, { + ...getStateReference(app), token: { token: 'something', expireTimeMillis: Date.now() - 1000, @@ -447,7 +551,10 @@ describe('internal api', () => { const appCheck = initializeAppCheck(app, { provider: new ReCaptchaV3Provider(FAKE_SITE_KEY) }); - setState(app, { ...getState(app), token: fakeRecaptchaAppCheckToken }); + setInitialState(app, { + ...getStateReference(app), + token: fakeRecaptchaAppCheckToken + }); stub(reCAPTCHA, 'getToken').returns(Promise.resolve(fakeRecaptchaToken)); stub(client, 'exchangeToken').returns(Promise.reject(new Error('blah'))); @@ -531,10 +638,15 @@ describe('internal api', () => { }); describe('addTokenListener', () => { - it('adds token listeners', () => { + afterEach(async () => { + clearState(); + }); + + it('adds token listeners', async () => { + const clock = useFakeTimers(); const listener = (): void => {}; - setState(app, { - ...getState(app), + setInitialState(app, { + ...getStateReference(app), cachedTokenPromise: Promise.resolve(undefined) }); @@ -544,18 +656,20 @@ describe('internal api', () => { listener ); - expect(getState(app).tokenObservers[0].next).to.equal(listener); + expect(getStateReference(app).tokenObservers[0].next).to.equal(listener); + // resolve initTokenRefresher dangling promise + await clock.tickAsync(1000); }); it('starts proactively refreshing token after adding the first listener', async () => { const listener = (): void => {}; - setState(app, { - ...getState(app), + setInitialState(app, { + ...getStateReference(app), isTokenAutoRefreshEnabled: true, cachedTokenPromise: Promise.resolve(undefined) }); - expect(getState(app).tokenObservers.length).to.equal(0); - expect(getState(app).tokenRefresher).to.equal(undefined); + expect(getStateReference(app).tokenObservers.length).to.equal(0); + expect(getStateReference(app).tokenRefresher).to.equal(undefined); addTokenListener( { app } as AppCheckService, @@ -563,13 +677,14 @@ describe('internal api', () => { listener ); - expect(getState(app).tokenRefresher?.isRunning()).to.be.undefined; + expect(getStateReference(app).tokenRefresher?.isRunning()).to.be + .undefined; // addTokenListener() waits for the result of cachedTokenPromise // before starting the refresher - await getState(app).cachedTokenPromise; + await getStateReference(app).cachedTokenPromise; - expect(getState(app).tokenRefresher?.isRunning()).to.be.true; + expect(getStateReference(app).tokenRefresher?.isRunning()).to.be.true; }); it('notifies the listener with the valid token in memory immediately', async () => { @@ -577,8 +692,8 @@ describe('internal api', () => { const listener = stub(); - setState(app, { - ...getState(app), + setInitialState(app, { + ...getStateReference(app), cachedTokenPromise: Promise.resolve(undefined), token: { token: `fake-memory-app-check-token`, @@ -630,8 +745,8 @@ describe('internal api', () => { provider: new ReCaptchaV3Provider(FAKE_SITE_KEY), isTokenAutoRefreshEnabled: true }); - setState(app, { - ...getState(app), + setInitialState(app, { + ...getStateReference(app), token: { token: `fake-cached-app-check-token`, // within refresh window @@ -674,8 +789,8 @@ describe('internal api', () => { provider: new ReCaptchaV3Provider(FAKE_SITE_KEY), isTokenAutoRefreshEnabled: true }); - setState(app, { - ...getState(app), + setInitialState(app, { + ...getStateReference(app), token: { token: `fake-cached-app-check-token`, // not expired but within refresh window @@ -713,8 +828,8 @@ describe('internal api', () => { provider: new ReCaptchaV3Provider(FAKE_SITE_KEY), isTokenAutoRefreshEnabled: true }); - setState(app, { - ...getState(app), + setInitialState(app, { + ...getStateReference(app), token: { token: `fake-cached-app-check-token`, // not expired but within refresh window @@ -750,8 +865,8 @@ describe('internal api', () => { provider: new ReCaptchaV3Provider(FAKE_SITE_KEY), isTokenAutoRefreshEnabled: true }); - setState(app, { - ...getState(app), + setInitialState(app, { + ...getStateReference(app), token: { token: `fake-cached-app-check-token`, // expired @@ -790,8 +905,8 @@ describe('internal api', () => { provider: new ReCaptchaV3Provider(FAKE_SITE_KEY), isTokenAutoRefreshEnabled: true }); - setState(app, { - ...getState(app), + setInitialState(app, { + ...getStateReference(app), token: { token: `fake-cached-app-check-token`, // expired @@ -827,8 +942,8 @@ describe('internal api', () => { describe('removeTokenListener', () => { it('should remove token listeners', () => { const listener = (): void => {}; - setState(app, { - ...getState(app), + setInitialState(app, { + ...getStateReference(app), cachedTokenPromise: Promise.resolve(undefined) }); addTokenListener( @@ -836,17 +951,20 @@ describe('internal api', () => { ListenerType.INTERNAL, listener ); - expect(getState(app).tokenObservers.length).to.equal(1); + expect(getStateReference(app).tokenObservers.length).to.equal(1); removeTokenListener(app, listener); - expect(getState(app).tokenObservers.length).to.equal(0); + expect(getStateReference(app).tokenObservers.length).to.equal(0); }); it('should stop proactively refreshing token after deleting the last listener', async () => { const listener = (): void => {}; - setState(app, { ...getState(app), isTokenAutoRefreshEnabled: true }); - setState(app, { - ...getState(app), + setInitialState(app, { + ...getStateReference(app), + isTokenAutoRefreshEnabled: true + }); + setInitialState(app, { + ...getStateReference(app), cachedTokenPromise: Promise.resolve(undefined) }); @@ -858,14 +976,14 @@ describe('internal api', () => { // addTokenListener() waits for the result of cachedTokenPromise // before starting the refresher - await getState(app).cachedTokenPromise; + await getStateReference(app).cachedTokenPromise; - expect(getState(app).tokenObservers.length).to.equal(1); - expect(getState(app).tokenRefresher?.isRunning()).to.be.true; + expect(getStateReference(app).tokenObservers.length).to.equal(1); + expect(getStateReference(app).tokenRefresher?.isRunning()).to.be.true; removeTokenListener(app, listener); - expect(getState(app).tokenObservers.length).to.equal(0); - expect(getState(app).tokenRefresher?.isRunning()).to.be.false; + expect(getStateReference(app).tokenObservers.length).to.equal(0); + expect(getStateReference(app).tokenRefresher?.isRunning()).to.be.false; }); }); }); diff --git a/packages/app-check/src/internal-api.ts b/packages/app-check/src/internal-api.ts index 99c0f5d2c00..f07b044be8a 100644 --- a/packages/app-check/src/internal-api.ts +++ b/packages/app-check/src/internal-api.ts @@ -23,7 +23,7 @@ import { ListenerType } from './types'; import { AppCheckTokenListener } from './public-types'; -import { getState, setState } from './state'; +import { getStateReference } from './state'; import { TOKEN_REFRESH_TIME } from './constants'; import { Refresher } from './proactive-refresh'; import { ensureActivated } from './util'; @@ -65,7 +65,7 @@ export async function getToken( const app = appCheck.app; ensureActivated(app); - const state = getState(app); + const state = getStateReference(app); /** * First check if there is a token in memory from a previous `getToken()` call. @@ -78,7 +78,7 @@ export async function getToken( * memory and unset the local variable `token`. */ if (token && !isValid(token)) { - setState(app, { ...state, token: undefined }); + state.token = undefined; token = undefined; } @@ -132,7 +132,7 @@ export async function getToken( // Write debug token to indexedDB. await writeTokenToStorage(app, tokenFromDebugExchange); // Write debug token to state. - setState(app, { ...state, token: tokenFromDebugExchange }); + state.token = tokenFromDebugExchange; return { token: tokenFromDebugExchange.token }; } @@ -153,7 +153,7 @@ export async function getToken( }); shouldCallListeners = true; } - token = await state.exchangeTokenPromise; + token = await getStateReference(app).exchangeTokenPromise; } catch (e) { if ((e as FirebaseError).code === `appCheck/${AppCheckError.THROTTLED}`) { // Warn if throttled, but do not treat it as an error. @@ -195,7 +195,7 @@ export async function getToken( }; // write the new token to the memory state as well as the persistent storage. // Only do it if we got a valid new token - setState(app, { ...state, token }); + state.token = token; await writeTokenToStorage(app, token); } @@ -212,16 +212,13 @@ export function addTokenListener( onError?: (error: Error) => void ): void { const { app } = appCheck; - const state = getState(app); + const state = getStateReference(app); const tokenObserver: AppCheckTokenObserver = { next: listener, error: onError, type }; - setState(app, { - ...state, - tokenObservers: [...state.tokenObservers, tokenObserver] - }); + state.tokenObservers = [...state.tokenObservers, tokenObserver]; // Invoke the listener async immediately if there is a valid token // in memory. @@ -255,7 +252,7 @@ export function removeTokenListener( app: FirebaseApp, listener: AppCheckTokenListener ): void { - const state = getState(app); + const state = getStateReference(app); const newObservers = state.tokenObservers.filter( tokenObserver => tokenObserver.next !== listener @@ -268,10 +265,7 @@ export function removeTokenListener( state.tokenRefresher.stop(); } - setState(app, { - ...state, - tokenObservers: newObservers - }); + state.tokenObservers = newObservers; } /** @@ -279,13 +273,13 @@ export function removeTokenListener( */ function initTokenRefresher(appCheck: AppCheckService): void { const { app } = appCheck; - const state = getState(app); + const state = getStateReference(app); // Create the refresher but don't start it if `isTokenAutoRefreshEnabled` // is not true. let refresher: Refresher | undefined = state.tokenRefresher; if (!refresher) { refresher = createTokenRefresher(appCheck); - setState(app, { ...state, tokenRefresher: refresher }); + state.tokenRefresher = refresher; } if (!refresher.isRunning() && state.isTokenAutoRefreshEnabled) { refresher.start(); @@ -298,7 +292,7 @@ function createTokenRefresher(appCheck: AppCheckService): Refresher { // Keep in mind when this fails for any reason other than the ones // for which we should retry, it will effectively stop the proactive refresh. async () => { - const state = getState(app); + const state = getStateReference(app); // If there is no token, we will try to load it from storage and use it // If there is a token, we force refresh it because we know it's going to expire soon let result; @@ -331,7 +325,7 @@ function createTokenRefresher(appCheck: AppCheckService): Refresher { return true; }, () => { - const state = getState(app); + const state = getStateReference(app); if (state.token) { // issuedAtTime + (50% * total TTL) + 5 minutes @@ -361,7 +355,7 @@ export function notifyTokenListeners( app: FirebaseApp, token: AppCheckTokenResult ): void { - const observers = getState(app).tokenObservers; + const observers = getStateReference(app).tokenObservers; for (const observer of observers) { try { diff --git a/packages/app-check/src/recaptcha.test.ts b/packages/app-check/src/recaptcha.test.ts index 2716064c6c0..8d371025195 100644 --- a/packages/app-check/src/recaptcha.test.ts +++ b/packages/app-check/src/recaptcha.test.ts @@ -33,7 +33,12 @@ import { GreCAPTCHATopLevel } from './recaptcha'; import * as utils from './util'; -import { getState } from './state'; +import { + clearState, + DEFAULT_STATE, + getStateReference, + setInitialState +} from './state'; import { Deferred } from '@firebase/util'; import { initializeAppCheck } from './api'; import { ReCaptchaEnterpriseProvider, ReCaptchaV3Provider } from './providers'; @@ -43,9 +48,11 @@ describe('recaptcha', () => { beforeEach(() => { app = getFullApp(); + setInitialState(app, { ...DEFAULT_STATE }); }); afterEach(() => { + clearState(); removegreCAPTCHAScriptsOnPage(); return deleteApp(app); }); @@ -53,11 +60,11 @@ describe('recaptcha', () => { describe('initialize() - V3', () => { it('sets reCAPTCHAState', async () => { self.grecaptcha = getFakeGreCAPTCHA() as GreCAPTCHATopLevel; - expect(getState(app).reCAPTCHAState).to.equal(undefined); + expect(getStateReference(app).reCAPTCHAState).to.equal(undefined); await initializeV3(app, FAKE_SITE_KEY); - expect(getState(app).reCAPTCHAState?.initialized).to.be.instanceof( - Deferred - ); + expect( + getStateReference(app).reCAPTCHAState?.initialized + ).to.be.instanceof(Deferred); }); it('loads reCAPTCHA script if it was not loaded already', async () => { @@ -89,18 +96,20 @@ describe('recaptcha', () => { size: 'invisible' }); - expect(getState(app).reCAPTCHAState?.widgetId).to.equal('fake_widget_1'); + expect(getStateReference(app).reCAPTCHAState?.widgetId).to.equal( + 'fake_widget_1' + ); }); }); describe('initialize() - Enterprise', () => { it('sets reCAPTCHAState', async () => { self.grecaptcha = getFakeGreCAPTCHA() as GreCAPTCHATopLevel; - expect(getState(app).reCAPTCHAState).to.equal(undefined); + expect(getStateReference(app).reCAPTCHAState).to.equal(undefined); await initializeEnterprise(app, FAKE_SITE_KEY); - expect(getState(app).reCAPTCHAState?.initialized).to.be.instanceof( - Deferred - ); + expect( + getStateReference(app).reCAPTCHAState?.initialized + ).to.be.instanceof(Deferred); }); it('loads reCAPTCHA script if it was not loaded already', async () => { @@ -135,7 +144,9 @@ describe('recaptcha', () => { size: 'invisible' }); - expect(getState(app).reCAPTCHAState?.widgetId).to.equal('fake_widget_1'); + expect(getStateReference(app).reCAPTCHAState?.widgetId).to.equal( + 'fake_widget_1' + ); }); }); diff --git a/packages/app-check/src/recaptcha.ts b/packages/app-check/src/recaptcha.ts index bf85b67dd40..2956f830a7b 100644 --- a/packages/app-check/src/recaptcha.ts +++ b/packages/app-check/src/recaptcha.ts @@ -16,7 +16,7 @@ */ import { FirebaseApp } from '@firebase/app'; -import { getState, setState } from './state'; +import { getStateReference } from './state'; import { Deferred } from '@firebase/util'; import { getRecaptcha, ensureActivated } from './util'; @@ -28,10 +28,11 @@ export function initializeV3( app: FirebaseApp, siteKey: string ): Promise { - const state = getState(app); const initialized = new Deferred(); - setState(app, { ...state, reCAPTCHAState: { initialized } }); + const state = getStateReference(app); + state.reCAPTCHAState = { initialized }; + const divId = makeDiv(app); const grecaptcha = getRecaptcha(false); @@ -54,10 +55,11 @@ export function initializeEnterprise( app: FirebaseApp, siteKey: string ): Promise { - const state = getState(app); const initialized = new Deferred(); - setState(app, { ...state, reCAPTCHAState: { initialized } }); + const state = getStateReference(app); + state.reCAPTCHAState = { initialized }; + const divId = makeDiv(app); const grecaptcha = getRecaptcha(true); @@ -113,12 +115,12 @@ export async function getToken(app: FirebaseApp): Promise { ensureActivated(app); // ensureActivated() guarantees that reCAPTCHAState is set - const reCAPTCHAState = getState(app).reCAPTCHAState!; + const reCAPTCHAState = getStateReference(app).reCAPTCHAState!; const recaptcha = await reCAPTCHAState.initialized.promise; return new Promise((resolve, _reject) => { // Updated after initialization is complete. - const reCAPTCHAState = getState(app).reCAPTCHAState!; + const reCAPTCHAState = getStateReference(app).reCAPTCHAState!; recaptcha.ready(() => { resolve( // widgetId is guaranteed to be available if reCAPTCHAState.initialized.promise resolved. @@ -145,16 +147,12 @@ function renderInvisibleWidget( sitekey: siteKey, size: 'invisible' }); + const state = getStateReference(app); - const state = getState(app); - - setState(app, { - ...state, - reCAPTCHAState: { - ...state.reCAPTCHAState!, // state.reCAPTCHAState is set in the initialize() - widgetId - } - }); + state.reCAPTCHAState = { + ...state.reCAPTCHAState!, // state.reCAPTCHAState is set in the initialize() + widgetId + }; } function loadReCAPTCHAV3Script(onload: () => void): void { diff --git a/packages/app-check/src/state.ts b/packages/app-check/src/state.ts index 0a81ba81636..23c0efc0bd9 100644 --- a/packages/app-check/src/state.ts +++ b/packages/app-check/src/state.ts @@ -58,12 +58,23 @@ const DEBUG_STATE: DebugState = { enabled: false }; -export function getState(app: FirebaseApp): AppCheckState { - return APP_CHECK_STATES.get(app) || DEFAULT_STATE; +/** + * Gets a reference to the state object. + */ +export function getStateReference(app: FirebaseApp): AppCheckState { + return APP_CHECK_STATES.get(app) || { ...DEFAULT_STATE }; } -export function setState(app: FirebaseApp, state: AppCheckState): void { +/** + * Set once on initialization. The map should hold the same reference to the + * same object until this entry is deleted. + */ +export function setInitialState( + app: FirebaseApp, + state: AppCheckState +): AppCheckState { APP_CHECK_STATES.set(app, state); + return APP_CHECK_STATES.get(app) as AppCheckState; } // for testing only diff --git a/packages/app-check/src/util.ts b/packages/app-check/src/util.ts index de9cc91f623..0ccb502d79f 100644 --- a/packages/app-check/src/util.ts +++ b/packages/app-check/src/util.ts @@ -16,7 +16,7 @@ */ import { GreCAPTCHA } from './recaptcha'; -import { getState } from './state'; +import { getStateReference } from './state'; import { ERROR_FACTORY, AppCheckError } from './errors'; import { FirebaseApp } from '@firebase/app'; @@ -30,7 +30,7 @@ export function getRecaptcha( } export function ensureActivated(app: FirebaseApp): void { - if (!getState(app).activated) { + if (!getStateReference(app).activated) { throw ERROR_FACTORY.create(AppCheckError.USE_BEFORE_ACTIVATION, { appName: app.name });