Skip to content

feat: improve caching by only decoding jwks when necessary #486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions src/CachedKeySet.php
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
namespace Firebase\JWT;

use ArrayAccess;
use InvalidArgumentException;
use LogicException;
use OutOfBoundsException;
use Psr\Cache\CacheItemInterface;
use Psr\Cache\CacheItemPoolInterface;
use Psr\Http\Client\ClientInterface;
use Psr\Http\Message\RequestFactoryInterface;
use RuntimeException;
use UnexpectedValueException;

/**
* @implements ArrayAccess<string, Key>
Expand Down Expand Up @@ -41,7 +43,7 @@ class CachedKeySet implements ArrayAccess
*/
private $cacheItem;
/**
* @var array<string, Key>
* @var array<string, array<mixed>>
*/
private $keySet;
/**
Expand Down Expand Up @@ -101,7 +103,7 @@ public function offsetGet($keyId): Key
if (!$this->keyIdExists($keyId)) {
throw new OutOfBoundsException('Key ID not found');
}
return $this->keySet[$keyId];
return JWK::parseKey($this->keySet[$keyId], $this->defaultAlg);
}

/**
Expand Down Expand Up @@ -130,15 +132,43 @@ public function offsetUnset($offset): void
throw new LogicException('Method not implemented');
}

/**
* @return array<mixed>
*/
private function formatJwksForCache(string $jwks): array
{
$jwks = json_decode($jwks, true);

if (!isset($jwks['keys'])) {
throw new UnexpectedValueException('"keys" member must exist in the JWK Set');
}

if (empty($jwks['keys'])) {
throw new InvalidArgumentException('JWK Set did not contain any keys');
}

$keys = [];
foreach ($jwks['keys'] as $k => $v) {
$kid = isset($v['kid']) ? $v['kid'] : $k;
$keys[(string) $kid] = $v;
}

return $keys;
}

private function keyIdExists(string $keyId): bool
{
if (null === $this->keySet) {
$item = $this->getCacheItem();
// Try to load keys from cache
if ($item->isHit()) {
// item found! Return it
$jwks = $item->get();
$this->keySet = JWK::parseKeySet(json_decode($jwks, true), $this->defaultAlg);
// item found! retrieve it
$this->keySet = $item->get();
// If the cached item is a string, the JWKS response was cached (previous behavior).
// Parse this into expected format array<kid, jwk> instead.
if (\is_string($this->keySet)) {
$this->keySet = $this->formatJwksForCache($this->keySet);
}
}
}

Expand All @@ -148,15 +178,14 @@ private function keyIdExists(string $keyId): bool
}
$request = $this->httpFactory->createRequest('GET', $this->jwksUri);
$jwksResponse = $this->httpClient->sendRequest($request);
$jwks = (string) $jwksResponse->getBody();
$this->keySet = JWK::parseKeySet(json_decode($jwks, true), $this->defaultAlg);
$this->keySet = $this->formatJwksForCache((string) $jwksResponse->getBody());

if (!isset($this->keySet[$keyId])) {
return false;
}

$item = $this->getCacheItem();
$item->set($jwks);
$item->set($this->keySet);
if ($this->expiresAfter) {
$item->expiresAfter($this->expiresAfter);
}
Expand Down
77 changes: 72 additions & 5 deletions tests/CachedKeySetTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ class CachedKeySetTest extends TestCase
private $testJwksUri = 'https://jwk.uri';
private $testJwksUriKey = 'jwkshttpsjwk.uri';
private $testJwks1 = '{"keys": [{"kid":"foo","kty":"RSA","alg":"foo","n":"","e":""}]}';
private $testCachedJwks1 = ['foo' => ['kid' => 'foo', 'kty' => 'RSA', 'alg' => 'foo', 'n' => '', 'e' => '']];
private $testJwks2 = '{"keys": [{"kid":"bar","kty":"RSA","alg":"bar","n":"","e":""}]}';
private $testJwks3 = '{"keys": [{"kid":"baz","kty":"RSA","n":"","e":""}]}';

private $googleRsaUri = 'https://www.googleapis.com/oauth2/v3/certs';
// private $googleEcUri = 'https://www.gstatic.com/iap/verify/public_key-jwk';
private $googleEcUri = 'https://www.gstatic.com/iap/verify/public_key-jwk';

public function testEmptyUriThrowsException()
{
Expand Down Expand Up @@ -117,7 +118,7 @@ public function testKeyIdIsCached()
$cacheItem->isHit()
->willReturn(true);
$cacheItem->get()
->willReturn($this->testJwks1);
->willReturn($this->testCachedJwks1);

$cache = $this->prophesize(CacheItemPoolInterface::class);
$cache->getItem($this->testJwksUriKey)
Expand All @@ -136,6 +137,66 @@ public function testKeyIdIsCached()
}

public function testCachedKeyIdRefresh()
{
$cacheItem = $this->prophesize(CacheItemInterface::class);
$cacheItem->isHit()
->shouldBeCalledOnce()
->willReturn(true);
$cacheItem->get()
->shouldBeCalledOnce()
->willReturn($this->testCachedJwks1);
$cacheItem->set(Argument::any())
->shouldBeCalledOnce()
->will(function () {
return $this;
});

$cache = $this->prophesize(CacheItemPoolInterface::class);
$cache->getItem($this->testJwksUriKey)
->shouldBeCalledOnce()
->willReturn($cacheItem->reveal());
$cache->save(Argument::any())
->shouldBeCalledOnce()
->willReturn(true);

$cachedKeySet = new CachedKeySet(
$this->testJwksUri,
$this->getMockHttpClient($this->testJwks2), // updated JWK
$this->getMockHttpFactory(),
$cache->reveal()
);
$this->assertInstanceOf(Key::class, $cachedKeySet['foo']);
$this->assertSame('foo', $cachedKeySet['foo']->getAlgorithm());

$this->assertInstanceOf(Key::class, $cachedKeySet['bar']);
$this->assertSame('bar', $cachedKeySet['bar']->getAlgorithm());
}

public function testKeyIdIsCachedFromPreviousFormat()
{
$cacheItem = $this->prophesize(CacheItemInterface::class);
$cacheItem->isHit()
->willReturn(true);
$cacheItem->get()
->willReturn($this->testJwks1);

$cache = $this->prophesize(CacheItemPoolInterface::class);
$cache->getItem($this->testJwksUriKey)
->willReturn($cacheItem->reveal());
$cache->save(Argument::any())
->willReturn(true);

$cachedKeySet = new CachedKeySet(
$this->testJwksUri,
$this->prophesize(ClientInterface::class)->reveal(),
$this->prophesize(RequestFactoryInterface::class)->reveal(),
$cache->reveal()
);
$this->assertInstanceOf(Key::class, $cachedKeySet['foo']);
$this->assertSame('foo', $cachedKeySet['foo']->getAlgorithm());
}

public function testCachedKeyIdRefreshFromPreviousFormat()
{
$cacheItem = $this->prophesize(CacheItemInterface::class);
$cacheItem->isHit()
Expand Down Expand Up @@ -213,12 +274,18 @@ public function testJwtVerify()
$payload = ['sub' => 'foo', 'exp' => strtotime('+10 seconds')];
$msg = JWT::encode($payload, $privKey1, 'RS256', 'jwk1');

// format the cached value to match the expected format
$cachedJwks = [];
$rsaKeySet = file_get_contents(__DIR__ . '/data/rsa-jwkset.json');
foreach (json_decode($rsaKeySet, true)['keys'] as $k => $v) {
$cachedJwks[$v['kid']] = $v;
}

$cacheItem = $this->prophesize(CacheItemInterface::class);
$cacheItem->isHit()
->willReturn(true);
$cacheItem->get()
->willReturn(file_get_contents(__DIR__ . '/data/rsa-jwkset.json')
);
->willReturn($cachedJwks);

$cache = $this->prophesize(CacheItemPoolInterface::class);
$cache->getItem($this->testJwksUriKey)
Expand Down Expand Up @@ -297,7 +364,7 @@ public function provideFullIntegration()
{
return [
[$this->googleRsaUri],
// [$this->googleEcUri, 'LYyP2g']
[$this->googleEcUri, 'LYyP2g']
];
}

Expand Down