Skip to content

Further reduce allocations in TransportGetSnapshotsAction #110817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
Expand Down Expand Up @@ -248,18 +249,8 @@ void getMultipleReposSnapshotInfo(ActionListener<GetSnapshotsResponse> listener)
return;
}

SubscribableListener

.<RepositoryData>newForked(repositoryDataListener -> {
if (snapshotNamePredicate == SnapshotNamePredicate.MATCH_CURRENT_ONLY) {
repositoryDataListener.onResponse(null);
} else {
repositoriesService.repository(repoName).getRepositoryData(executor, repositoryDataListener);
}
})

SubscribableListener.<RepositoryData>newForked(l -> maybeGetRepositoryData(repoName, l))
.<Void>andThen((l, repositoryData) -> loadSnapshotInfos(repoName, repositoryData, l))

.addListener(listeners.acquire());
}
}
Expand All @@ -268,6 +259,14 @@ void getMultipleReposSnapshotInfo(ActionListener<GetSnapshotsResponse> listener)
.addListener(listener.map(ignored -> buildResponse()), executor, threadPool.getThreadContext());
}

private void maybeGetRepositoryData(String repositoryName, ActionListener<RepositoryData> listener) {
if (snapshotNamePredicate == SnapshotNamePredicate.MATCH_CURRENT_ONLY) {
listener.onResponse(null);
} else {
repositoriesService.repository(repositoryName).getRepositoryData(executor, listener);
}
}

private boolean skipRepository(String repositoryName) {
if (sortBy == SnapshotSortKey.REPOSITORY && fromSortValue != null) {
// If we are sorting by repository name with an offset given by fromSortValue, skip earlier repositories
Expand All @@ -277,61 +276,101 @@ private boolean skipRepository(String repositoryName) {
}
}

private void loadSnapshotInfos(String repo, @Nullable RepositoryData repositoryData, ActionListener<Void> listener) {
private void loadSnapshotInfos(String repositoryName, @Nullable RepositoryData repositoryData, ActionListener<Void> listener) {
assert ThreadPool.assertCurrentThreadPool(ThreadPool.Names.MANAGEMENT);

if (cancellableTask.notifyIfCancelled(listener)) {
return;
}

final Set<String> unmatchedRequiredNames = new HashSet<>(snapshotNamePredicate.requiredNames());
final Set<Snapshot> toResolve = new HashSet<>();

for (final var snapshotInProgress : snapshotsInProgress.forRepo(repo)) {
final var snapshotName = snapshotInProgress.snapshot().getSnapshotId().getName();
unmatchedRequiredNames.remove(snapshotName);
if (snapshotNamePredicate.test(snapshotName, true)) {
toResolve.add(snapshotInProgress.snapshot());
}
}

if (repositoryData != null) {
for (final var snapshotId : repositoryData.getSnapshotIds()) {
final var snapshotName = snapshotId.getName();
unmatchedRequiredNames.remove(snapshotName);
if (snapshotNamePredicate.test(snapshotName, false) && matchesPredicates(snapshotId, repositoryData)) {
toResolve.add(new Snapshot(repo, snapshotId));
}
}
}

if (unmatchedRequiredNames.isEmpty() == false) {
throw new SnapshotMissingException(repo, unmatchedRequiredNames.iterator().next());
}
cancellableTask.ensureNotCancelled();
ensureRequiredNamesPresent(repositoryName, repositoryData);

if (verbose) {
loadSnapshotInfos(repo, toResolve.stream().map(Snapshot::getSnapshotId).toList(), listener);
loadSnapshotInfos(repositoryName, getSnapshotIdIterator(repositoryName, repositoryData), listener);
} else {
assert fromSortValuePredicates.isMatchAll() : "filtering is not supported in non-verbose mode";
assert slmPolicyPredicate == SlmPolicyPredicate.MATCH_ALL_POLICIES : "filtering is not supported in non-verbose mode";

addSimpleSnapshotInfos(
toResolve,
repo,
getSnapshotIdIterator(repositoryName, repositoryData),
repositoryName,
repositoryData,
snapshotsInProgress.forRepo(repo).stream().map(entry -> SnapshotInfo.inProgress(entry).basic()).toList()
snapshotsInProgress.forRepo(repositoryName).stream().map(entry -> SnapshotInfo.inProgress(entry).basic()).toList()
);
listener.onResponse(null);
}
}

private void loadSnapshotInfos(String repositoryName, Collection<SnapshotId> snapshotIds, ActionListener<Void> listener) {
/**
* Check that the repository contains every <i>required</i> name according to {@link #snapshotNamePredicate}.
*
* @throws SnapshotMissingException if one or more required names are missing.
*/
private void ensureRequiredNamesPresent(String repositoryName, @Nullable RepositoryData repositoryData) {
if (snapshotNamePredicate.requiredNames().isEmpty()) {
return;
}

final var unmatchedRequiredNames = new HashSet<>(snapshotNamePredicate.requiredNames());
for (final var snapshotInProgress : snapshotsInProgress.forRepo(repositoryName)) {
unmatchedRequiredNames.remove(snapshotInProgress.snapshot().getSnapshotId().getName());
}
if (unmatchedRequiredNames.isEmpty()) {
return;
}
if (repositoryData != null) {
for (final var snapshotId : repositoryData.getSnapshotIds()) {
unmatchedRequiredNames.remove(snapshotId.getName());
}
if (unmatchedRequiredNames.isEmpty()) {
return;
}
}
throw new SnapshotMissingException(repositoryName, unmatchedRequiredNames.iterator().next());
}

/**
* @return an iterator over the snapshot IDs in the given repository which match {@link #snapshotNamePredicate}.
*/
private Iterator<SnapshotId> getSnapshotIdIterator(String repositoryName, @Nullable RepositoryData repositoryData) {

// now iterate through the snapshots again, returning matching IDs (or null)
final Set<SnapshotId> matchingInProgressSnapshots = new HashSet<>();
return Iterators.concat(
// matching in-progress snapshots first
Iterators.filter(
Iterators.map(
Comment on lines +338 to +339
Copy link
Contributor

@mhl-b mhl-b Jul 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR. Why you prefer static iter constructors over method chaining like stream api. Reading inside-out not fun at all :) It took me a few round-trips to follow the code.

Does not looks too complex to extend current iterators. Something like this:

        snapshotsInProgress.forRepo(repositoryName)
            .iterator()
            .map(snapshotInProgress -> snapshotInProgress.snapshot().getSnapshotId())
            .filter(snapshotId -> {
                if (snapshotNamePredicate.test(snapshotId.getName(), true)) {
                    matchingInProgressSnapshots.add(snapshotId);
                    return true;
                } else {
                    return false;
                }
            })
            .concat(
                () -> repositoryData == null
                    ? Collections.emptyIterator()
                    : repositoryData.getSnapshotIds()
                        .iterator()
                        .filter(
                            snapshotId -> matchingInProgressSnapshots.contains(snapshotId) == false
                                && snapshotNamePredicate.test(snapshotId.getName(), false)
                                && matchesPredicates(snapshotId, repositoryData)
                        ));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stream APIs are very much more expressive than simple iterators, which is all very nice, but it turns out that this means they're outrageously expensive at runtime as a result.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didnt mean to use stream, but rather allow iterator have non-static method to create another iter using map or filter. It might call current static method that we have. So we can chain them, I dont think it cost much.

    class Iterator <T>{
        public <U> Iterator<U> map(Function<T, U> mapFn) {
            return Iterators.map(this, mapFn);
        }
        
        public Iterator<T> filter(Predicate<T> filterFn) {
            return Iterators.filter(this, filterFn);
        }
        
        public static void main(String[] args) {
            var iter = new Iterator<String>();
            iter.map(String::length)
                .filter(l -> l>0);
        }
    }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right yes I'd love to do that but this is java.util.Iterator<T>, it's in the JDK, so not something to which we can add methods ourselves. We could have our own Iterator whose interface we could extend, but then we'd end up having to add layers of wrapping to adapt it into the JDK one and back again and it'd end up being fairly messy in practice.

snapshotsInProgress.forRepo(repositoryName).iterator(),
snapshotInProgress -> snapshotInProgress.snapshot().getSnapshotId()
),
snapshotId -> {
if (snapshotNamePredicate.test(snapshotId.getName(), true)) {
matchingInProgressSnapshots.add(snapshotId);
return true;
} else {
return false;
}
}
),
repositoryData == null
// only returning in-progress snapshots
? Collections.emptyIterator()
// also return matching completed snapshots (except any ones that were also found to be in-progress)
: Iterators.filter(
repositoryData.getSnapshotIds().iterator(),
snapshotId -> matchingInProgressSnapshots.contains(snapshotId) == false
&& snapshotNamePredicate.test(snapshotId.getName(), false)
&& matchesPredicates(snapshotId, repositoryData)
)
);
}

private void loadSnapshotInfos(String repositoryName, Iterator<SnapshotId> snapshotIdIterator, ActionListener<Void> listener) {
if (cancellableTask.notifyIfCancelled(listener)) {
return;
}
final AtomicInteger repositoryTotalCount = new AtomicInteger();
final List<SnapshotInfo> snapshots = new ArrayList<>(snapshotIds.size());
final Set<SnapshotId> snapshotIdsToIterate = new HashSet<>(snapshotIds);
final Set<SnapshotId> snapshotIdsToIterate = new HashSet<>();
snapshotIdIterator.forEachRemaining(snapshotIdsToIterate::add);

final List<SnapshotInfo> snapshots = new ArrayList<>(snapshotIdsToIterate.size());
// first, look at the snapshots in progress
final List<SnapshotsInProgress.Entry> entries = SnapshotsService.currentSnapshots(
snapshotsInProgress,
Expand Down Expand Up @@ -409,7 +448,7 @@ public void onFailure(Exception e) {
}
})

// no need to synchronize access to snapshots: Repository#getSnapshotInfo fails fast but we're on the success path here
// no need to synchronize access to snapshots: all writes happen-before this read
.andThenAccept(ignored -> addResults(repositoryTotalCount.get(), snapshots))

.addListener(listener);
Expand All @@ -422,9 +461,9 @@ private void addResults(int repositoryTotalCount, List<SnapshotInfo> snapshots)
}

private void addSimpleSnapshotInfos(
final Set<Snapshot> toResolve,
final String repoName,
final RepositoryData repositoryData,
final Iterator<SnapshotId> snapshotIdIterator,
final String repositoryName,
@Nullable final RepositoryData repositoryData,
final List<SnapshotInfo> currentSnapshots
) {
if (repositoryData == null) {
Expand All @@ -433,11 +472,14 @@ private void addSimpleSnapshotInfos(
return;
} // else want non-current snapshots as well, which are found in the repository data

final Set<SnapshotId> toResolve = new HashSet<>();
snapshotIdIterator.forEachRemaining(toResolve::add);

List<SnapshotInfo> snapshotInfos = new ArrayList<>(currentSnapshots.size() + toResolve.size());
int repositoryTotalCount = 0;
for (SnapshotInfo snapshotInfo : currentSnapshots) {
assert snapshotInfo.startTime() == 0L && snapshotInfo.endTime() == 0L && snapshotInfo.totalShards() == 0L : snapshotInfo;
if (toResolve.remove(snapshotInfo.snapshot())) {
if (toResolve.remove(snapshotInfo.snapshot().getSnapshotId())) {
repositoryTotalCount += 1;
if (afterPredicate.test(snapshotInfo)) {
snapshotInfos.add(snapshotInfo);
Expand All @@ -448,19 +490,19 @@ private void addSimpleSnapshotInfos(
if (indices) {
for (IndexId indexId : repositoryData.getIndices().values()) {
for (SnapshotId snapshotId : repositoryData.getSnapshots(indexId)) {
if (toResolve.contains(new Snapshot(repoName, snapshotId))) {
if (toResolve.contains(snapshotId)) {
snapshotsToIndices.computeIfAbsent(snapshotId, (k) -> new ArrayList<>()).add(indexId.getName());
}
}
}
}
for (Snapshot snapshot : toResolve) {
for (SnapshotId snapshotId : toResolve) {
final var snapshotInfo = new SnapshotInfo(
snapshot,
snapshotsToIndices.getOrDefault(snapshot.getSnapshotId(), Collections.emptyList()),
new Snapshot(repositoryName, snapshotId),
snapshotsToIndices.getOrDefault(snapshotId, Collections.emptyList()),
Collections.emptyList(),
Collections.emptyList(),
repositoryData.getSnapshotState(snapshot.getSnapshotId())
repositoryData.getSnapshotState(snapshotId)
);
repositoryTotalCount += 1;
if (afterPredicate.test(snapshotInfo)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.function.ToIntFunction;

Expand Down Expand Up @@ -179,6 +180,59 @@ public void forEachRemaining(Consumer<? super U> action) {
}
}

/**
* @param input An iterator over <i>non-null</i> values.
* @param predicate The predicate with which to filter the input.
* @return an iterator which returns the values from {@code input} which match {@code predicate}.
*/
public static <T> Iterator<T> filter(Iterator<? extends T> input, Predicate<T> predicate) {
while (input.hasNext()) {
final var value = input.next();
assert value != null;
if (predicate.test(value)) {
return new FilterIterator<>(value, input, predicate);
}
}
Comment on lines +189 to +195
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this? It's same as FilterIterator.next. Also it might be not expected that constructor of FilterIterator will start pulling items, it should be done lazily by explicit next call, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, no, and that's not how any of the other nearby iterator combinators work either.

Copy link
Contributor

@mhl-b mhl-b Jul 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, nearby iterators do same thing, I didnt notice.
But I dont understand why its important to start iteration in constructor until we reach first element or exhaust iter. My first assumption would be creating iter is cheap and doesnt do anything other than allocation of a few pointers, I might even throw it away later without use, at least thats what rust does. But with this implementation it might do busy-work.

return Collections.emptyIterator();
}

private static final class FilterIterator<T> implements Iterator<T> {
private final Iterator<? extends T> input;
private final Predicate<T> predicate;
private T next;

FilterIterator(T value, Iterator<? extends T> input, Predicate<T> predicate) {
this.next = value;
this.input = input;
this.predicate = predicate;
assert next != null;
assert predicate.test(next);
}

@Override
public boolean hasNext() {
return next != null;
}

@Override
public T next() {
if (hasNext() == false) {
throw new NoSuchElementException();
}
final var value = next;
while (input.hasNext()) {
final var laterValue = input.next();
assert laterValue != null;
if (predicate.test(laterValue)) {
next = laterValue;
return value;
}
}
next = null;
return value;
}
}

public static <T, U> Iterator<U> flatMap(Iterator<? extends T> input, Function<T, Iterator<? extends U>> fn) {
while (input.hasNext()) {
final var value = fn.apply(input.next());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.elasticsearch.common.collect;

import org.elasticsearch.common.Randomness;
import org.elasticsearch.core.Assertions;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.test.ESTestCase;

Expand All @@ -23,6 +24,7 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiPredicate;
import java.util.function.Predicate;
import java.util.function.ToIntFunction;
import java.util.stream.IntStream;

Expand Down Expand Up @@ -219,6 +221,27 @@ public void testMap() {
assertEquals(array.length, index.get());
}

public void testFilter() {
assertSame(Collections.emptyIterator(), Iterators.filter(Collections.emptyIterator(), i -> fail(null, "not called")));

final var array = randomIntegerArray();
assertSame(Collections.emptyIterator(), Iterators.filter(Iterators.forArray(array), i -> false));

final var threshold = array.length > 0 && randomBoolean() ? randomFrom(array) : randomIntBetween(0, 1000);
final Predicate<Integer> predicate = i -> i <= threshold;
final var expectedResults = Arrays.stream(array).filter(predicate).toList();
final var index = new AtomicInteger();
Iterators.filter(Iterators.forArray(array), predicate)
.forEachRemaining(i -> assertEquals(expectedResults.get(index.getAndIncrement()), i));

if (Assertions.ENABLED) {
final var predicateCalled = new AtomicBoolean();
final var inputIterator = Iterators.forArray(new Object[] { null });
expectThrows(AssertionError.class, () -> Iterators.filter(inputIterator, i -> predicateCalled.compareAndSet(false, true)));
assertFalse(predicateCalled.get());
}
}

public void testFailFast() {
final var array = randomIntegerArray();
assertEmptyIterator(Iterators.failFast(Iterators.forArray(array), () -> true));
Expand Down