diff --git a/docs/changelog/88209.yaml b/docs/changelog/88209.yaml new file mode 100644 index 0000000000000..0e16b6553fbfd --- /dev/null +++ b/docs/changelog/88209.yaml @@ -0,0 +1,6 @@ +pr: 88209 +summary: Prioritize shard snapshot tasks over file snapshot tasks and limit the number of the concurrently running snapshot tasks +area: Snapshot/Restore +type: enhancement +issues: + - 83408 diff --git a/server/src/main/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunner.java b/server/src/main/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunner.java new file mode 100644 index 0000000000000..0f67a45bd0c96 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunner.java @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.common.util.concurrent; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Executor; +import java.util.concurrent.PriorityBlockingQueue; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * {@link PrioritizedThrottledTaskRunner} performs the enqueued tasks in the order dictated by the + * natural ordering of the tasks, limiting the max number of concurrently running tasks. Each new task + * that is dequeued to be run, is forked off to the given executor. + */ +public class PrioritizedThrottledTaskRunner & Runnable> { + private static final Logger logger = LogManager.getLogger(PrioritizedThrottledTaskRunner.class); + + private final String taskRunnerName; + // The max number of tasks that this runner will schedule to concurrently run on the executor. + private final int maxRunningTasks; + // As we fork off dequeued tasks to the given executor, technically the following counter represents + // the number of the concurrent pollAndSpawn calls currently checking the queue for a task to run. This + // doesn't necessarily correspond to currently running tasks, since a pollAndSpawn could return without + // actually running a task when the queue is empty. + private final AtomicInteger runningTasks = new AtomicInteger(); + private final BlockingQueue tasks = new PriorityBlockingQueue<>(); + private final Executor executor; + + public PrioritizedThrottledTaskRunner(final String name, final int maxRunningTasks, final Executor executor) { + assert maxRunningTasks > 0; + this.taskRunnerName = name; + this.maxRunningTasks = maxRunningTasks; + this.executor = executor; + } + + public void enqueueTask(final T task) { + logger.trace("[{}] enqueuing task {}", taskRunnerName, task); + tasks.add(task); + // Try to run a task since now there is at least one in the queue. If the maxRunningTasks is + // reached, the task is just enqueued. + pollAndSpawn(); + } + + // visible for testing + protected void pollAndSpawn() { + // A pollAndSpawn attempts to run a new task. There could be many concurrent pollAndSpawn calls competing + // to get a "free slot", since we attempt to run a new task on every enqueueTask call and every time an + // existing task is finished. + while (incrementRunningTasks()) { + T task = tasks.poll(); + if (task == null) { + logger.trace("[{}] task queue is empty", taskRunnerName); + // We have taken up a "free slot", but there are no tasks in the queue! This could happen each time a worker + // sees an empty queue after running a task. Decrement to give competing pollAndSpawn calls a chance! + int decremented = runningTasks.decrementAndGet(); + assert decremented >= 0; + // We might have blocked all competing pollAndSpawn calls. This could happen for example when + // maxRunningTasks=1 and a task got enqueued just after checking the queue but before decrementing. + // To be sure, return only if the queue is still empty. If the queue is not empty, this might be the + // only pollAndSpawn call in progress, and returning without peeking would risk ending up with a + // non-empty queue and no workers! + if (tasks.peek() == null) break; + } else { + executor.execute(() -> runTask(task)); + } + } + } + + // Each worker thread that runs a task, first needs to get a "free slot" in order to respect maxRunningTasks. + private boolean incrementRunningTasks() { + int preUpdateValue = runningTasks.getAndUpdate(v -> v < maxRunningTasks ? v + 1 : v); + assert preUpdateValue <= maxRunningTasks; + return preUpdateValue < maxRunningTasks; + } + + // Only use for testing + public int runningTasks() { + return runningTasks.get(); + } + + // Only use for testing + public int queueSize() { + return tasks.size(); + } + + private void runTask(final T task) { + try { + logger.trace("[{}] running task {}", taskRunnerName, task); + task.run(); + } finally { + // To avoid missing to run tasks that are enqueued and waiting, we check the queue again once running + // a task is finished. + int decremented = runningTasks.decrementAndGet(); + assert decremented >= 0; + pollAndSpawn(); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/repositories/SnapshotShardContext.java b/server/src/main/java/org/elasticsearch/repositories/SnapshotShardContext.java index 2e022d440f753..68cef06071519 100644 --- a/server/src/main/java/org/elasticsearch/repositories/SnapshotShardContext.java +++ b/server/src/main/java/org/elasticsearch/repositories/SnapshotShardContext.java @@ -37,6 +37,7 @@ public final class SnapshotShardContext extends ActionListener.Delegating userMetadata; + private final long snapshotStartTime; /** * @param store store to be snapshotted @@ -51,6 +52,8 @@ public final class SnapshotShardContext extends ActionListener.Delegating userMetadata, + final long snapshotStartTime, ActionListener listener ) { super(ActionListener.runBefore(listener, commitRef::close)); @@ -75,6 +79,7 @@ public SnapshotShardContext( this.snapshotStatus = snapshotStatus; this.repositoryMetaVersion = repositoryMetaVersion; this.userMetadata = userMetadata; + this.snapshotStartTime = snapshotStartTime; } public Store store() { @@ -114,6 +119,10 @@ public Map userMetadata() { return userMetadata; } + public long snapshotStartTime() { + return snapshotStartTime; + } + @Override public void onResponse(ShardSnapshotResult result) { delegate.onResponse(result); diff --git a/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java b/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java index 35d72e13b371a..5db1d0961364a 100644 --- a/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java +++ b/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java @@ -150,6 +150,7 @@ import java.util.stream.Stream; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.index.snapshots.blobstore.BlobStoreIndexShardSnapshot.FileInfo; import static org.elasticsearch.index.snapshots.blobstore.BlobStoreIndexShardSnapshot.FileInfo.canonicalName; /** @@ -376,6 +377,8 @@ public abstract class BlobStoreRepository extends AbstractLifecycleComponent imp */ private final int maxSnapshotCount; + private final ShardSnapshotTaskRunner shardSnapshotTaskRunner; + /** * Constructs new BlobStoreRepository * @param metadata The metadata for this repository including name and settings @@ -405,6 +408,12 @@ protected BlobStoreRepository( this.basePath = basePath; this.maxSnapshotCount = MAX_SNAPSHOTS_SETTING.get(metadata.settings()); this.repoDataDeduplicator = new ResultDeduplicator<>(threadPool.getThreadContext()); + shardSnapshotTaskRunner = new ShardSnapshotTaskRunner( + threadPool.info(ThreadPool.Names.SNAPSHOT).getMax(), + threadPool.executor(ThreadPool.Names.SNAPSHOT), + this::doSnapshotShard, + this::snapshotFile + ); } @Override @@ -2629,6 +2638,10 @@ private void writeAtomic( @Override public void snapshotShard(SnapshotShardContext context) { + shardSnapshotTaskRunner.enqueueShardSnapshot(context); + } + + private void doSnapshotShard(SnapshotShardContext context) { if (isReadOnly()) { context.onFailure(new RepositoryException(metadata.name(), "cannot snapshot shard on a readonly repository")); return; @@ -2889,45 +2902,19 @@ public void snapshotShard(SnapshotShardContext context) { snapshotStatus.moveToDone(threadPool.absoluteTimeInMillis(), shardSnapshotResult); context.onResponse(shardSnapshotResult); }, context::onFailure); - if (indexIncrementalFileCount == 0) { + if (indexIncrementalFileCount == 0 || filesToSnapshot.isEmpty()) { allFilesUploadedListener.onResponse(Collections.emptyList()); return; } - final Executor executor = threadPool.executor(ThreadPool.Names.SNAPSHOT); - // Start as many workers as fit into the snapshot pool at once at the most - final int workers = Math.min(threadPool.info(ThreadPool.Names.SNAPSHOT).getMax(), indexIncrementalFileCount); - final ActionListener filesListener = fileQueueListener(filesToSnapshot, workers, allFilesUploadedListener); - for (int i = 0; i < workers; ++i) { - executeOneFileSnapshot(store, snapshotId, context.indexId(), snapshotStatus, filesToSnapshot, executor, filesListener); + final ActionListener filesListener = fileQueueListener(filesToSnapshot, filesToSnapshot.size(), allFilesUploadedListener); + for (FileInfo fileInfo : filesToSnapshot) { + shardSnapshotTaskRunner.enqueueFileSnapshot(context, fileInfo, filesListener); } } catch (Exception e) { context.onFailure(e); } } - private void executeOneFileSnapshot( - Store store, - SnapshotId snapshotId, - IndexId indexId, - IndexShardSnapshotStatus snapshotStatus, - BlockingQueue filesToSnapshot, - Executor executor, - ActionListener listener - ) throws InterruptedException { - final ShardId shardId = store.shardId(); - final BlobStoreIndexShardSnapshot.FileInfo snapshotFileInfo = filesToSnapshot.poll(0L, TimeUnit.MILLISECONDS); - if (snapshotFileInfo == null) { - listener.onResponse(null); - } else { - executor.execute(ActionRunnable.wrap(listener, l -> { - try (Releasable ignored = incrementStoreRef(store, snapshotStatus, shardId)) { - snapshotFile(snapshotFileInfo, indexId, shardId, snapshotId, snapshotStatus, store); - executeOneFileSnapshot(store, snapshotId, indexId, snapshotStatus, filesToSnapshot, executor, l); - } - })); - } - } - private static Releasable incrementStoreRef(Store store, IndexShardSnapshotStatus snapshotStatus, ShardId shardId) { if (store.tryIncRef() == false) { if (snapshotStatus.isAborted()) { @@ -3116,10 +3103,10 @@ void ensureNotClosing(final Store store) throws AlreadyClosedException { private static ActionListener fileQueueListener( BlockingQueue files, - int workers, + int numberOfFiles, ActionListener> listener ) { - return new GroupedActionListener<>(listener, workers).delegateResponse((l, e) -> { + return new GroupedActionListener<>(listener, numberOfFiles).delegateResponse((l, e) -> { files.clear(); // Stop uploading the remaining files if we run into any exception l.onFailure(e); }); @@ -3426,19 +3413,20 @@ private Tuple buildBlobStoreIndexShardSnapsh /** * Snapshot individual file - * @param fileInfo file to be snapshotted + * @param fileInfo file to snapshot */ - private void snapshotFile( - BlobStoreIndexShardSnapshot.FileInfo fileInfo, - IndexId indexId, - ShardId shardId, - SnapshotId snapshotId, - IndexShardSnapshotStatus snapshotStatus, - Store store - ) throws IOException { + private void snapshotFile(SnapshotShardContext context, FileInfo fileInfo) throws IOException { + final IndexId indexId = context.indexId(); + final Store store = context.store(); + final ShardId shardId = store.shardId(); + final IndexShardSnapshotStatus snapshotStatus = context.status(); + final SnapshotId snapshotId = context.snapshotId(); final BlobContainer shardContainer = shardContainer(indexId, shardId); final String file = fileInfo.physicalName(); - try (IndexInput indexInput = store.openVerifyingInput(file, IOContext.READONCE, fileInfo.metadata())) { + try ( + Releasable ignored = BlobStoreRepository.incrementStoreRef(store, snapshotStatus, store.shardId()); + IndexInput indexInput = store.openVerifyingInput(file, IOContext.READONCE, fileInfo.metadata()) + ) { for (int i = 0; i < fileInfo.numberOfParts(); i++) { final long partBytes = fileInfo.partBytes(i); diff --git a/server/src/main/java/org/elasticsearch/repositories/blobstore/ShardSnapshotTaskRunner.java b/server/src/main/java/org/elasticsearch/repositories/blobstore/ShardSnapshotTaskRunner.java new file mode 100644 index 0000000000000..41e6c8ce99471 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/repositories/blobstore/ShardSnapshotTaskRunner.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.repositories.blobstore; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRunnable; +import org.elasticsearch.common.CheckedBiConsumer; +import org.elasticsearch.common.util.concurrent.PrioritizedThrottledTaskRunner; +import org.elasticsearch.repositories.SnapshotShardContext; + +import java.io.IOException; +import java.util.concurrent.Executor; +import java.util.function.Consumer; + +import static org.elasticsearch.index.snapshots.blobstore.BlobStoreIndexShardSnapshot.FileInfo; + +/** + * {@link ShardSnapshotTaskRunner} performs snapshotting tasks, prioritizing {@link ShardSnapshotTask} + * over {@link FileSnapshotTask}. Each enqueued shard to snapshot results in one {@link ShardSnapshotTask} + * and zero or more {@link FileSnapshotTask}s. + */ +public class ShardSnapshotTaskRunner { + private final PrioritizedThrottledTaskRunner taskRunner; + private final Consumer shardSnapshotter; + private final CheckedBiConsumer fileSnapshotter; + + abstract static class SnapshotTask implements Comparable, Runnable { + protected final SnapshotShardContext context; + + SnapshotTask(SnapshotShardContext context) { + this.context = context; + } + + public abstract int priority(); + + public SnapshotShardContext context() { + return context; + } + + @Override + public final int compareTo(SnapshotTask other) { + int res = Integer.compare(priority(), other.priority()); + if (res != 0) { + return res; + } + return Long.compare(context.snapshotStartTime(), other.context.snapshotStartTime()); + } + } + + class ShardSnapshotTask extends SnapshotTask { + ShardSnapshotTask(SnapshotShardContext context) { + super(context); + } + + @Override + public void run() { + shardSnapshotter.accept(context); + } + + @Override + public int priority() { + return 1; + } + + @Override + public String toString() { + return getClass().getSimpleName() + "{snapshotID=[" + context.snapshotId() + "], indexID=[" + context.indexId() + "]}"; + } + } + + class FileSnapshotTask extends SnapshotTask { + private final FileInfo fileInfo; + private final ActionListener fileSnapshotListener; + + FileSnapshotTask(SnapshotShardContext context, FileInfo fileInfo, ActionListener fileSnapshotListener) { + super(context); + this.fileInfo = fileInfo; + this.fileSnapshotListener = fileSnapshotListener; + } + + @Override + public void run() { + ActionRunnable.run(fileSnapshotListener, () -> fileSnapshotter.accept(context, fileInfo)).run(); + } + + @Override + public int priority() { + return 2; + } + + @Override + public String toString() { + return getClass().getSimpleName() + + "{snapshotID=[" + + context.snapshotId() + + "], indexID=[" + + context.indexId() + + "], file=[" + + fileInfo.name() + + "]}"; + } + } + + public ShardSnapshotTaskRunner( + final int maxRunningTasks, + final Executor executor, + final Consumer shardSnapshotter, + final CheckedBiConsumer fileSnapshotter + ) { + this.taskRunner = new PrioritizedThrottledTaskRunner<>("ShardSnapshotTaskRunner", maxRunningTasks, executor); + this.shardSnapshotter = shardSnapshotter; + this.fileSnapshotter = fileSnapshotter; + } + + public void enqueueShardSnapshot(final SnapshotShardContext context) { + ShardSnapshotTask task = new ShardSnapshotTask(context); + taskRunner.enqueueTask(task); + } + + public void enqueueFileSnapshot(final SnapshotShardContext context, final FileInfo fileInfo, final ActionListener listener) { + final FileSnapshotTask task = new FileSnapshotTask(context, fileInfo, listener); + taskRunner.enqueueTask(task); + } + + // visible for testing + int runningTasks() { + return taskRunner.runningTasks(); + } + + // visible for testing + int queueSize() { + return taskRunner.queueSize(); + } +} diff --git a/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java b/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java index 09bce137ab8af..7e98eaa2c7119 100644 --- a/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java +++ b/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java @@ -239,7 +239,15 @@ private void startNewSnapshots(List snapshotsInProgre + snapshotStatus.generation() + "] for snapshot with old-format compatibility"; shardSnapshotTasks.add( - newShardSnapshotTask(shardId, snapshot, indexId, entry.userMetadata(), snapshotStatus, entry.version()) + newShardSnapshotTask( + shardId, + snapshot, + indexId, + entry.userMetadata(), + snapshotStatus, + entry.version(), + entry.startTime() + ) ); } @@ -272,43 +280,53 @@ private Runnable newShardSnapshotTask( final IndexId indexId, final Map userMetadata, final IndexShardSnapshotStatus snapshotStatus, - final Version entryVersion + final Version entryVersion, + final long entryStartTime ) { // separate method to make sure this lambda doesn't capture any heavy local objects like a SnapshotsInProgress.Entry - return () -> snapshot(shardId, snapshot, indexId, userMetadata, snapshotStatus, entryVersion, new ActionListener<>() { - @Override - public void onResponse(ShardSnapshotResult shardSnapshotResult) { - final ShardGeneration newGeneration = shardSnapshotResult.getGeneration(); - assert newGeneration != null; - assert newGeneration.equals(snapshotStatus.generation()); - if (logger.isDebugEnabled()) { - final IndexShardSnapshotStatus.Copy lastSnapshotStatus = snapshotStatus.asCopy(); - logger.debug( - "[{}][{}] completed snapshot to [{}] with status [{}] at generation [{}]", - shardId, - snapshot, - snapshot.getRepository(), - lastSnapshotStatus, - snapshotStatus.generation() - ); + return () -> snapshot( + shardId, + snapshot, + indexId, + userMetadata, + snapshotStatus, + entryVersion, + entryStartTime, + new ActionListener<>() { + @Override + public void onResponse(ShardSnapshotResult shardSnapshotResult) { + final ShardGeneration newGeneration = shardSnapshotResult.getGeneration(); + assert newGeneration != null; + assert newGeneration.equals(snapshotStatus.generation()); + if (logger.isDebugEnabled()) { + final IndexShardSnapshotStatus.Copy lastSnapshotStatus = snapshotStatus.asCopy(); + logger.debug( + "[{}][{}] completed snapshot to [{}] with status [{}] at generation [{}]", + shardId, + snapshot, + snapshot.getRepository(), + lastSnapshotStatus, + snapshotStatus.generation() + ); + } + notifySuccessfulSnapshotShard(snapshot, shardId, shardSnapshotResult); } - notifySuccessfulSnapshotShard(snapshot, shardId, shardSnapshotResult); - } - @Override - public void onFailure(Exception e) { - final String failure; - if (e instanceof AbortedSnapshotException) { - failure = "aborted"; - logger.debug(() -> format("[%s][%s] aborted shard snapshot", shardId, snapshot), e); - } else { - failure = summarizeFailure(e); - logger.warn(() -> format("[%s][%s] failed to snapshot shard", shardId, snapshot), e); + @Override + public void onFailure(Exception e) { + final String failure; + if (e instanceof AbortedSnapshotException) { + failure = "aborted"; + logger.debug(() -> format("[%s][%s] aborted shard snapshot", shardId, snapshot), e); + } else { + failure = summarizeFailure(e); + logger.warn(() -> format("[%s][%s] failed to snapshot shard", shardId, snapshot), e); + } + snapshotStatus.moveToFailed(threadPool.absoluteTimeInMillis(), failure); + notifyFailedSnapshotShard(snapshot, shardId, failure, snapshotStatus.generation()); } - snapshotStatus.moveToFailed(threadPool.absoluteTimeInMillis(), failure); - notifyFailedSnapshotShard(snapshot, shardId, failure, snapshotStatus.generation()); } - }); + ); } // package private for testing @@ -346,6 +364,7 @@ private void snapshot( final Map userMetadata, final IndexShardSnapshotStatus snapshotStatus, Version version, + final long entryStartTime, ActionListener listener ) { try { @@ -382,6 +401,7 @@ private void snapshot( snapshotStatus, version, userMetadata, + entryStartTime, listener ) ); diff --git a/server/src/test/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunnerTests.java b/server/src/test/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunnerTests.java new file mode 100644 index 0000000000000..3768f5556b4e9 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunnerTests.java @@ -0,0 +1,177 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.common.util.concurrent; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.lessThanOrEqualTo; + +public class PrioritizedThrottledTaskRunnerTests extends ESTestCase { + private ThreadPool threadPool; + private Executor executor; + + @Override + public void setUp() throws Exception { + super.setUp(); + threadPool = new TestThreadPool("test"); + executor = threadPool.executor(ThreadPool.Names.GENERIC); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + TestThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); + } + + static class TestTask implements Comparable, Runnable { + + private final Runnable runnable; + private final int priority; + + TestTask(Runnable runnable, int priority) { + this.runnable = runnable; + this.priority = priority; + } + + @Override + public int compareTo(TestTask o) { + return priority - o.priority; + } + + @Override + public void run() { + runnable.run(); + } + } + + public void testMultiThreadedEnqueue() throws Exception { + final int maxTasks = randomIntBetween(1, threadPool.info(ThreadPool.Names.GENERIC).getMax()); + PrioritizedThrottledTaskRunner taskRunner = new PrioritizedThrottledTaskRunner<>("test", maxTasks, executor); + final int enqueued = randomIntBetween(2 * maxTasks, 10 * maxTasks); + AtomicInteger executed = new AtomicInteger(); + CountDownLatch threadBlocker = new CountDownLatch(enqueued); + for (int i = 0; i < enqueued; i++) { + new Thread(() -> { + try { + threadBlocker.countDown(); + threadBlocker.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + taskRunner.enqueueTask(new TestTask(() -> { + try { + Thread.sleep(randomLongBetween(0, 10)); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + executed.incrementAndGet(); + }, getRandomPriority())); + assertThat(taskRunner.runningTasks(), lessThanOrEqualTo(maxTasks)); + }).start(); + } + // Eventually all tasks are executed + assertBusy(() -> { + assertThat(executed.get(), equalTo(enqueued)); + assertThat(taskRunner.runningTasks(), equalTo(0)); + }); + assertThat(taskRunner.queueSize(), equalTo(0)); + } + + public void testTasksRunInOrder() throws Exception { + final int n = randomIntBetween(1, threadPool.info(ThreadPool.Names.GENERIC).getMax()); + final int enqueued = randomIntBetween(2 * n, 10 * n); + // To check that tasks are run in the order (based on their priority), limit max running tasks to 1 + // and wait until all tasks are enqueued. + CountDownLatch workerBlocker = new CountDownLatch(1); + PrioritizedThrottledTaskRunner taskRunner = new PrioritizedThrottledTaskRunner<>("test", 1, executor) { + @Override + protected void pollAndSpawn() { + try { + workerBlocker.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + super.pollAndSpawn(); + } + }; + List taskPriorities = new ArrayList<>(enqueued); + List executedPriorities = new ArrayList<>(enqueued); + for (int i = 0; i < enqueued; i++) { + final int priority = getRandomPriority(); + taskPriorities.add(priority); + new Thread(() -> taskRunner.enqueueTask(new TestTask(() -> executedPriorities.add(priority), priority))).start(); + } + assertBusy(() -> assertThat(taskRunner.queueSize(), equalTo(enqueued))); + assertThat(taskRunner.runningTasks(), equalTo(0)); + workerBlocker.countDown(); + // Eventually all tasks are executed + assertBusy(() -> { + assertThat(executedPriorities.size(), equalTo(enqueued)); + assertThat(taskRunner.runningTasks(), equalTo(0)); + }); + assertThat(taskRunner.queueSize(), equalTo(0)); + Collections.sort(taskPriorities); + assertThat(executedPriorities, equalTo(taskPriorities)); + } + + public void testEnqueueSpawnsNewTasksUpToMax() throws Exception { + int maxTasks = randomIntBetween(1, threadPool.info(ThreadPool.Names.GENERIC).getMax()); + CountDownLatch taskBlocker = new CountDownLatch(1); + AtomicInteger executed = new AtomicInteger(); + PrioritizedThrottledTaskRunner taskRunner = new PrioritizedThrottledTaskRunner<>("test", maxTasks, executor); + final int enqueued = maxTasks - 1; // So that it is possible to run at least one more task + for (int i = 0; i < enqueued; i++) { + taskRunner.enqueueTask(new TestTask(() -> { + try { + taskBlocker.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + executed.incrementAndGet(); + }, getRandomPriority())); + assertThat(taskRunner.runningTasks(), equalTo(i + 1)); + } + // Enqueueing one or more new tasks would create only one new running task + final int newTasks = randomIntBetween(1, 10); + for (int i = 0; i < newTasks; i++) { + taskRunner.enqueueTask(new TestTask(() -> { + try { + taskBlocker.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + executed.incrementAndGet(); + }, getRandomPriority())); + assertThat(taskRunner.runningTasks(), equalTo(maxTasks)); + } + assertThat(taskRunner.queueSize(), equalTo(newTasks - 1)); + taskBlocker.countDown(); + /// Eventually all tasks are executed + assertBusy(() -> { + assertThat(executed.get(), equalTo(enqueued + newTasks)); + assertThat(taskRunner.runningTasks(), equalTo(0)); + }); + assertThat(taskRunner.queueSize(), equalTo(0)); + } + + private int getRandomPriority() { + return randomIntBetween(-1000, 1000); + } +} diff --git a/server/src/test/java/org/elasticsearch/repositories/RepositoriesServiceTests.java b/server/src/test/java/org/elasticsearch/repositories/RepositoriesServiceTests.java index 8b2339d81c7da..87480f84487ae 100644 --- a/server/src/test/java/org/elasticsearch/repositories/RepositoriesServiceTests.java +++ b/server/src/test/java/org/elasticsearch/repositories/RepositoriesServiceTests.java @@ -53,7 +53,6 @@ import java.util.function.Consumer; import java.util.function.Function; -import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.isA; import static org.mockito.Mockito.mock; @@ -69,6 +68,9 @@ public void setUp() throws Exception { ThreadContext threadContext = new ThreadContext(Settings.EMPTY); ThreadPool threadPool = mock(ThreadPool.class); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(threadPool.info(ThreadPool.Names.SNAPSHOT)).thenReturn( + new ThreadPool.Info(ThreadPool.Names.SNAPSHOT, ThreadPool.ThreadPoolType.FIXED, randomIntBetween(1, 10)) + ); final TransportService transportService = new TransportService( Settings.EMPTY, mock(Transport.class), diff --git a/server/src/test/java/org/elasticsearch/repositories/blobstore/ShardSnapshotTaskRunnerTests.java b/server/src/test/java/org/elasticsearch/repositories/blobstore/ShardSnapshotTaskRunnerTests.java new file mode 100644 index 0000000000000..e07c6d4966a3e --- /dev/null +++ b/server/src/test/java/org/elasticsearch/repositories/blobstore/ShardSnapshotTaskRunnerTests.java @@ -0,0 +1,195 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.repositories.blobstore; + +import org.apache.lucene.store.ByteBuffersDirectory; +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.GroupedActionListener; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.UUIDs; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.engine.Engine; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.index.snapshots.IndexShardSnapshotStatus; +import org.elasticsearch.index.snapshots.blobstore.BlobStoreIndexShardSnapshot; +import org.elasticsearch.index.store.Store; +import org.elasticsearch.index.store.StoreFileMetadata; +import org.elasticsearch.repositories.IndexId; +import org.elasticsearch.repositories.SnapshotShardContext; +import org.elasticsearch.snapshots.SnapshotId; +import org.elasticsearch.test.DummyShardLock; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; + +import java.util.Collections; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.lessThan; + +public class ShardSnapshotTaskRunnerTests extends ESTestCase { + + private ThreadPool threadPool; + private Executor executor; + + @Override + public void setUp() throws Exception { + super.setUp(); + threadPool = new TestThreadPool("test"); + executor = threadPool.executor(ThreadPool.Names.SNAPSHOT); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + TestThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); + } + + private static class MockedRepo { + private final AtomicInteger expectedFileSnapshotTasks = new AtomicInteger(); + private final AtomicInteger finishedFileSnapshotTasks = new AtomicInteger(); + private final AtomicInteger finishedShardSnapshotTasks = new AtomicInteger(); + private final AtomicInteger finishedShardSnapshots = new AtomicInteger(); + private ShardSnapshotTaskRunner taskRunner; + + public void setTaskRunner(ShardSnapshotTaskRunner taskRunner) { + this.taskRunner = taskRunner; + } + + public void snapshotShard(SnapshotShardContext context) { + int filesToUpload = randomIntBetween(0, 10); + if (filesToUpload == 0) { + finishedShardSnapshots.incrementAndGet(); + } else { + expectedFileSnapshotTasks.addAndGet(filesToUpload); + ActionListener uploadListener = new GroupedActionListener<>( + ActionListener.wrap(finishedShardSnapshots::incrementAndGet), + filesToUpload + ); + for (int i = 0; i < filesToUpload; i++) { + taskRunner.enqueueFileSnapshot(context, dummyFileInfo(), uploadListener); + } + } + finishedShardSnapshotTasks.incrementAndGet(); + } + + public void snapshotFile(SnapshotShardContext context, BlobStoreIndexShardSnapshot.FileInfo fileInfo) { + finishedFileSnapshotTasks.incrementAndGet(); + } + + public int expectedFileSnapshotTasks() { + return expectedFileSnapshotTasks.get(); + } + + public int finishedFileSnapshotTasks() { + return finishedFileSnapshotTasks.get(); + } + + public int finishedShardSnapshots() { + return finishedShardSnapshots.get(); + } + + public int finishedShardSnapshotTasks() { + return finishedShardSnapshotTasks.get(); + } + } + + private static BlobStoreIndexShardSnapshot.FileInfo dummyFileInfo() { + String filename = randomAlphaOfLength(10); + StoreFileMetadata metadata = new StoreFileMetadata(filename, 10, "CHECKSUM", Version.CURRENT.luceneVersion.toString()); + return new BlobStoreIndexShardSnapshot.FileInfo(filename, metadata, null); + } + + private SnapshotShardContext dummyContext() { + return dummyContext(new SnapshotId(randomAlphaOfLength(10), UUIDs.randomBase64UUID()), randomMillisUpToYear9999()); + } + + private SnapshotShardContext dummyContext(final SnapshotId snapshotId, final long startTime) { + IndexId indexId = new IndexId(randomAlphaOfLength(10), UUIDs.randomBase64UUID()); + ShardId shardId = new ShardId(indexId.getName(), indexId.getId(), 1); + Settings settings = Settings.builder() + .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) + .build(); + IndexSettings indexSettings = new IndexSettings( + IndexMetadata.builder(indexId.getName()).settings(settings).build(), + Settings.EMPTY + ); + Store dummyStore = new Store(shardId, indexSettings, new ByteBuffersDirectory(), new DummyShardLock(shardId)); + return new SnapshotShardContext( + dummyStore, + null, + snapshotId, + indexId, + new Engine.IndexCommitRef(null, () -> {}), + null, + IndexShardSnapshotStatus.newInitializing(null), + Version.CURRENT, + Collections.emptyMap(), + startTime, + ActionListener.noop() + ); + } + + public void testShardSnapshotTaskRunner() throws Exception { + int maxTasks = randomIntBetween(1, threadPool.info(ThreadPool.Names.SNAPSHOT).getMax()); + MockedRepo repo = new MockedRepo(); + ShardSnapshotTaskRunner taskRunner = new ShardSnapshotTaskRunner(maxTasks, executor, repo::snapshotShard, repo::snapshotFile); + repo.setTaskRunner(taskRunner); + int enqueuedSnapshots = randomIntBetween(maxTasks * 2, maxTasks * 10); + for (int i = 0; i < enqueuedSnapshots; i++) { + threadPool.generic().execute(() -> taskRunner.enqueueShardSnapshot(dummyContext())); + } + // Eventually all snapshots are finished + assertBusy(() -> { + assertThat(repo.finishedShardSnapshots(), equalTo(enqueuedSnapshots)); + assertThat(taskRunner.runningTasks(), equalTo(0)); + }); + assertThat(taskRunner.queueSize(), equalTo(0)); + assertThat(repo.finishedFileSnapshotTasks(), equalTo(repo.expectedFileSnapshotTasks())); + assertThat(repo.finishedShardSnapshotTasks(), equalTo(enqueuedSnapshots)); + } + + public void testCompareToShardSnapshotTask() { + ShardSnapshotTaskRunner workers = new ShardSnapshotTaskRunner(1, executor, context -> {}, (context, fileInfo) -> {}); + SnapshotId s1 = new SnapshotId("s1", UUIDs.randomBase64UUID()); + SnapshotId s2 = new SnapshotId("s2", UUIDs.randomBase64UUID()); + SnapshotId s3 = new SnapshotId("s3", UUIDs.randomBase64UUID()); + ActionListener listener = ActionListener.noop(); + final long s1StartTime = threadPool.absoluteTimeInMillis(); + final long s2StartTime = s1StartTime + randomLongBetween(1, 1000); + SnapshotShardContext s1Context = dummyContext(s1, s1StartTime); + SnapshotShardContext s2Context = dummyContext(s2, s2StartTime); + SnapshotShardContext s3Context = dummyContext(s3, s2StartTime); + // Two tasks with the same start time and of the same type have the same priority + assertThat(workers.new ShardSnapshotTask(s2Context).compareTo(workers.new ShardSnapshotTask(s3Context)), equalTo(0)); + // Shard snapshot task always has a higher priority over file snapshot + assertThat( + workers.new ShardSnapshotTask(s1Context).compareTo(workers.new FileSnapshotTask(s1Context, dummyFileInfo(), listener)), + lessThan(0) + ); + assertThat( + workers.new ShardSnapshotTask(s2Context).compareTo(workers.new FileSnapshotTask(s1Context, dummyFileInfo(), listener)), + lessThan(0) + ); + // File snapshots are prioritized by start time. + assertThat( + workers.new FileSnapshotTask(s1Context, dummyFileInfo(), listener).compareTo( + workers.new FileSnapshotTask(s2Context, dummyFileInfo(), listener) + ), + lessThan(0) + ); + } +} diff --git a/server/src/test/java/org/elasticsearch/repositories/fs/FsRepositoryTests.java b/server/src/test/java/org/elasticsearch/repositories/fs/FsRepositoryTests.java index 4fac56a66c08b..c76b131a4cb1e 100644 --- a/server/src/test/java/org/elasticsearch/repositories/fs/FsRepositoryTests.java +++ b/server/src/test/java/org/elasticsearch/repositories/fs/FsRepositoryTests.java @@ -113,6 +113,7 @@ public void testSnapshotAndRestore() throws IOException { snapshotStatus, Version.CURRENT, Collections.emptyMap(), + randomMillisUpToYear9999(), snapshot1Future ) ); @@ -155,6 +156,7 @@ public void testSnapshotAndRestore() throws IOException { snapshotStatus2, Version.CURRENT, Collections.emptyMap(), + randomMillisUpToYear9999(), snapshot2future ) ); diff --git a/test/framework/src/main/java/org/elasticsearch/index/shard/IndexShardTestCase.java b/test/framework/src/main/java/org/elasticsearch/index/shard/IndexShardTestCase.java index 46e436acc03b0..29accffd72730 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/shard/IndexShardTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/index/shard/IndexShardTestCase.java @@ -1010,6 +1010,7 @@ protected ShardGeneration snapshotShard(final IndexShard shard, final Snapshot s snapshotStatus, Version.CURRENT, Collections.emptyMap(), + randomMillisUpToYear9999(), future ) ); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/snapshots/sourceonly/SourceOnlySnapshotRepository.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/snapshots/sourceonly/SourceOnlySnapshotRepository.java index 281a61090efbe..9c0e64102e09b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/snapshots/sourceonly/SourceOnlySnapshotRepository.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/snapshots/sourceonly/SourceOnlySnapshotRepository.java @@ -208,6 +208,7 @@ protected void closeInternal() { context.status(), context.getRepositoryMetaVersion(), context.userMetadata(), + context.snapshotStartTime(), context ) ); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/snapshots/sourceonly/SourceOnlySnapshotShardTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/snapshots/sourceonly/SourceOnlySnapshotShardTests.java index a4fa4d46c1bcc..d9c25639d811f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/snapshots/sourceonly/SourceOnlySnapshotShardTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/snapshots/sourceonly/SourceOnlySnapshotShardTests.java @@ -134,6 +134,7 @@ public void testSourceIncomplete() throws IOException { indexShardSnapshotStatus, Version.CURRENT, Collections.emptyMap(), + randomMillisUpToYear9999(), future ) ) @@ -176,6 +177,7 @@ public void testIncrementalSnapshot() throws IOException { indexShardSnapshotStatus, Version.CURRENT, Collections.emptyMap(), + randomMillisUpToYear9999(), future ) ) @@ -207,6 +209,7 @@ public void testIncrementalSnapshot() throws IOException { indexShardSnapshotStatus, Version.CURRENT, Collections.emptyMap(), + randomMillisUpToYear9999(), future ) ) @@ -238,6 +241,7 @@ public void testIncrementalSnapshot() throws IOException { indexShardSnapshotStatus, Version.CURRENT, Collections.emptyMap(), + randomMillisUpToYear9999(), future ) ) @@ -299,6 +303,7 @@ public void testRestoreMinmal() throws IOException { indexShardSnapshotStatus, Version.CURRENT, Collections.emptyMap(), + randomMillisUpToYear9999(), future ) ); diff --git a/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/EncryptedRepositoryTests.java b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/EncryptedRepositoryTests.java index ee1056977e906..2a7c2bd3438af 100644 --- a/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/EncryptedRepositoryTests.java +++ b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/EncryptedRepositoryTests.java @@ -79,6 +79,9 @@ public void setUpMocks() throws Exception { final var threadContext = new ThreadContext(Settings.EMPTY); final var threadPool = mock(ThreadPool.class); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(threadPool.info(ThreadPool.Names.SNAPSHOT)).thenReturn( + new ThreadPool.Info(ThreadPool.Names.SNAPSHOT, ThreadPool.ThreadPoolType.FIXED, randomIntBetween(1, 10)) + ); when(clusterApplierService.threadPool()).thenReturn(threadPool); ClusterService clusterService = mock(ClusterService.class); when(clusterService.getClusterApplierService()).thenReturn(clusterApplierService); diff --git a/x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/xpack/searchablesnapshots/store/SearchableSnapshotDirectoryTests.java b/x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/xpack/searchablesnapshots/store/SearchableSnapshotDirectoryTests.java index 20a386486cd21..2416076192ea9 100644 --- a/x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/xpack/searchablesnapshots/store/SearchableSnapshotDirectoryTests.java +++ b/x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/xpack/searchablesnapshots/store/SearchableSnapshotDirectoryTests.java @@ -632,6 +632,7 @@ protected void assertSnapshotOrGenericThread() { snapshotStatus, Version.CURRENT, emptyMap(), + randomMillisUpToYear9999(), future ) );