Skip to content

Commit 71a42db

Browse files
[7.x] Rely on the computeIfAbsent logic to prevent duplicated compilation of scripts (elastic#55467) (elastic#58123)
Instead of serializing compilation using a plain lock / mutex combined with a double check, rely on the computeIfAbsent logic to prevent duplicated compilation of scripts. Made checkCompilationLimit to be thread-safe and lock free. Backport: 865acad Co-authored-by: Michael Bischoff <[email protected]>
1 parent e268a89 commit 71a42db

File tree

2 files changed

+70
-58
lines changed

2 files changed

+70
-58
lines changed

server/src/main/java/org/elasticsearch/script/ScriptCache.java

Lines changed: 66 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232

3333
import java.util.Map;
3434
import java.util.Objects;
35+
import java.util.concurrent.ExecutionException;
36+
import java.util.concurrent.atomic.AtomicReference;
3537

3638
/**
3739
* Script cache and compilation rate limiter.
@@ -44,12 +46,7 @@ public class ScriptCache {
4446

4547
private final Cache<CacheKey, Object> cache;
4648
private final ScriptMetrics scriptMetrics;
47-
48-
private final Object lock = new Object();
49-
50-
// Mutable fields, visible for tests
51-
long lastInlineCompileTime;
52-
double scriptsPerTimeWindow;
49+
final AtomicReference<TokenBucketState> tokenBucketState;
5350

5451
// Cache settings or derived from settings
5552
final int cacheSize;
@@ -81,11 +78,9 @@ public class ScriptCache {
8178
this.cache = cacheBuilder.removalListener(new ScriptCacheRemovalListener()).build();
8279

8380
this.rate = maxCompilationRate;
84-
this.scriptsPerTimeWindow = this.rate.v1();
8581
this.compilesAllowedPerNano = ((double) rate.v1()) / rate.v2().nanos();
86-
87-
this.lastInlineCompileTime = System.nanoTime();
8882
this.scriptMetrics = new ScriptMetrics();
83+
this.tokenBucketState = new AtomicReference<TokenBucketState>(new TokenBucketState(this.rate.v1()));
8984
}
9085

9186
<FactoryType> FactoryType compile(
@@ -98,47 +93,43 @@ <FactoryType> FactoryType compile(
9893
) {
9994
String lang = scriptEngine.getType();
10095
CacheKey cacheKey = new CacheKey(lang, idOrCode, context.name, options);
101-
Object compiledScript = cache.get(cacheKey);
102-
103-
if (compiledScript != null) {
104-
return context.factoryClazz.cast(compiledScript);
105-
}
10696

107-
// Synchronize so we don't compile scripts many times during multiple shards all compiling a script
108-
synchronized (lock) {
109-
// Retrieve it again in case it has been put by a different thread
110-
compiledScript = cache.get(cacheKey);
111-
112-
if (compiledScript == null) {
113-
try {
114-
// Either an un-cached inline script or indexed script
115-
// If the script type is inline the name will be the same as the code for identification in exceptions
116-
// but give the script engine the chance to be better, give it separate name + source code
117-
// for the inline case, then its anonymous: null.
118-
if (logger.isTraceEnabled()) {
119-
logger.trace("context [{}]: compiling script, type: [{}], lang: [{}], options: [{}]", context.name, type,
120-
lang, options);
121-
}
122-
// Check whether too many compilations have happened
123-
checkCompilationLimit();
124-
compiledScript = scriptEngine.compile(id, idOrCode, context, options);
125-
} catch (ScriptException good) {
126-
// TODO: remove this try-catch completely, when all script engines have good exceptions!
127-
throw good; // its already good
128-
} catch (Exception exception) {
129-
throw new GeneralScriptException("Failed to compile " + type + " script [" + id + "] using lang [" + lang + "]",
130-
exception);
97+
// Relying on computeIfAbsent to avoid multiple threads from compiling the same script
98+
try {
99+
return context.factoryClazz.cast(cache.computeIfAbsent(cacheKey, key -> {
100+
// Either an un-cached inline script or indexed script
101+
// If the script type is inline the name will be the same as the code for identification in exceptions
102+
// but give the script engine the chance to be better, give it separate name + source code
103+
// for the inline case, then its anonymous: null.
104+
if (logger.isTraceEnabled()) {
105+
logger.trace("context [{}]: compiling script, type: [{}], lang: [{}], options: [{}]", context.name, type,
106+
lang, options);
131107
}
132-
108+
// Check whether too many compilations have happened
109+
checkCompilationLimit();
110+
Object compiledScript = scriptEngine.compile(id, idOrCode, context, options);
133111
// Since the cache key is the script content itself we don't need to
134112
// invalidate/check the cache if an indexed script changes.
135113
scriptMetrics.onCompilation();
136-
cache.put(cacheKey, compiledScript);
114+
return compiledScript;
115+
}));
116+
} catch (ExecutionException executionException) {
117+
Throwable cause = executionException.getCause();
118+
if (cause instanceof ScriptException) {
119+
throw (ScriptException) cause;
120+
} else if (cause instanceof Exception) {
121+
throw new GeneralScriptException("Failed to compile " + type + " script [" + id + "] using lang [" + lang + "]", cause);
122+
} else {
123+
rethrow(cause);
124+
throw new AssertionError(cause);
137125
}
138-
139126
}
127+
}
140128

141-
return context.factoryClazz.cast(compiledScript);
129+
/** Hack to rethrow unknown Exceptions from compile: */
130+
@SuppressWarnings("unchecked")
131+
static <T extends Throwable> void rethrow(Throwable t) throws T {
132+
throw (T) t;
142133
}
143134

144135
public ScriptStats stats() {
@@ -159,21 +150,26 @@ void checkCompilationLimit() {
159150
return;
160151
}
161152

162-
long now = System.nanoTime();
163-
long timePassed = now - lastInlineCompileTime;
164-
lastInlineCompileTime = now;
153+
TokenBucketState tokenBucketState = this.tokenBucketState.updateAndGet(current -> {
154+
long now = System.nanoTime();
155+
long timePassed = now - current.lastInlineCompileTime;
156+
double scriptsPerTimeWindow = current.availableTokens + (timePassed) * compilesAllowedPerNano;
165157

166-
scriptsPerTimeWindow += (timePassed) * compilesAllowedPerNano;
158+
// It's been over the time limit anyway, readjust the bucket to be level
159+
if (scriptsPerTimeWindow > rate.v1()) {
160+
scriptsPerTimeWindow = rate.v1();
161+
}
167162

168-
// It's been over the time limit anyway, readjust the bucket to be level
169-
if (scriptsPerTimeWindow > rate.v1()) {
170-
scriptsPerTimeWindow = rate.v1();
171-
}
163+
// If there is enough tokens in the bucket, allow the request and decrease the tokens by 1
164+
if (scriptsPerTimeWindow >= 1) {
165+
scriptsPerTimeWindow -= 1.0;
166+
return new TokenBucketState(now, scriptsPerTimeWindow, true);
167+
} else {
168+
return new TokenBucketState(now, scriptsPerTimeWindow, false);
169+
}
170+
});
172171

173-
// If there is enough tokens in the bucket, allow the request and decrease the tokens by 1
174-
if (scriptsPerTimeWindow >= 1) {
175-
scriptsPerTimeWindow -= 1.0;
176-
} else {
172+
if(!tokenBucketState.tokenSuccessfullyTaken) {
177173
scriptMetrics.onCompilationLimit();
178174
// Otherwise reject the request
179175
throw new CircuitBreakingException("[script] Too many dynamic script compilations within, max: [" +
@@ -231,4 +227,20 @@ public int hashCode() {
231227
return Objects.hash(lang, idOrCode, context, options);
232228
}
233229
}
230+
231+
static class TokenBucketState {
232+
public final long lastInlineCompileTime;
233+
public final double availableTokens;
234+
public final boolean tokenSuccessfullyTaken;
235+
236+
private TokenBucketState(double availableTokens) {
237+
this(System.nanoTime(), availableTokens, false);
238+
}
239+
240+
private TokenBucketState(long lastInlineCompileTime, double availableTokens, boolean tokenSuccessfullyTaken) {
241+
this.lastInlineCompileTime = lastInlineCompileTime;
242+
this.availableTokens = availableTokens;
243+
this.tokenSuccessfullyTaken = tokenSuccessfullyTaken;
244+
}
245+
}
234246
}

server/src/test/java/org/elasticsearch/script/ScriptCacheTests.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,12 @@ public void testUnlimitedCompilationRate() {
5959
final TimeValue expire = ScriptService.SCRIPT_GENERAL_CACHE_EXPIRE_SETTING.get(Settings.EMPTY);
6060
String settingName = ScriptService.SCRIPT_GENERAL_MAX_COMPILATIONS_RATE_SETTING.getKey();
6161
ScriptCache cache = new ScriptCache(size, expire, ScriptCache.UNLIMITED_COMPILATION_RATE, settingName);
62-
long lastInlineCompileTime = cache.lastInlineCompileTime;
63-
double scriptsPerTimeWindow = cache.scriptsPerTimeWindow;
62+
ScriptCache.TokenBucketState initialState = cache.tokenBucketState.get();
6463
for(int i=0; i < 3000; i++) {
6564
cache.checkCompilationLimit();
66-
assertEquals(lastInlineCompileTime, cache.lastInlineCompileTime);
67-
assertEquals(scriptsPerTimeWindow, cache.scriptsPerTimeWindow, 0.0); // delta of 0.0 because it should never change
65+
ScriptCache.TokenBucketState currentState = cache.tokenBucketState.get();
66+
assertEquals(initialState.lastInlineCompileTime, currentState.lastInlineCompileTime);
67+
assertEquals(initialState.availableTokens, currentState.availableTokens, 0.0); // delta of 0.0 because it should never change
6868
}
6969
}
7070
}

0 commit comments

Comments
 (0)