Skip to content

Commit b711380

Browse files
authored
Merge pull request kubernetes#128670 from liggitt/externaljwt-broadcast
Move broadcast of key updates into sync, fixup of externaljwt generation / test
2 parents ab30adc + 070f74b commit b711380

File tree

6 files changed

+328
-48
lines changed

6 files changed

+328
-48
lines changed

pkg/serviceaccount/externaljwt/plugin/keycache.go

+38-22
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ import (
2525
"time"
2626

2727
"golang.org/x/sync/singleflight"
28-
"k8s.io/klog/v2"
29-
"k8s.io/kubernetes/pkg/serviceaccount"
3028

3129
externaljwtv1alpha1 "k8s.io/externaljwt/apis/v1alpha1"
30+
"k8s.io/klog/v2"
31+
"k8s.io/kubernetes/pkg/serviceaccount"
3232
externaljwtmetrics "k8s.io/kubernetes/pkg/serviceaccount/externaljwt/metrics"
3333
)
3434

@@ -56,7 +56,7 @@ func newKeyCache(client externaljwtv1alpha1.ExternalJWTSignerClient) *keyCache {
5656
// InitialFill can be used to perform an initial fetch for keys get the
5757
// refresh interval as recommended by external signer.
5858
func (p *keyCache) initialFill(ctx context.Context) error {
59-
if _, err := p.syncKeys(ctx); err != nil {
59+
if err := p.syncKeys(ctx); err != nil {
6060
return fmt.Errorf("while performing initial cache fill: %w", err)
6161
}
6262
return nil
@@ -66,7 +66,6 @@ func (p *keyCache) scheduleSync(ctx context.Context, keySyncTimeout time.Duratio
6666
timer := time.NewTimer(p.verificationKeys.Load().NextRefreshHint.Sub(time.Now()))
6767
defer timer.Stop()
6868

69-
var lastDataTimestamp time.Time
7069
for {
7170
select {
7271
case <-ctx.Done():
@@ -76,16 +75,11 @@ func (p *keyCache) scheduleSync(ctx context.Context, keySyncTimeout time.Duratio
7675
}
7776

7877
timedCtx, cancel := context.WithTimeout(ctx, keySyncTimeout)
79-
dataTimestamp, err := p.syncKeys(timedCtx)
80-
if err != nil {
78+
if err := p.syncKeys(timedCtx); err != nil {
8179
klog.Errorf("when syncing supported public keys(Stale set of keys will be supported): %v", err)
8280
timer.Reset(fallbackRefreshDuration)
8381
} else {
8482
timer.Reset(p.verificationKeys.Load().NextRefreshHint.Sub(time.Now()))
85-
if lastDataTimestamp.IsZero() || !dataTimestamp.Equal(lastDataTimestamp) {
86-
lastDataTimestamp = dataTimestamp
87-
p.broadcastUpdate()
88-
}
8983
}
9084
cancel()
9185
}
@@ -115,7 +109,7 @@ func (p *keyCache) GetPublicKeys(ctx context.Context, keyID string) []serviceacc
115109
}
116110

117111
// If we didn't find it, trigger a sync.
118-
if _, err := p.syncKeys(ctx); err != nil {
112+
if err := p.syncKeys(ctx); err != nil {
119113
klog.ErrorS(err, "Error while syncing keys")
120114
return []serviceaccount.PublicKey{}
121115
}
@@ -152,35 +146,57 @@ func (p *keyCache) findKeyForKeyID(keyID string) ([]serviceaccount.PublicKey, bo
152146

153147
// sync supported external keys.
154148
// completely re-writes the set of supported keys.
155-
func (p *keyCache) syncKeys(ctx context.Context) (time.Time, error) {
156-
val, err, _ := p.syncGroup.Do("", func() (any, error) {
149+
func (p *keyCache) syncKeys(ctx context.Context) error {
150+
_, err, _ := p.syncGroup.Do("", func() (any, error) {
151+
oldPublicKeys := p.verificationKeys.Load()
157152
newPublicKeys, err := p.getTokenVerificationKeys(ctx)
158153
externaljwtmetrics.RecordFetchKeysAttempt(err)
159154
if err != nil {
160155
return nil, fmt.Errorf("while fetching token verification keys: %w", err)
161156
}
162157

163158
p.verificationKeys.Store(newPublicKeys)
164-
165159
externaljwtmetrics.RecordKeyDataTimeStamp(newPublicKeys.DataTimestamp.Unix())
166160

167-
return newPublicKeys, nil
168-
})
169-
if err != nil {
170-
return time.Time{}, err
171-
}
161+
if keysChanged(oldPublicKeys, newPublicKeys) {
162+
p.broadcastUpdate()
163+
}
172164

173-
vk := val.(*VerificationKeys)
165+
return nil, nil
166+
})
167+
return err
168+
}
174169

175-
return vk.DataTimestamp, nil
170+
// keysChanged returns true if the data timestamp, key count, order of key ids or excludeFromOIDCDiscovery indicators
171+
func keysChanged(oldPublicKeys, newPublicKeys *VerificationKeys) bool {
172+
// If the timestamp changed, we changed
173+
if !oldPublicKeys.DataTimestamp.Equal(newPublicKeys.DataTimestamp) {
174+
return true
175+
}
176+
// Avoid deepequal checks on key content itself.
177+
// If the number of keys changed, we changed
178+
if len(oldPublicKeys.Keys) != len(newPublicKeys.Keys) {
179+
return true
180+
}
181+
// If the order, key id, or oidc discovery flag changed, we changed.
182+
for i := range oldPublicKeys.Keys {
183+
if oldPublicKeys.Keys[i].KeyID != newPublicKeys.Keys[i].KeyID {
184+
return true
185+
}
186+
if oldPublicKeys.Keys[i].ExcludeFromOIDCDiscovery != newPublicKeys.Keys[i].ExcludeFromOIDCDiscovery {
187+
return true
188+
}
189+
}
190+
return false
176191
}
177192

178193
func (p *keyCache) broadcastUpdate() {
179194
p.listenersLock.Lock()
180195
defer p.listenersLock.Unlock()
181196

182197
for _, l := range p.listeners {
183-
l.Enqueue()
198+
// don't block on a slow listener
199+
go l.Enqueue()
184200
}
185201
}
186202

pkg/serviceaccount/externaljwt/plugin/keycache_test.go

+110-5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"net"
2323
"os"
2424
"strings"
25+
"sync/atomic"
2526
"testing"
2627
"time"
2728

@@ -31,6 +32,7 @@ import (
3132
"google.golang.org/grpc/credentials/insecure"
3233
"google.golang.org/protobuf/types/known/timestamppb"
3334

35+
"k8s.io/apimachinery/pkg/util/wait"
3436
externaljwtv1alpha1 "k8s.io/externaljwt/apis/v1alpha1"
3537
"k8s.io/kubernetes/pkg/serviceaccount"
3638
)
@@ -167,7 +169,7 @@ func TestExternalPublicKeyGetter(t *testing.T) {
167169
t.Run(tc.desc, func(t *testing.T) {
168170
ctx := context.Background()
169171

170-
sockname := fmt.Sprintf("@test-external-public-key-getter-%d.sock", i)
172+
sockname := fmt.Sprintf("@test-external-public-key-getter-%d-%d.sock", time.Now().Nanosecond(), i)
171173
t.Cleanup(func() { _ = os.Remove(sockname) })
172174

173175
addr := &net.UnixAddr{Name: sockname, Net: "unix"}
@@ -238,7 +240,7 @@ func TestExternalPublicKeyGetter(t *testing.T) {
238240
func TestInitialFill(t *testing.T) {
239241
ctx := context.Background()
240242

241-
sockname := "@test-initial-fill.sock"
243+
sockname := fmt.Sprintf("@test-initial-fill-%d.sock", time.Now().Nanosecond())
242244
t.Cleanup(func() { _ = os.Remove(sockname) })
243245

244246
addr := &net.UnixAddr{Name: sockname, Net: "unix"}
@@ -304,7 +306,7 @@ func TestInitialFill(t *testing.T) {
304306
func TestReflectChanges(t *testing.T) {
305307
ctx := context.Background()
306308

307-
sockname := "@test-reflect-changes.sock"
309+
sockname := fmt.Sprintf("@test-reflect-changes-%d.sock", time.Now().Nanosecond())
308310
t.Cleanup(func() { _ = os.Remove(sockname) })
309311

310312
addr := &net.UnixAddr{Name: sockname, Net: "unix"}
@@ -357,18 +359,25 @@ func TestReflectChanges(t *testing.T) {
357359

358360
plugin := newPlugin("iss", clientConn, true)
359361

362+
dummyListener := &dummyListener{}
363+
plugin.keyCache.AddListener(dummyListener)
364+
365+
dummyListener.waitForCount(t, 0)
360366
if err := plugin.keyCache.initialFill(ctx); err != nil {
361367
t.Fatalf("Error during InitialFill: %v", err)
362368
}
369+
dummyListener.waitForCount(t, 1)
363370

364371
gotPubKeysT1 := plugin.keyCache.GetPublicKeys(ctx, "")
365372
if diff := cmp.Diff(gotPubKeysT1, wantPubKeysT1, cmpopts.SortSlices(sortPublicKeySlice)); diff != "" {
366373
t.Fatalf("Bad public keys; diff (-got +want)\n%s", diff)
367374
}
368375

369-
if _, err := plugin.keyCache.syncKeys(ctx); err != nil {
376+
dummyListener.waitForCount(t, 1)
377+
if err := plugin.keyCache.syncKeys(ctx); err != nil {
370378
t.Fatalf("Error while calling syncKeys: %v", err)
371379
}
380+
dummyListener.waitForCount(t, 1)
372381

373382
supportedKeysT2 := map[string]supportedKeyT{
374383
"key-1": {
@@ -396,12 +405,108 @@ func TestReflectChanges(t *testing.T) {
396405
backend.supportedKeys = supportedKeysT2
397406
backend.keyLock.Unlock()
398407

399-
if _, err := plugin.keyCache.syncKeys(ctx); err != nil {
408+
dummyListener.waitForCount(t, 1)
409+
if err := plugin.keyCache.syncKeys(ctx); err != nil {
400410
t.Fatalf("Error while calling syncKeys: %v", err)
401411
}
412+
dummyListener.waitForCount(t, 2)
402413

403414
gotPubKeysT2 := plugin.keyCache.GetPublicKeys(ctx, "")
404415
if diff := cmp.Diff(gotPubKeysT2, wantPubKeysT2, cmpopts.SortSlices(sortPublicKeySlice)); diff != "" {
405416
t.Fatalf("Bad public keys; diff (-got +want)\n%s", diff)
406417
}
418+
dummyListener.waitForCount(t, 2)
419+
}
420+
421+
type dummyListener struct {
422+
count atomic.Int64
423+
}
424+
425+
func (d *dummyListener) waitForCount(t *testing.T, expect int) {
426+
t.Helper()
427+
err := wait.PollUntilContextTimeout(context.Background(), time.Millisecond, 10*time.Second, true, func(_ context.Context) (bool, error) {
428+
actual := int(d.count.Load())
429+
switch {
430+
case actual > expect:
431+
return false, fmt.Errorf("expected %d broadcasts, got %d broadcasts", expect, actual)
432+
case actual == expect:
433+
return true, nil
434+
default:
435+
t.Logf("expected %d broadcasts, got %d broadcasts, waiting...", expect, actual)
436+
return false, nil
437+
}
438+
})
439+
if err != nil {
440+
t.Fatal(err)
441+
}
442+
}
443+
444+
func (d *dummyListener) Enqueue() {
445+
d.count.Add(1)
446+
}
447+
448+
func TestKeysChanged(t *testing.T) {
449+
testcases := []struct {
450+
name string
451+
oldKeys VerificationKeys
452+
newKeys VerificationKeys
453+
expect bool
454+
}{
455+
{
456+
name: "empty",
457+
oldKeys: VerificationKeys{},
458+
newKeys: VerificationKeys{},
459+
expect: false,
460+
},
461+
{
462+
name: "identical",
463+
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
464+
newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
465+
expect: false,
466+
},
467+
{
468+
name: "changed datatimestamp",
469+
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
470+
newKeys: VerificationKeys{DataTimestamp: time.Unix(1001, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
471+
expect: true,
472+
},
473+
{
474+
name: "reordered keyid",
475+
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
476+
newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "b"}, {KeyID: "a"}}},
477+
expect: true,
478+
},
479+
{
480+
name: "changed keyid",
481+
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}}},
482+
newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "b"}}},
483+
expect: true,
484+
},
485+
{
486+
name: "added key",
487+
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}}},
488+
newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
489+
expect: true,
490+
},
491+
{
492+
name: "removed key",
493+
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
494+
newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}}},
495+
expect: true,
496+
},
497+
{
498+
name: "changed oidc",
499+
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a", ExcludeFromOIDCDiscovery: false}}},
500+
newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a", ExcludeFromOIDCDiscovery: true}}},
501+
expect: true,
502+
},
503+
}
504+
for _, tc := range testcases {
505+
t.Run(tc.name, func(t *testing.T) {
506+
result := keysChanged(&tc.oldKeys, &tc.newKeys)
507+
if result != tc.expect {
508+
t.Errorf("got %v, expected %v", result, tc.expect)
509+
}
510+
})
511+
}
407512
}

pkg/serviceaccount/externaljwt/plugin/plugin_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ func TestExternalTokenGenerator(t *testing.T) {
258258
t.Run(tc.desc, func(t *testing.T) {
259259
ctx := context.Background()
260260

261-
sockname := fmt.Sprintf("@test-external-token-generator-%d.sock", i)
261+
sockname := fmt.Sprintf("@test-external-token-generator-%d-%d.sock", time.Now().Nanosecond(), i)
262262
t.Cleanup(func() { _ = os.Remove(sockname) })
263263

264264
addr := &net.UnixAddr{Name: sockname, Net: "unix"}

0 commit comments

Comments
 (0)