diff --git a/server/src/main/java/org/elasticsearch/repositories/blobstore/ChunkedBlobOutputStream.java b/server/src/main/java/org/elasticsearch/repositories/blobstore/ChunkedBlobOutputStream.java new file mode 100644 index 0000000000000..bf3c99019443d --- /dev/null +++ b/server/src/main/java/org/elasticsearch/repositories/blobstore/ChunkedBlobOutputStream.java @@ -0,0 +1,152 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ +package org.elasticsearch.repositories.blobstore; + +import org.apache.lucene.store.AlreadyClosedException; +import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.core.Releasables; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.List; + +/** + * Base class for doing chunked writes to a blob store. Some blob stores require either up-front knowledge of the size of the blob that + * will be written or writing it in chunks that are then joined into the final blob at the end of the write. This class provides a basis + * on which to implement an output stream that encapsulates such a chunked write. + * + * @param type of chunk identifier + */ +public abstract class ChunkedBlobOutputStream extends OutputStream { + + /** + * List of identifiers of already written chunks. + */ + protected final List parts = new ArrayList<>(); + + /** + * Size of the write buffer above which it must be flushed to storage. + */ + private final long maxBytesToBuffer; + + /** + * Big arrays to be able to allocate buffers from pooled bytes. + */ + private final BigArrays bigArrays; + + /** + * Current write buffer. + */ + protected ReleasableBytesStreamOutput buffer; + + /** + * Set to true once no more calls to {@link #write} are expected and the blob has been received by {@link #write} in full so that + * {@link #close()} knows whether to clean up existing chunks or finish a chunked write. + */ + protected boolean successful = false; + + /** + * Is set to {@code true} once this stream has been closed. + */ + private boolean closed = false; + + /** + * Number of bytes flushed to blob storage so far. + */ + protected long flushedBytes = 0L; + + protected ChunkedBlobOutputStream(BigArrays bigArrays, long maxBytesToBuffer) { + this.bigArrays = bigArrays; + if (maxBytesToBuffer <= 0) { + throw new IllegalArgumentException("maximum buffer size must be positive"); + } + this.maxBytesToBuffer = maxBytesToBuffer; + buffer = new ReleasableBytesStreamOutput(bigArrays); + } + + @Override + public final void write(int b) throws IOException { + buffer.write(b); + maybeFlushBuffer(); + } + + @Override + public final void write(byte[] b, int off, int len) throws IOException { + buffer.write(b, off, len); + maybeFlushBuffer(); + } + + @Override + public final void close() throws IOException { + if (closed) { + assert false : "this output stream should only be closed once"; + throw new AlreadyClosedException("already closed"); + } + closed = true; + try { + if (successful) { + onCompletion(); + } else { + onFailure(); + } + } finally { + Releasables.close(buffer); + } + } + + /** + * Mark all blob bytes as properly received by {@link #write}, indicating that {@link #close} may finalize the blob. + */ + public final void markSuccess() { + this.successful = true; + } + + /** + * Finish writing the current buffer contents to storage and track them by the given {@code partId}. Depending on whether all contents + * have already been written either prepare the write buffer for additional writes or release the buffer. + * + * @param partId part identifier to track for use when closing + */ + protected final void finishPart(T partId) { + flushedBytes += buffer.size(); + parts.add(partId); + buffer.close(); + // only need a new buffer if we're not done yet + if (successful) { + buffer = null; + } else { + buffer = new ReleasableBytesStreamOutput(bigArrays); + } + } + + /** + * Write the contents of {@link #buffer} to storage. Implementations should call {@link #finishPart} at the end to track the the chunk + * of data just written and ready {@link #buffer} for the next write. + */ + protected abstract void flushBuffer() throws IOException; + + /** + * Invoked once all write chunks/parts are ready to be combined into the final blob. Implementations must invoke the necessary logic + * for combining the uploaded chunks into the final blob in this method. + */ + protected abstract void onCompletion() throws IOException; + + /** + * Invoked in case writing all chunks of data to storage failed. Implementations should run any cleanup required for the already + * written data in this method. + */ + protected abstract void onFailure(); + + private void maybeFlushBuffer() throws IOException { + if (buffer.size() >= maxBytesToBuffer) { + flushBuffer(); + } + } +} diff --git a/server/src/test/java/org/elasticsearch/repositories/blobstore/ChunkedBlobOutputStreamTests.java b/server/src/test/java/org/elasticsearch/repositories/blobstore/ChunkedBlobOutputStreamTests.java new file mode 100644 index 0000000000000..8dc2b8d21b2d9 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/repositories/blobstore/ChunkedBlobOutputStreamTests.java @@ -0,0 +1,149 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.repositories.blobstore; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.common.util.MockPageCacheRecycler; +import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.zip.CRC32; +import java.util.zip.CheckedOutputStream; + +public class ChunkedBlobOutputStreamTests extends ESTestCase { + + private BigArrays bigArrays; + + @Override + public void setUp() throws Exception { + bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); + super.setUp(); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + } + + public void testSuccessfulChunkedWrite() throws IOException { + final long chunkSize = randomLongBetween(10, 1024); + final CRC32 checksumIn = new CRC32(); + final CRC32 checksumOut = new CRC32(); + final CheckedOutputStream out = new CheckedOutputStream(OutputStream.nullOutputStream(), checksumOut); + final AtomicLong writtenBytesCounter = new AtomicLong(0L); + final long bytesToWrite = randomLongBetween(chunkSize - 5, 1000 * chunkSize); + long written = 0; + try (ChunkedBlobOutputStream stream = new ChunkedBlobOutputStream<>(bigArrays, chunkSize) { + + private final AtomicInteger partIdSupplier = new AtomicInteger(); + + @Override + protected void flushBuffer() throws IOException { + final BytesReference bytes = buffer.bytes(); + bytes.writeTo(out); + writtenBytesCounter.addAndGet(bytes.length()); + finishPart(partIdSupplier.incrementAndGet()); + } + + @Override + protected void onCompletion() throws IOException { + if (buffer.size() > 0) { + flushBuffer(); + } + out.flush(); + for (int i = 0; i < partIdSupplier.get(); i++) { + assertEquals((long) i + 1, (long) parts.get(i)); + } + } + + @Override + protected void onFailure() { + fail("not supposed to fail"); + } + }) { + final byte[] buffer = new byte[randomInt(Math.toIntExact(2 * chunkSize)) + 1]; + while (written < bytesToWrite) { + if (randomBoolean()) { + random().nextBytes(buffer); + final int offset = randomInt(buffer.length - 2) + 1; + final int length = Math.toIntExact(Math.min(bytesToWrite - written, buffer.length - offset)); + stream.write(buffer, offset, length); + checksumIn.update(buffer, offset, length); + written += length; + } else { + int oneByte = randomByte(); + stream.write(oneByte); + checksumIn.update(oneByte); + written++; + } + } + stream.markSuccess(); + } + assertEquals(bytesToWrite, written); + assertEquals(bytesToWrite, writtenBytesCounter.get()); + assertEquals(checksumIn.getValue(), checksumOut.getValue()); + } + + public void testExceptionDuringChunkedWrite() throws IOException { + final long chunkSize = randomLongBetween(10, 1024); + final AtomicLong writtenBytesCounter = new AtomicLong(0L); + final long bytesToWrite = randomLongBetween(chunkSize - 5, 1000 * chunkSize); + long written = 0; + final AtomicBoolean onFailureCalled = new AtomicBoolean(false); + try (ChunkedBlobOutputStream stream = new ChunkedBlobOutputStream<>(bigArrays, chunkSize) { + + private final AtomicInteger partIdSupplier = new AtomicInteger(); + + @Override + protected void flushBuffer() { + writtenBytesCounter.addAndGet(buffer.size()); + finishPart(partIdSupplier.incrementAndGet()); + } + + @Override + protected void onCompletion() { + fail("supposed to fail"); + } + + @Override + protected void onFailure() { + for (int i = 0; i < partIdSupplier.get(); i++) { + assertEquals((long) i + 1, (long) parts.get(i)); + } + assertTrue(onFailureCalled.compareAndSet(false, true)); + } + }) { + final byte[] buffer = new byte[randomInt(Math.toIntExact(2 * chunkSize)) + 1]; + while (written < bytesToWrite) { + if (rarely()) { + break; + } else if (randomBoolean()) { + random().nextBytes(buffer); + final int offset = randomInt(buffer.length - 2) + 1; + final int length = Math.toIntExact(Math.min(bytesToWrite - written, buffer.length - offset)); + stream.write(buffer, offset, length); + written += length; + } else { + int oneByte = randomByte(); + stream.write(oneByte); + written++; + } + } + } + assertTrue(onFailureCalled.get()); + } +}