Skip to content

memory improvement for ReadMessageAsync compression #2629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 112 additions & 37 deletions src/Grpc.Net.Client/Internal/StreamExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

using System.Buffers;
using System.Buffers.Binary;
using System.Collections;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.InteropServices;
Expand Down Expand Up @@ -99,27 +100,15 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
var compressed = ReadCompressedFlag(buffer[0]);
var length = ReadMessageLength(buffer.AsSpan(1, 4));

if (length > 0)
if (length > call.Channel.ReceiveMaxMessageSize)
{
if (length > call.Channel.ReceiveMaxMessageSize)
{
throw call.CreateRpcException(ReceivedMessageExceedsLimitStatus);
}

// Replace buffer if the message doesn't fit
if (buffer.Length < length)
{
ArrayPool<byte>.Shared.Return(buffer);
buffer = ArrayPool<byte>.Shared.Rent(length);
}

await ReadMessageContentAsync(responseStream, buffer, length, cancellationToken).ConfigureAwait(false);
throw call.CreateRpcException(ReceivedMessageExceedsLimitStatus);
}

cancellationToken.ThrowIfCancellationRequested();

ReadOnlySequence<byte> payload;
if (compressed)
TResponse message;
if (compressed && length > 0)
{
if (grpcEncoding == null)
{
Expand All @@ -130,28 +119,55 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
throw call.CreateRpcException(IdentityMessageEncodingMessageStatus);
}

// Performance improvement would be to decompress without converting to an intermediary byte array
if (!TryDecompressMessage(call.Logger, grpcEncoding, call.Channel.CompressionProviders, buffer, length, out var decompressedMessage))
if (call.Channel.CompressionProviders.TryGetValue(grpcEncoding, out var compressionProvider))
{
GrpcCallLog.DecompressingMessage(call.Logger, compressionProvider.EncodingName);
var moreBuffers = new List<byte[]>();
try
{
int lastLength;
using (var compressionStream = compressionProvider.CreateDecompressionStream(new FixedLengthStream(responseStream, length)))
{
var underLohLength = Math.Min(Math.Max(4096, length), 65536);
lastLength = await ReadStreamToBuffers(compressionStream, buffer, moreBuffers, underLohLength, cancellationToken).ConfigureAwait(false);
}
call.DeserializationContext.SetPayload(BuffersToReadOnlySequence(buffer, moreBuffers, lastLength));
message = deserializer(call.DeserializationContext);
}
finally
{
foreach (var byteArray in moreBuffers)
{
ArrayPool<byte>.Shared.Return(byteArray);
}
}
}
else
{
var supportedEncodings = new List<string>();
supportedEncodings.Add(GrpcProtocolConstants.IdentityGrpcEncoding);
var supportedEncodings = new List<string>(call.Channel.CompressionProviders.Count + 1) { GrpcProtocolConstants.IdentityGrpcEncoding };
supportedEncodings.AddRange(call.Channel.CompressionProviders.Select(c => c.Key));
throw call.CreateRpcException(CreateUnknownMessageEncodingMessageStatus(grpcEncoding, supportedEncodings));
}

payload = decompressedMessage;
}
else
{
payload = new ReadOnlySequence<byte>(buffer, 0, length);
if (length > 0)
{
// Replace buffer if the message doesn't fit
if (buffer.Length < length)
{
ArrayPool<byte>.Shared.Return(buffer);
buffer = ArrayPool<byte>.Shared.Rent(length);
}
await ReadMessageContentAsync(responseStream, buffer, length, cancellationToken).ConfigureAwait(false);
}
call.DeserializationContext.SetPayload(new ReadOnlySequence<byte>(buffer, 0, length));
message = deserializer(call.DeserializationContext);
}
call.DeserializationContext.SetPayload(null);

GrpcCallLog.DeserializingMessage(call.Logger, length, typeof(TResponse));

call.DeserializationContext.SetPayload(payload);
var message = deserializer(call.DeserializationContext);
call.DeserializationContext.SetPayload(null);

if (singleMessage)
{
// Check that there is no additional content in the stream for a single message
Expand Down Expand Up @@ -251,24 +267,83 @@ private static async Task ReadMessageContentAsync(Stream responseStream, Memory<
}
}

private static bool TryDecompressMessage(ILogger logger, string compressionEncoding, Dictionary<string, ICompressionProvider> compressionProviders, byte[] messageData, int length, out ReadOnlySequence<byte> result)
private sealed class FixedLengthStream(Stream stream, int length) : Stream
{
if (compressionProviders.TryGetValue(compressionEncoding, out var compressionProvider))
private int _bytesRead;
public override int Read(byte[] buffer, int offset, int count)
{
GrpcCallLog.DecompressingMessage(logger, compressionProvider.EncodingName);
var bytesToRead = Math.Min(count, length - _bytesRead);
if (bytesToRead <= 0)
{
return 0;
}
var bytesRead = stream.Read(buffer, offset, bytesToRead);
if (bytesRead == 0)
{
throw new InvalidDataException("Unexpected end of content while reading the message content.");
}
_bytesRead += bytesRead;
return bytesRead;
}
public override bool CanRead => true;
public override bool CanSeek => false;
public override bool CanWrite => false;
public override long Length => length;
public override long Position { get => _bytesRead; set => throw new NotSupportedException(); }
public override void Flush() => throw new NotSupportedException();
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
public override void SetLength(long value) => throw new NotSupportedException();
public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException();
}

var output = new MemoryStream();
using (var compressionStream = compressionProvider.CreateDecompressionStream(new MemoryStream(messageData, 0, length, writable: true, publiclyVisible: true)))
private static async Task<int> ReadStreamToBuffers(Stream stream, byte[] buffer, List<byte[]> moreBuffers, int moreLength, CancellationToken cancellationToken)
{
while (true)
{
var offset = 0;
while (offset < buffer.Length)
{
compressionStream.CopyTo(output);
var read = await stream.ReadAsync(buffer.AsMemory(offset), cancellationToken).ConfigureAwait(false);
if (read == 0)
{
return offset;
}
offset += read;
}
moreBuffers.Add(buffer = ArrayPool<byte>.Shared.Rent(moreLength));
}
}

result = new ReadOnlySequence<byte>(output.GetBuffer(), 0, (int)output.Length);
return true;
private static ReadOnlySequence<byte> BuffersToReadOnlySequence(byte[] buffer, List<byte[]> moreBuffers, int lastLength)
{
if (moreBuffers.Count == 0)
{
return new ReadOnlySequence<byte>(buffer, 0, lastLength);
}
var runningIndex = buffer.Length;
for (var i = moreBuffers.Count - 2; i >= 0; i--)
{
runningIndex += moreBuffers[i].Length;
}
var endSegment = new ReadOnlySequenceSegmentByte(moreBuffers[moreBuffers.Count - 1].AsMemory(0, lastLength), null, runningIndex);
var startSegment = endSegment;
for (var i = moreBuffers.Count - 2; i >= 0; i--)
{
var bytes = moreBuffers[i];
startSegment = new ReadOnlySequenceSegmentByte(bytes, startSegment, runningIndex -= bytes.Length);
}
startSegment = new ReadOnlySequenceSegmentByte(buffer, startSegment, 0);
return new ReadOnlySequence<byte>(startSegment, 0, endSegment, lastLength);
}

result = default;
return false;
private sealed class ReadOnlySequenceSegmentByte : ReadOnlySequenceSegment<byte>
{
public ReadOnlySequenceSegmentByte(ReadOnlyMemory<byte> memory, ReadOnlySequenceSegmentByte? next, int runningIndex)
{
Memory = memory;
Next = next;
RunningIndex = runningIndex;
}
}

private static bool ReadCompressedFlag(byte flag)
Expand Down