Skip to content

Reduce memory usage for chunk-encoded streaming uploads, like those used by flexible checksums in S3. #4858

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/next-release/feature-AWSSDKforJavav2-a5aec87.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "feature",
"category": "AWS SDK for Java v2",
"contributor": "",
"description": "Reduce how many times input data is copied when writing to chunked encoded operations, like S3's PutObject."
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public ContentStreamProvider sign(ContentStreamProvider payload, V4aRequestSigni
.builder()
.inputStream(inputStream)
.chunkSize(chunkSize)
.header(chunk -> Integer.toHexString(chunk.length).getBytes(StandardCharsets.UTF_8));
.header(chunk -> Integer.toHexString(chunk.remaining()).getBytes(StandardCharsets.UTF_8));

preExistingTrailers.forEach(trailer -> chunkedEncodedInputStreamBuilder.addTrailer(() -> trailer));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

package software.amazon.awssdk.http.auth.aws.crt.internal.signer;

import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -45,14 +46,14 @@ public RollingSigner(byte[] seedSignature, AwsSigningConfig signingConfig) {
this.signingConfig = signingConfig;
}

private static byte[] signChunk(byte[] chunkBody, byte[] previousSignature, AwsSigningConfig signingConfig) {
private static byte[] signChunk(ByteBuffer chunkBody, byte[] previousSignature, AwsSigningConfig signingConfig) {
// All the config remains the same as signing config except the Signature Type.
AwsSigningConfig configCopy = signingConfig.clone();
configCopy.setSignatureType(AwsSigningConfig.AwsSignatureType.HTTP_REQUEST_CHUNK);
configCopy.setSignedBodyHeader(AwsSigningConfig.AwsSignedBodyHeaderType.NONE);
configCopy.setSignedBodyValue(null);

HttpRequestBodyStream crtBody = new CrtInputStream(() -> new ByteArrayInputStream(chunkBody));
HttpRequestBodyStream crtBody = new CrtInputStream(() -> new ByteBufferBackedInputStream(chunkBody));
return CompletableFutureUtils.joinLikeSync(AwsSigner.signChunk(crtBody, previousSignature, configCopy));
}

Expand All @@ -75,7 +76,7 @@ private static AwsSigningResult signTrailerHeaders(Map<String, List<String>> hea
/**
* Using a template that incorporates the previous calculated signature, sign the string and return it.
*/
public byte[] sign(byte[] chunkBody) {
public byte[] sign(ByteBuffer chunkBody) {
previousSignature = signChunk(chunkBody, previousSignature, signingConfig);
return previousSignature;
}
Expand All @@ -89,4 +90,29 @@ public byte[] sign(Map<String, List<String>> headerMap) {
public void reset() {
previousSignature = seedSignature;
}

private static class ByteBufferBackedInputStream extends InputStream {
private final ByteBuffer buf;

private ByteBufferBackedInputStream(ByteBuffer buf) {
this.buf = buf;
}

public int read() {
if (!buf.hasRemaining()) {
return -1;
}
return buf.get() & 0xFF;
}

public int read(byte[] bytes, int off, int len) {
if (!buf.hasRemaining()) {
return -1;
}

len = Math.min(len, buf.remaining());
buf.get(bytes, off, len);
return len;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package software.amazon.awssdk.http.auth.aws.crt.internal.signer;

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.http.auth.aws.internal.signer.CredentialScope;
Expand All @@ -38,11 +39,8 @@ public void reset() {
}

@Override
public Pair<byte[], byte[]> get(byte[] chunk) {
public Pair<byte[], byte[]> get(ByteBuffer chunk) {
byte[] chunkSig = signer.sign(chunk);
return Pair.of(
"chunk-signature".getBytes(StandardCharsets.UTF_8),
chunkSig
);
return Pair.of("chunk-signature".getBytes(StandardCharsets.UTF_8), chunkSig);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public ContentStreamProvider sign(ContentStreamProvider payload, V4RequestSignin
.builder()
.inputStream(payload.newStream())
.chunkSize(chunkSize)
.header(chunk -> Integer.toHexString(chunk.length).getBytes(StandardCharsets.UTF_8));
.header(chunk -> Integer.toHexString(chunk.remaining()).getBytes(StandardCharsets.UTF_8));

preExistingTrailers.forEach(trailer -> chunkedEncodedInputStreamBuilder.addTrailer(() -> trailer));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding;

import java.nio.ByteBuffer;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.utils.Pair;

Expand All @@ -32,5 +33,5 @@
@FunctionalInterface
@SdkInternalApi
public interface ChunkExtensionProvider extends Resettable {
Pair<byte[], byte[]> get(byte[] chunk);
Pair<byte[], byte[]> get(ByteBuffer chunk);
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding;

import java.nio.ByteBuffer;
import software.amazon.awssdk.annotations.SdkInternalApi;

/**
Expand All @@ -27,5 +28,5 @@
@FunctionalInterface
@SdkInternalApi
public interface ChunkHeaderProvider extends Resettable {
byte[] get(byte[] chunk);
byte[] get(ByteBuffer chunk);
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
package software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.SequenceInputStream;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.utils.Logger;
Expand Down Expand Up @@ -52,6 +53,10 @@ public final class ChunkedEncodedInputStream extends InputStream {
private static final Logger LOG = Logger.loggerFor(ChunkedEncodedInputStream.class);
private static final byte[] CRLF = {'\r', '\n'};
private static final byte[] END = {};
private static final byte[] SEMICOLON = {';'};
private static final byte[] EQUALS = {'='};
private static final byte[] COLON = {':'};
private static final byte[] COMMA = {','};

private final InputStream inputStream;
private final int chunkSize;
Expand Down Expand Up @@ -101,14 +106,14 @@ private Chunk getChunk(InputStream stream) throws IOException {
if (currentChunk != null) {
currentChunk.close();
}
// we *have* to read from the backing stream in order to figure out if it's the end or not
// TODO(sra-identity-and-auth): We can likely optimize this by not copying the entire chunk of data into memory

// We have to read from the input stream into a format that can be used for signing and headers.
byte[] chunkData = new byte[chunkSize];
int read = read(stream, chunkData, chunkSize);

if (read > 0) {
// set the current chunk to the newly written chunk
return getNextChunk(Arrays.copyOf(chunkData, read));
return getNextChunk(ByteBuffer.wrap(chunkData, 0, read));
}

LOG.debug(() -> "End of backing stream reached. Reading final chunk.");
Expand Down Expand Up @@ -142,58 +147,71 @@ private int read(InputStream inputStream, byte[] buf, int maxBytesToRead) throws
* Create a chunk from a byte-array, which includes the header, the extensions, and the chunk data. The input array should be
* correctly sized, i.e. the number of bytes should equal its length.
*/
private Chunk getNextChunk(byte[] data) throws IOException {
ByteArrayOutputStream chunkStream = new ByteArrayOutputStream();
writeChunk(data, chunkStream);
chunkStream.write(CRLF);
byte[] newChunkData = chunkStream.toByteArray();

return Chunk.create(new ByteArrayInputStream(newChunkData), newChunkData.length);
private Chunk getNextChunk(ByteBuffer data) {
LengthAwareSequenceInputStream newChunkData =
LengthAwareSequenceInputStream.builder()
.add(createChunkStream(data))
.add(CRLF)
.build();
return Chunk.create(newChunkData, newChunkData.size);
}

/**
* Create the final chunk, which includes the header, the extensions, the chunk (if applicable), and the trailer
*/
private Chunk getFinalChunk() throws IOException {
ByteArrayOutputStream chunkStream = new ByteArrayOutputStream();
writeChunk(END, chunkStream);
writeTrailers(chunkStream);
chunkStream.write(CRLF);
byte[] newChunkData = chunkStream.toByteArray();

return Chunk.create(new ByteArrayInputStream(newChunkData), newChunkData.length);
LengthAwareSequenceInputStream chunkData =
LengthAwareSequenceInputStream.builder()
.add(createChunkStream(ByteBuffer.wrap(END)))
.add(createTrailerStream())
.add(CRLF)
.build();

return Chunk.create(chunkData, chunkData.size);
}

private void writeChunk(byte[] chunk, ByteArrayOutputStream outputStream) throws IOException {
writeHeader(chunk, outputStream);
writeExtensions(chunk, outputStream);
outputStream.write(CRLF);
outputStream.write(chunk);
private LengthAwareSequenceInputStream createChunkStream(ByteBuffer chunkData) {
return LengthAwareSequenceInputStream.builder()
.add(createHeaderStream(chunkData.asReadOnlyBuffer()))
.add(createExtensionsStream(chunkData.asReadOnlyBuffer()))
.add(CRLF)
.add(new ByteArrayInputStream(chunkData.array(),
chunkData.arrayOffset(),
chunkData.remaining()))
.build();
}

private void writeHeader(byte[] chunk, ByteArrayOutputStream outputStream) throws IOException {
byte[] hdr = header.get(chunk);
outputStream.write(hdr);
private ByteArrayInputStream createHeaderStream(ByteBuffer chunkData) {
return new ByteArrayInputStream(header.get(chunkData));
}

private void writeExtensions(byte[] chunk, ByteArrayOutputStream outputStream) throws IOException {
private LengthAwareSequenceInputStream createExtensionsStream(ByteBuffer chunkData) {
LengthAwareSequenceInputStream.Builder result = LengthAwareSequenceInputStream.builder();
for (ChunkExtensionProvider chunkExtensionProvider : extensions) {
Pair<byte[], byte[]> ext = chunkExtensionProvider.get(chunk);
outputStream.write((byte) ';');
outputStream.write(ext.left());
outputStream.write((byte) '=');
outputStream.write(ext.right());
Pair<byte[], byte[]> ext = chunkExtensionProvider.get(chunkData);
result.add(SEMICOLON);
result.add(ext.left());
result.add(EQUALS);
result.add(ext.right());
}
return result.build();
}

private void writeTrailers(ByteArrayOutputStream outputStream) throws IOException {
private LengthAwareSequenceInputStream createTrailerStream() throws IOException {
LengthAwareSequenceInputStream.Builder result = LengthAwareSequenceInputStream.builder();
for (TrailerProvider trailer : trailers) {
Pair<String, List<String>> tlr = trailer.get();
outputStream.write(tlr.left().getBytes(StandardCharsets.UTF_8));
outputStream.write((byte) ':');
outputStream.write(String.join(",", tlr.right()).getBytes(StandardCharsets.UTF_8));
outputStream.write(CRLF);
result.add(tlr.left().getBytes(StandardCharsets.UTF_8));
result.add(COLON);
for (String trailerValue : tlr.right()) {
result.add(trailerValue.getBytes(StandardCharsets.UTF_8));
result.add(COMMA);
}

// Replace trailing comma with clrf
result.replaceLast(new ByteArrayInputStream(CRLF), COMMA.length);
}
return result.build();
}

@Override
Expand All @@ -216,7 +234,8 @@ public static class Builder {
private final List<TrailerProvider> trailers = new ArrayList<>();
private InputStream inputStream;
private int chunkSize;
private ChunkHeaderProvider header = chunk -> Integer.toHexString(chunk.length).getBytes(StandardCharsets.UTF_8);
private ChunkHeaderProvider header =
chunk -> Integer.toHexString(chunk.remaining()).getBytes(StandardCharsets.UTF_8);

public InputStream inputStream() {
return this.inputStream;
Expand Down Expand Up @@ -267,5 +286,51 @@ public ChunkedEncodedInputStream build() {
return new ChunkedEncodedInputStream(this);
}
}


private static class LengthAwareSequenceInputStream extends SequenceInputStream {
private final int size;

private LengthAwareSequenceInputStream(Builder builder) {
super(Collections.enumeration(builder.streams));
this.size = builder.size;
}

private static Builder builder() {
return new Builder();
}

private static class Builder {
private final List<InputStream> streams = new ArrayList<>();
private int size = 0;

public Builder add(ByteArrayInputStream stream) {
streams.add(stream);
size += stream.available();
return this;
}

public Builder add(byte[] stream) {
return add(new ByteArrayInputStream(stream));
}

public Builder add(LengthAwareSequenceInputStream stream) {
streams.add(stream);
size += stream.size;
return this;
}

public Builder replaceLast(ByteArrayInputStream stream, int lastLength) {
streams.set(streams.size() - 1, stream);
size -= lastLength;
size += stream.available();
return this;
}

public LengthAwareSequenceInputStream build() {
return new LengthAwareSequenceInputStream(this);
}
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerUtils.hash;
import static software.amazon.awssdk.utils.BinaryUtils.toHex;

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.http.auth.aws.internal.signer.CredentialScope;
Expand All @@ -42,7 +43,7 @@ public void reset() {
signer.reset();
}

private String getStringToSign(String previousSignature, byte[] chunk) {
private String getStringToSign(String previousSignature, ByteBuffer chunk) {
// build the string-to-sign template for the rolling-signer to sign
return String.join("\n",
"AWS4-HMAC-SHA256-PAYLOAD",
Expand All @@ -55,11 +56,9 @@ private String getStringToSign(String previousSignature, byte[] chunk) {
}

@Override
public Pair<byte[], byte[]> get(byte[] chunk) {
public Pair<byte[], byte[]> get(ByteBuffer chunk) {
String chunkSig = signer.sign(previousSig -> getStringToSign(previousSig, chunk));
return Pair.of(
"chunk-signature".getBytes(StandardCharsets.UTF_8),
chunkSig.getBytes(StandardCharsets.UTF_8)
);
return Pair.of("chunk-signature".getBytes(StandardCharsets.UTF_8),
chunkSig.getBytes(StandardCharsets.UTF_8));
}
}
Loading