Skip to content

feat: gRPC stream connection deadline #999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 4, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@ public final class Config {
static final String DEFAULT_HOST = "localhost";

static final int DEFAULT_DEADLINE = 500;
static final int DEFAULT_STREAM_DEADLINE_MS = 10 * 60 * 1000;
static final int DEFAULT_MAX_CACHE_SIZE = 1000;
static final long DEFAULT_KEEP_ALIVE = 0;

@@ -31,6 +32,7 @@ public final class Config {
static final String MAX_EVENT_STREAM_RETRIES_ENV_VAR_NAME = "FLAGD_MAX_EVENT_STREAM_RETRIES";
static final String BASE_EVENT_STREAM_RETRY_BACKOFF_MS_ENV_VAR_NAME = "FLAGD_RETRY_BACKOFF_MS";
static final String DEADLINE_MS_ENV_VAR_NAME = "FLAGD_DEADLINE_MS";
static final String STREAM_DEADLINE_MS_ENV_VAR_NAME = "FLAGD_STREAM_DEADLINE_MS";
static final String SOURCE_SELECTOR_ENV_VAR_NAME = "FLAGD_SOURCE_SELECTOR";
static final String OFFLINE_SOURCE_PATH = "FLAGD_OFFLINE_FLAG_SOURCE_PATH";
static final String KEEP_ALIVE_MS_ENV_VAR_NAME_OLD = "FLAGD_KEEP_ALIVE_TIME";
Original file line number Diff line number Diff line change
@@ -92,6 +92,14 @@ public class FlagdOptions {
@Builder.Default
private int deadline = fallBackToEnvOrDefault(Config.DEADLINE_MS_ENV_VAR_NAME, Config.DEFAULT_DEADLINE);

/**
* Streaming connection deadline in milliseconds.
* Set to 0 to disable the deadline.
*/
@Builder.Default
private int streamDeadlineMs = fallBackToEnvOrDefault(Config.STREAM_DEADLINE_MS_ENV_VAR_NAME,
Config.DEFAULT_STREAM_DEADLINE_MS);

/**
* Selector to be used with flag sync gRPC contract.
**/
@@ -101,7 +109,7 @@ public class FlagdOptions {
/**
* gRPC client KeepAlive in milliseconds. Disabled with 0.
* Defaults to 0 (disabled).
*
*
**/
@Builder.Default
private long keepAlive = fallBackToEnvOrDefault(Config.KEEP_ALIVE_MS_ENV_VAR_NAME,
Original file line number Diff line number Diff line change
@@ -11,6 +11,8 @@
import dev.openfeature.contrib.providers.flagd.resolver.grpc.cache.Cache;
import dev.openfeature.flagd.grpc.evaluation.Evaluation.EventStreamResponse;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import io.grpc.Status.Code;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.StreamObserver;
import lombok.extern.slf4j.Slf4j;

@@ -52,12 +54,18 @@ public void onNext(EventStreamResponse value) {
}

@Override
public void onError(Throwable t) {
log.warn("event stream", t);
if (this.cache.getEnabled()) {
this.cache.clear();
public void onError(Throwable throwable) {
if (throwable instanceof StatusRuntimeException
&& ((StatusRuntimeException) throwable).getStatus().getCode()
.equals(Code.DEADLINE_EXCEEDED)) {
log.debug(String.format("stream deadline reached; will re-establish"));
} else {
log.error(String.format("event stream error", throwable));
if (this.cache.getEnabled()) {
this.cache.clear();
}
this.onConnectionEvent.accept(false, Collections.emptyList());
}
this.onConnectionEvent.accept(false, Collections.emptyList());

// handle last call of this stream
handleEndOfStream();
Original file line number Diff line number Diff line change
@@ -36,6 +36,7 @@ public class GrpcConnector {

private final int startEventStreamRetryBackoff;
private final long deadline;
private final long streamDeadlineMs;

private final Cache cache;
private final Consumer<ConnectionEvent> onConnectionEvent;
@@ -64,6 +65,7 @@ public GrpcConnector(final FlagdOptions options, final Cache cache, final Suppli
this.startEventStreamRetryBackoff = options.getRetryBackoffMs();
this.eventStreamRetryBackoff = options.getRetryBackoffMs();
this.deadline = options.getDeadline();
this.streamDeadlineMs = options.getStreamDeadlineMs();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[suggestion] using the options object instead of separate fields would reduce this over head of adding a new field all the time

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can go either way on this one. In fact it might be a good thing to do in a pure refactor/cleanup PR. There's also some naming that we can probably improve with the provider.

this.cache = cache;
this.onConnectionEvent = onConnectionEvent;
this.connectedSupplier = connectedSupplier;
@@ -126,7 +128,14 @@ private void observeEventStream() {
while (this.eventStreamAttempt <= this.maxEventStreamRetries) {
final StreamObserver<EventStreamResponse> responseObserver = new EventStreamObserver(sync, this.cache,
this::onConnectionEvent);
this.serviceStub.eventStream(EventStreamRequest.getDefaultInstance(), responseObserver);

ServiceGrpc.ServiceStub localServiceStub = this.serviceStub;

if (this.streamDeadlineMs > 0) {
localServiceStub = localServiceStub.withDeadlineAfter(this.streamDeadlineMs, TimeUnit.MILLISECONDS);
}

localServiceStub.eventStream(EventStreamRequest.getDefaultInstance(), responseObserver);

try {
synchronized (sync) {
Original file line number Diff line number Diff line change
@@ -22,6 +22,8 @@
import io.grpc.Context;
import io.grpc.Context.CancellableContext;
import io.grpc.ManagedChannel;
import io.grpc.Status.Code;
import io.grpc.StatusRuntimeException;
import lombok.extern.slf4j.Slf4j;

/**
@@ -43,6 +45,7 @@ public class GrpcStreamConnector implements Connector {
private final FlagSyncServiceStub serviceStub;
private final FlagSyncServiceBlockingStub serviceBlockingStub;
private final int deadline;
private final int streamDeadlineMs;
private final String selector;

/**
@@ -55,6 +58,7 @@ public GrpcStreamConnector(final FlagdOptions options) {
serviceStub = FlagSyncServiceGrpc.newStub(channel);
serviceBlockingStub = FlagSyncServiceGrpc.newBlockingStub(channel);
deadline = options.getDeadline();
streamDeadlineMs = options.getStreamDeadlineMs();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[suggestion] Cant we just use the options within this class, and pass it further down the call chain? - adding a new field and parameters might be tedious over time

selector = options.getSelector();
}

@@ -64,7 +68,8 @@ public GrpcStreamConnector(final FlagdOptions options) {
public void init() {
Thread listener = new Thread(() -> {
try {
observeEventStream(blockingQueue, shutdown, serviceStub, serviceBlockingStub, selector, deadline);
observeEventStream(blockingQueue, shutdown, serviceStub, serviceBlockingStub, selector, deadline,
streamDeadlineMs);
} catch (InterruptedException e) {
log.warn("gRPC event stream interrupted, flag configurations are stale", e);
Thread.currentThread().interrupt();
@@ -114,7 +119,8 @@ static void observeEventStream(final BlockingQueue<QueuePayload> writeTo,
final FlagSyncServiceStub serviceStub,
final FlagSyncServiceBlockingStub serviceBlockingStub,
final String selector,
final int deadline)
final int deadline,
final int streamDeadlineMs)
throws InterruptedException {

final BlockingQueue<GrpcResponseModel> streamReceiver = new LinkedBlockingQueue<>(QUEUE_SIZE);
@@ -128,14 +134,20 @@ static void observeEventStream(final BlockingQueue<QueuePayload> writeTo,
log.debug("Initializing sync stream request");
final SyncFlagsRequest.Builder syncRequest = SyncFlagsRequest.newBuilder();
final GetMetadataRequest.Builder metadataRequest = GetMetadataRequest.newBuilder();
GetMetadataResponse metadataResponse = GetMetadataResponse.getDefaultInstance();
GetMetadataResponse metadataResponse = GetMetadataResponse.getDefaultInstance();

if (selector != null) {
syncRequest.setSelector(selector);
}

try (CancellableContext context = Context.current().withCancellation()) {
serviceStub.syncFlags(syncRequest.build(), new GrpcStreamHandler(streamReceiver));
FlagSyncServiceStub localServiceStub = serviceStub;
if (streamDeadlineMs > 0) {
localServiceStub = localServiceStub.withDeadlineAfter(streamDeadlineMs, TimeUnit.MILLISECONDS);
}

localServiceStub.syncFlags(syncRequest.build(), new GrpcStreamHandler(streamReceiver));

try {
metadataResponse = serviceBlockingStub.withDeadlineAfter(deadline, TimeUnit.MILLISECONDS)
.getMetadata(metadataRequest.build());
@@ -158,14 +170,21 @@ static void observeEventStream(final BlockingQueue<QueuePayload> writeTo,
}

if (response.getError() != null || metadataException != null) {
log.error(String.format("Error from initializing stream or metadata, retrying in %dms",
retryDelay), response.getError());

if (!writeTo.offer(
new QueuePayload(QueuePayloadType.ERROR, "Error from stream or metadata",
metadataResponse))) {
log.error("Failed to convey ERROR status, queue is full");
if (response.getError() instanceof StatusRuntimeException
&& ((StatusRuntimeException) response.getError()).getStatus().getCode()
.equals(Code.DEADLINE_EXCEEDED)) {
log.debug(String.format("Stream deadline reached, re-establishing in %dms",
retryDelay));
} else {
log.error(String.format("Error initializing stream or metadata, retrying in %dms",
retryDelay), response.getError());
if (!writeTo.offer(
new QueuePayload(QueuePayloadType.ERROR, "Error from stream or metadata",
metadataResponse))) {
log.error("Failed to convey ERROR status, queue is full");
}
}

// close the context to cancel the stream in case just the metadata call failed
context.cancel(metadataException);
break;
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.atMost;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -24,6 +25,8 @@

import dev.openfeature.contrib.providers.flagd.resolver.grpc.cache.Cache;
import dev.openfeature.flagd.grpc.evaluation.Evaluation.EventStreamResponse;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;

class EventStreamObserverTest {

@@ -83,6 +86,15 @@ public void reconnections() {
assertFalse(states.get(0));
}

@Test
public void deadlineExceeded() {
stream.onError(new StatusRuntimeException(Status.DEADLINE_EXCEEDED));
// we flush the cache
verify(cache, never()).clear();
// we notify the error
assertEquals(0, states.size());
}

@Test
public void cacheBustingForKnownKeys() {
final String key1 = "myKey1";
Original file line number Diff line number Diff line change
@@ -2,22 +2,15 @@

import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockConstruction;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.*;

import java.lang.reflect.Field;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

import org.junit.jupiter.api.Test;
@@ -37,6 +30,8 @@
import dev.openfeature.flagd.grpc.evaluation.ServiceGrpc.ServiceBlockingStub;
import dev.openfeature.flagd.grpc.evaluation.ServiceGrpc.ServiceStub;
import io.grpc.Channel;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.epoll.EpollEventLoopGroup;
@@ -58,7 +53,7 @@ void validate_retry_calls(int retries) throws NoSuchFieldException, IllegalAcces

final Cache cache = new Cache("disabled", 0);

final ServiceGrpc.ServiceStub mockStub = mock(ServiceGrpc.ServiceStub.class);
final ServiceGrpc.ServiceStub mockStub = createServiceStubMock();
doAnswer(invocation -> null).when(mockStub).eventStream(any(), any());

final GrpcConnector connector = new GrpcConnector(options, cache, () -> true,
@@ -94,7 +89,7 @@ void validate_retry_calls(int retries) throws NoSuchFieldException, IllegalAcces
@Test
void initialization_succeed_with_connected_status() throws NoSuchFieldException, IllegalAccessException {
final Cache cache = new Cache("disabled", 0);
final ServiceGrpc.ServiceStub mockStub = mock(ServiceGrpc.ServiceStub.class);
final ServiceGrpc.ServiceStub mockStub = createServiceStubMock();
Consumer<ConnectionEvent> onConnectionEvent = mock(Consumer.class);
doAnswer((InvocationOnMock invocation) -> {
EventStreamObserver eventStreamObserver = (EventStreamObserver) invocation.getArgument(1);
@@ -126,9 +121,9 @@ void initialization_succeed_with_connected_status() throws NoSuchFieldException,
}

@Test
void initialization_fail_with_timeout() throws Exception {
void stream_fails_with_error() throws Exception {
final Cache cache = new Cache("disabled", 0);
final ServiceGrpc.ServiceStub mockStub = mock(ServiceGrpc.ServiceStub.class);
final ServiceStub mockStub = createServiceStubMock();
Consumer<ConnectionEvent> onConnectionEvent = mock(Consumer.class);
doAnswer((InvocationOnMock invocation) -> {
EventStreamObserver eventStreamObserver = (EventStreamObserver) invocation.getArgument(1);
@@ -159,13 +154,47 @@ void initialization_fail_with_timeout() throws Exception {
}
}

@Test
void stream_does_not_fail_with_deadline_error() throws Exception {
final Cache cache = new Cache("disabled", 0);
final ServiceStub mockStub = createServiceStubMock();
Consumer<ConnectionEvent> onConnectionEvent = mock(Consumer.class);
doAnswer((InvocationOnMock invocation) -> {
EventStreamObserver eventStreamObserver = (EventStreamObserver) invocation.getArgument(1);
eventStreamObserver
.onError(new StatusRuntimeException(Status.DEADLINE_EXCEEDED));
return null;
}).when(mockStub).eventStream(any(), any());

try (MockedStatic<ServiceGrpc> mockStaticService = mockStatic(ServiceGrpc.class)) {
mockStaticService.when(() -> ServiceGrpc.newStub(any()))
.thenReturn(mockStub);

// pass true in connected lambda
final GrpcConnector connector = new GrpcConnector(FlagdOptions.builder().build(), cache, () -> {
try {
Thread.sleep(100);
return true;
} catch (Exception e) {
}
return false;

},
onConnectionEvent);

assertDoesNotThrow(connector::initialize);
// this should not call the connection event
verify(onConnectionEvent, never()).accept(any());
}
}

@Test
void host_and_port_arg_should_build_tcp_socket() {
final String host = "host.com";
final int port = 1234;

ServiceGrpc.ServiceBlockingStub mockBlockingStub = mock(ServiceGrpc.ServiceBlockingStub.class);
ServiceGrpc.ServiceStub mockStub = mock(ServiceGrpc.ServiceStub.class);
ServiceGrpc.ServiceStub mockStub = createServiceStubMock();
NettyChannelBuilder mockChannelBuilder = getMockChannelBuilderSocket();

try (MockedStatic<ServiceGrpc> mockStaticService = mockStatic(ServiceGrpc.class)) {
@@ -196,7 +225,7 @@ void no_args_host_and_port_env_set_should_build_tcp_socket() throws Exception {

new EnvironmentVariables("FLAGD_HOST", host, "FLAGD_PORT", String.valueOf(port)).execute(() -> {
ServiceGrpc.ServiceBlockingStub mockBlockingStub = mock(ServiceGrpc.ServiceBlockingStub.class);
ServiceGrpc.ServiceStub mockStub = mock(ServiceGrpc.ServiceStub.class);
ServiceGrpc.ServiceStub mockStub = createServiceStubMock();
NettyChannelBuilder mockChannelBuilder = getMockChannelBuilderSocket();

try (MockedStatic<ServiceGrpc> mockStaticService = mockStatic(ServiceGrpc.class)) {
@@ -230,7 +259,7 @@ void path_arg_should_build_domain_socket_with_correct_path() {
final String path = "/some/path";

ServiceGrpc.ServiceBlockingStub mockBlockingStub = mock(ServiceGrpc.ServiceBlockingStub.class);
ServiceGrpc.ServiceStub mockStub = mock(ServiceGrpc.ServiceStub.class);
ServiceGrpc.ServiceStub mockStub = createServiceStubMock();
NettyChannelBuilder mockChannelBuilder = getMockChannelBuilderSocket();

try (MockedStatic<ServiceGrpc> mockStaticService = mockStatic(ServiceGrpc.class)) {
@@ -304,6 +333,50 @@ void no_args_socket_env_should_build_domain_socket_with_correct_path() throws Ex
});
}

@Test
void initialization_with_stream_deadline() throws NoSuchFieldException, IllegalAccessException {
final FlagdOptions options = FlagdOptions.builder()
.streamDeadlineMs(16983)
.build();

final Cache cache = new Cache("disabled", 0);
final ServiceGrpc.ServiceStub mockStub = createServiceStubMock();

try (MockedStatic<ServiceGrpc> mockStaticService = mockStatic(ServiceGrpc.class)) {
mockStaticService.when(() -> ServiceGrpc.newStub(any())).thenReturn(mockStub);

final GrpcConnector connector = new GrpcConnector(options, cache, () -> true, null);

assertDoesNotThrow(connector::initialize);
verify(mockStub).withDeadlineAfter(16983, TimeUnit.MILLISECONDS);
}
}

@Test
void initialization_without_stream_deadline() throws NoSuchFieldException, IllegalAccessException {
final FlagdOptions options = FlagdOptions.builder()
.streamDeadlineMs(0)
.build();

final Cache cache = new Cache("disabled", 0);
final ServiceGrpc.ServiceStub mockStub = createServiceStubMock();

try (MockedStatic<ServiceGrpc> mockStaticService = mockStatic(ServiceGrpc.class)) {
mockStaticService.when(() -> ServiceGrpc.newStub(any())).thenReturn(mockStub);

final GrpcConnector connector = new GrpcConnector(options, cache, () -> true, null);

assertDoesNotThrow(connector::initialize);
verify(mockStub, never()).withDeadlineAfter(16983, TimeUnit.MILLISECONDS);
}
}

private static ServiceStub createServiceStubMock() {
final ServiceStub mockStub = mock(ServiceStub.class);
when(mockStub.withDeadlineAfter(anyLong(), any())).thenReturn(mockStub);
return mockStub;
}

private NettyChannelBuilder getMockChannelBuilderSocket() {
NettyChannelBuilder mockChannelBuilder = mock(NettyChannelBuilder.class);
when(mockChannelBuilder.eventLoopGroup(any(EventLoopGroup.class))).thenReturn(mockChannelBuilder);
Original file line number Diff line number Diff line change
@@ -7,12 +7,7 @@
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.*;

import java.lang.reflect.Field;
import java.time.Duration;
@@ -42,6 +37,7 @@ public void connectionParameters() throws Throwable {
final FlagdOptions options = FlagdOptions.builder()
.selector("selector")
.deadline(1337)
.streamDeadlineMs(87699)
.build();

final GrpcStreamConnector connector = new GrpcStreamConnector(options);
@@ -58,6 +54,37 @@ public void connectionParameters() throws Throwable {
connector.init();
verify(stubMock, timeout(MAX_WAIT_MS.toMillis()).times(1)).syncFlags(any(), any());
verify(blockingStubMock).withDeadlineAfter(1337, TimeUnit.MILLISECONDS);
verify(stubMock).withDeadlineAfter(87699, TimeUnit.MILLISECONDS);

// then
final SyncFlagsRequest flagsRequest = request[0];
assertNotNull(flagsRequest);
assertEquals("selector", flagsRequest.getSelector());
}


@Test
public void disableStreamDeadline() throws Throwable {
// given
final FlagdOptions options = FlagdOptions.builder()
.selector("selector")
.streamDeadlineMs(0)
.build();

final GrpcStreamConnector connector = new GrpcStreamConnector(options);
final FlagSyncServiceStub stubMock = mockStubAndReturn(connector);
final FlagSyncServiceBlockingStub blockingStubMock = mockBlockingStubAndReturn(connector);
final SyncFlagsRequest[] request = new SyncFlagsRequest[1];

doAnswer(invocation -> {
request[0] = invocation.getArgument(0, SyncFlagsRequest.class);
return null;
}).when(stubMock).syncFlags(any(), any());

// when
connector.init();
verify(stubMock, timeout(MAX_WAIT_MS.toMillis()).times(1)).syncFlags(any(), any());
verify(stubMock, never()).withDeadlineAfter(anyLong(), any());

// then
final SyncFlagsRequest flagsRequest = request[0];
@@ -186,6 +213,7 @@ private static FlagSyncServiceStub mockStubAndReturn(final GrpcStreamConnector c
serviceStubField.setAccessible(true);

final FlagSyncServiceStub stubMock = mock(FlagSyncServiceStub.class);
when(stubMock.withDeadlineAfter(anyLong(), any())).thenReturn(stubMock);

serviceStubField.set(connector, stubMock);