Skip to content

Commit 395beae

Browse files
committed
Recompute active realms when license changes
This commit changes the implementation of the Realms class to listen for license changes, and recompute the set of actively licensed realms only when the license changes rather than each time the "asList" method is called. This is primarily a performance optimisation, but it also allows us to turn off the "in use" license tracking for realms when they are disabled by a change in license. Relates: elastic#76476
1 parent 58f66cf commit 395beae

File tree

4 files changed

+171
-70
lines changed

4 files changed

+171
-70
lines changed

x-pack/plugin/core/src/test/java/org/elasticsearch/license/MockLicenseState.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,9 @@ public boolean isAllowed(LicensedFeature feature) {
2525
public void enableUsageTracking(LicensedFeature feature, String contextName) {
2626
super.enableUsageTracking(feature, contextName);
2727
}
28+
29+
@Override
30+
public void disableUsageTracking(LicensedFeature feature, String contextName) {
31+
super.disableUsageTracking(feature, contextName);
32+
}
2833
}

x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/Realms.java

Lines changed: 50 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import org.apache.logging.log4j.LogManager;
1010
import org.apache.logging.log4j.Logger;
11-
import org.elasticsearch.Assertions;
1211
import org.elasticsearch.action.ActionListener;
1312
import org.elasticsearch.common.Strings;
1413
import org.elasticsearch.common.collect.MapBuilder;
@@ -30,7 +29,6 @@
3029
import org.elasticsearch.xpack.security.authc.esnative.ReservedRealm;
3130

3231
import java.util.ArrayList;
33-
import java.util.Arrays;
3432
import java.util.Collections;
3533
import java.util.HashMap;
3634
import java.util.HashSet;
@@ -59,12 +57,11 @@ public class Realms implements Iterable<Realm> {
5957
private final ThreadContext threadContext;
6058
private final ReservedRealm reservedRealm;
6159

62-
protected List<Realm> realms;
63-
// a list of realms that are considered standard in that they are provided by x-pack and
64-
// interact with a 3rd party source on a limited basis
65-
List<Realm> standardRealmsOnly;
66-
// a list of realms that are considered native, that is they only interact with x-pack and no 3rd party auth sources
67-
List<Realm> nativeRealmsOnly;
60+
// All realms that were explicitly configured in the settings, some of these may not be enabled due to licensing
61+
private final List<Realm> allConfiguredRealms;
62+
63+
// the realms in current use. This list will change dynamically as the license changes
64+
private volatile List<Realm> activeRealms;
6865

6966
public Realms(Settings settings, Environment env, Map<String, Realm.Factory> factories, XPackLicenseState licenseState,
7067
ThreadContext threadContext, ReservedRealm reservedRealm) throws Exception {
@@ -74,35 +71,40 @@ public Realms(Settings settings, Environment env, Map<String, Realm.Factory> fac
7471
this.licenseState = licenseState;
7572
this.threadContext = threadContext;
7673
this.reservedRealm = reservedRealm;
74+
7775
assert XPackSettings.SECURITY_ENABLED.get(settings) : "security must be enabled";
7876
assert factories.get(ReservedRealm.TYPE) == null;
77+
7978
final List<RealmConfig> realmConfigs = buildRealmConfigs();
80-
this.realms = initRealms(realmConfigs);
81-
assert realms.get(0) == reservedRealm : "the first realm must be reserved realm";
82-
// pre-computing a list of internal only realms allows us to have much cheaper iteration than a custom iterator
83-
// and is also simpler in terms of logic. These lists are small, so the duplication should not be a real issue here
84-
List<Realm> standardRealms = new ArrayList<>(List.of(reservedRealm));
85-
List<Realm> basicRealms = new ArrayList<>(List.of(reservedRealm));
86-
for (Realm realm : realms) {
87-
// don't add the reserved realm here otherwise we end up with only this realm...
88-
if (InternalRealms.isStandardRealm(realm.type())) {
89-
standardRealms.add(realm);
90-
}
79+
this.allConfiguredRealms = initRealms(realmConfigs);
80+
this.allConfiguredRealms.forEach(r -> r.initialize(allConfiguredRealms, licenseState));
81+
assert allConfiguredRealms.get(0) == reservedRealm : "the first realm must be reserved realm";
9182

92-
if (InternalRealms.isBuiltinRealm(realm.type())) {
93-
basicRealms.add(realm);
94-
}
95-
}
83+
recomputeActiveRealms();
84+
licenseState.addListener(this::recomputeActiveRealms);
85+
}
9686

97-
if (Assertions.ENABLED) {
98-
for (List<Realm> realmList : Arrays.asList(standardRealms, basicRealms)) {
99-
assert realmList.get(0) == reservedRealm : "the first realm must be reserved realm";
100-
}
87+
protected void recomputeActiveRealms() {
88+
final XPackLicenseState licenseStateSnapshot = licenseState.copyCurrentLicenseState();
89+
final List<Realm> licensedRealms = calculateLicensedRealms(licenseStateSnapshot);
90+
logger.info(
91+
"license mode is [{}], currently licensed security realms are [{}]",
92+
licenseStateSnapshot.getOperationMode().description(),
93+
Strings.collectionToCommaDelimitedString(licensedRealms)
94+
);
95+
96+
// Stop license-tracking for any previously-active realms that are no longer allowed
97+
if (activeRealms != null) {
98+
activeRealms.stream().filter(r -> licensedRealms.contains(r) == false).forEach(realm -> {
99+
if (InternalRealms.isStandardRealm(realm.type())) {
100+
Security.STANDARD_REALMS_FEATURE.stopTracking(licenseStateSnapshot, realm.name());
101+
} else {
102+
Security.ALL_REALMS_FEATURE.stopTracking(licenseStateSnapshot, realm.name());
103+
}
104+
});
101105
}
102106

103-
this.standardRealmsOnly = Collections.unmodifiableList(standardRealms);
104-
this.nativeRealmsOnly = Collections.unmodifiableList(basicRealms);
105-
realms.forEach(r -> r.initialize(this, licenseState));
107+
activeRealms = licensedRealms;
106108
}
107109

108110
@Override
@@ -114,37 +116,33 @@ public Iterator<Realm> iterator() {
114116
* Returns a list of realms that are configured, but are not permitted under the current license.
115117
*/
116118
public List<Realm> getUnlicensedRealms() {
117-
final XPackLicenseState licenseStateSnapshot = licenseState.copyCurrentLicenseState();
118-
119-
// If all realms are allowed, then nothing is unlicensed
120-
if (Security.ALL_REALMS_FEATURE.checkWithoutTracking(licenseStateSnapshot)) {
121-
return Collections.emptyList();
122-
}
123-
124-
final List<Realm> allowedRealms = this.asList();
125-
// Shortcut for the typical case, all the configured realms are allowed
126-
if (allowedRealms.equals(this.realms)) {
119+
final List<Realm> activeSnapshot = activeRealms;
120+
if (activeSnapshot.equals(allConfiguredRealms)) {
127121
return Collections.emptyList();
128122
}
129123

130124
// Otherwise, we return anything in "all realms" that is not in the allowed realm list
131-
return realms.stream().filter(r -> allowedRealms.contains(r) == false).collect(Collectors.toUnmodifiableList());
125+
return allConfiguredRealms.stream().filter(r -> activeSnapshot.contains(r) == false).collect(Collectors.toUnmodifiableList());
132126
}
133127

134128
public Stream<Realm> stream() {
135129
return StreamSupport.stream(this.spliterator(), false);
136130
}
137131

138132
public List<Realm> asList() {
139-
// TODO : Recalculate this when the license changes rather than on every call
140-
return realms.stream().filter(r -> checkLicense(r, licenseState)).collect(Collectors.toUnmodifiableList());
133+
assert activeRealms != null : "Active realms not configured";
134+
return activeRealms;
135+
}
136+
137+
// Protected for testing
138+
protected List<Realm> calculateLicensedRealms(XPackLicenseState licenseStateSnapshot) {
139+
return allConfiguredRealms.stream()
140+
.filter(r -> checkLicense(r, licenseStateSnapshot))
141+
.collect(Collectors.toUnmodifiableList());
141142
}
142143

143144
private static boolean checkLicense(Realm realm, XPackLicenseState licenseState) {
144-
if (ReservedRealm.TYPE.equals(realm.type())) {
145-
return true;
146-
}
147-
if (InternalRealms.isBuiltinRealm(realm.type())) {
145+
if (isBasicLicensedRealm(realm)) {
148146
return true;
149147
}
150148
if (InternalRealms.isStandardRealm(realm.type())) {
@@ -153,8 +151,12 @@ private static boolean checkLicense(Realm realm, XPackLicenseState licenseState)
153151
return Security.ALL_REALMS_FEATURE.checkAndStartTracking(licenseState, realm.name());
154152
}
155153

154+
private static boolean isBasicLicensedRealm(Realm realm) {
155+
return ReservedRealm.TYPE.equals(realm.type()) || InternalRealms.isBuiltinRealm(realm.type());
156+
}
157+
156158
public Realm realm(String name) {
157-
for (Realm realm : realms) {
159+
for (Realm realm : activeRealms) {
158160
if (name.equals(realm.name())) {
159161
return realm;
160162
}

x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
import org.elasticsearch.index.get.GetResult;
5252
import org.elasticsearch.index.seqno.SequenceNumbers;
5353
import org.elasticsearch.index.shard.ShardId;
54+
import org.elasticsearch.license.License;
5455
import org.elasticsearch.license.MockLicenseState;
5556
import org.elasticsearch.license.XPackLicenseState;
5657
import org.elasticsearch.license.XPackLicenseState.Feature;
@@ -220,16 +221,23 @@ public void init() throws Exception {
220221
.build();
221222
MockLicenseState licenseState = mock(MockLicenseState.class);
222223
when(licenseState.isAllowed(Security.ALL_REALMS_FEATURE)).thenReturn(true);
224+
when(licenseState.isAllowed(Security.STANDARD_REALMS_FEATURE)).thenReturn(true);
223225
when(licenseState.checkFeature(Feature.SECURITY_TOKEN_SERVICE)).thenReturn(true);
224226
when(licenseState.copyCurrentLicenseState()).thenReturn(licenseState);
225227
when(licenseState.checkFeature(Feature.SECURITY_AUDITING)).thenReturn(true);
228+
when(licenseState.getOperationMode()).thenReturn(randomFrom(License.OperationMode.ENTERPRISE, License.OperationMode.PLATINUM));
229+
226230
ReservedRealm reservedRealm = mock(ReservedRealm.class);
227231
when(reservedRealm.type()).thenReturn("reserved");
228232
when(reservedRealm.name()).thenReturn("reserved_realm");
229233
realms = spy(new TestRealms(Settings.EMPTY, TestEnvironment.newEnvironment(settings),
230234
Map.of(FileRealmSettings.TYPE, config -> mock(FileRealm.class), NativeRealmSettings.TYPE, config -> mock(NativeRealm.class)),
231235
licenseState, threadContext, reservedRealm, Arrays.asList(firstRealm, secondRealm),
232-
Collections.singletonList(firstRealm)));
236+
Arrays.asList(firstRealm)));
237+
238+
// Needed because this is calculated in the constructor, which means the override doesn't get called correctly
239+
realms.recomputeActiveRealms();
240+
assertThat(realms.asList(), contains(firstRealm, secondRealm));
233241

234242
auditTrail = mock(AuditTrail.class);
235243
auditTrailService = new AuditTrailService(Collections.singletonList(auditTrail), licenseState);
@@ -388,7 +396,10 @@ public void testAuthenticateBothSupportSecondSucceeds() throws Exception {
388396
}, this::logAndFail));
389397
assertTrue(completed.get());
390398
verify(auditTrail).authenticationFailed(reqId.get(), firstRealm.name(), token, "_action", transportRequest);
391-
verify(realms).asList();
399+
verify(realms, atLeastOnce()).recomputeActiveRealms();
400+
verify(realms, atLeastOnce()).calculateLicensedRealms(any(XPackLicenseState.class));
401+
verify(realms, atLeastOnce()).asList();
402+
// ^^ We don't care how many times these methods are called, we just check it here so that we can verify no more interactions below.
392403
verifyNoMoreInteractions(realms);
393404
}
394405

@@ -447,9 +458,8 @@ public void testAuthenticateSmartRealmOrdering() {
447458

448459
verify(auditTrail).authenticationFailed(reqId.get(), firstRealm.name(), token, "_action", transportRequest);
449460
verify(firstRealm, times(2)).name(); // used above one time
450-
verify(firstRealm, atLeastOnce()).type();
451-
verify(secondRealm, Mockito.atLeast(3)).name(); // also used in license tracking
452-
verify(secondRealm, Mockito.atLeast(3)).type(); // used to create realm ref, and license tracking
461+
verify(secondRealm, Mockito.atLeast(2)).name(); // also used in license tracking
462+
verify(secondRealm, Mockito.atLeast(2)).type(); // used to create realm ref, and license tracking
453463
verify(firstRealm, times(2)).token(threadContext);
454464
verify(secondRealm, times(2)).token(threadContext);
455465
verify(firstRealm).supports(token);
@@ -573,9 +583,8 @@ public void testAuthenticateSmartRealmOrderingDisabled() {
573583
}, this::logAndFail));
574584
verify(auditTrail, times(2)).authenticationFailed(reqId.get(), firstRealm.name(), token, "_action", transportRequest);
575585
verify(firstRealm, times(3)).name(); // used above one time
576-
verify(firstRealm, atLeastOnce()).type();
577-
verify(secondRealm, Mockito.atLeast(3)).name();
578-
verify(secondRealm, Mockito.atLeast(3)).type(); // used to create realm ref
586+
verify(secondRealm, Mockito.atLeast(2)).name();
587+
verify(secondRealm, Mockito.atLeast(2)).type(); // used to create realm ref
579588
verify(firstRealm, times(2)).token(threadContext);
580589
verify(secondRealm, times(2)).token(threadContext);
581590
verify(firstRealm, times(2)).supports(token);
@@ -638,10 +647,7 @@ public void testAuthenticateCached() throws Exception {
638647
assertThat(result.v1(), is(authentication));
639648
assertThat(result.v1().getAuthenticationType(), is(AuthenticationType.REALM));
640649
verifyZeroInteractions(auditTrail);
641-
verify(firstRealm, atLeastOnce()).type();
642-
verify(secondRealm, atLeastOnce()).type();
643-
verify(secondRealm, atLeastOnce()).name(); // This realm is license-tracked, which uses the name
644-
verifyNoMoreInteractions(firstRealm, secondRealm);
650+
verifyZeroInteractions(firstRealm, secondRealm);
645651
verifyZeroInteractions(operatorPrivilegesService);
646652
}
647653

@@ -920,8 +926,6 @@ public void testAuthenticateTransportContextAndHeader() throws Exception {
920926
verifyZeroInteractions(operatorPrivilegesService);
921927
}, this::logAndFail));
922928
assertTrue(completed.compareAndSet(true, false));
923-
verify(firstRealm, atLeastOnce()).type();
924-
verify(firstRealm, atLeastOnce()).name();
925929
verifyNoMoreInteractions(firstRealm);
926930
reset(firstRealm);
927931
} finally {
@@ -970,8 +974,6 @@ public void testAuthenticateTransportContextAndHeader() throws Exception {
970974
verifyZeroInteractions(operatorPrivilegesService);
971975
}, this::logAndFail));
972976
assertTrue(completed.get());
973-
verify(firstRealm, atLeastOnce()).type();
974-
verify(firstRealm, atLeastOnce()).name();
975977
verifyNoMoreInteractions(firstRealm);
976978
} finally {
977979
terminate(threadPool2);
@@ -2106,12 +2108,40 @@ private static void mockRealmLookupReturnsNull(Realm realm, String username) {
21062108

21072109
static class TestRealms extends Realms {
21082110

2109-
TestRealms(Settings settings, Environment env, Map<String, Factory> factories, XPackLicenseState licenseState,
2110-
ThreadContext threadContext, ReservedRealm reservedRealm, List<Realm> realms, List<Realm> internalRealms)
2111-
throws Exception {
2111+
private final List<Realm> allRealms;
2112+
private final List<Realm> internalRealms;
2113+
2114+
TestRealms(
2115+
Settings settings,
2116+
Environment env,
2117+
Map<String, Factory> factories,
2118+
XPackLicenseState licenseState,
2119+
ThreadContext threadContext,
2120+
ReservedRealm reservedRealm,
2121+
List<Realm> realms,
2122+
List<Realm> internalRealms
2123+
) throws Exception {
21122124
super(settings, env, factories, licenseState, threadContext, reservedRealm);
2113-
this.realms = realms;
2114-
this.standardRealmsOnly = internalRealms;
2125+
this.allRealms = realms;
2126+
this.internalRealms = internalRealms;
2127+
}
2128+
2129+
@Override
2130+
protected List<Realm> calculateLicensedRealms(XPackLicenseState licenseState) {
2131+
if (allRealms == null) {
2132+
// This can happen because the realms are recalculated during construction
2133+
return super.calculateLicensedRealms(licenseState);
2134+
}
2135+
if (Security.STANDARD_REALMS_FEATURE.checkWithoutTracking(licenseState)) {
2136+
return allRealms;
2137+
} else {
2138+
return internalRealms;
2139+
}
2140+
}
2141+
2142+
// Make public for testing
2143+
public void recomputeActiveRealms() {
2144+
super.recomputeActiveRealms();
21152145
}
21162146
}
21172147

0 commit comments

Comments
 (0)