Skip to content

ES APM traces for HTTP requests include authn duration #96205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

import java.security.AccessController;
import java.security.PrivilegedAction;
import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -141,7 +142,8 @@ protected void doStop() {
@Override
protected void doClose() {}

private APMServices createApmServices() {
// package-private for tests
APMServices createApmServices() {
assert this.enabled;
assert this.services == null;

Expand Down Expand Up @@ -187,6 +189,11 @@ public void startTrace(ThreadContext threadContext, SpanId spanId, String spanNa
}

setSpanAttributes(threadContext, attributes, spanBuilder);

Instant startTime = threadContext.getTransient(Task.TRACE_START_TIME);
if (startTime != null) {
spanBuilder.setStartTimestamp(startTime);
}
final Span span = spanBuilder.startSpan();
final Context contextForNewSpan = Context.current().with(span);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,27 @@

package org.elasticsearch.tracing.apm;

import io.opentelemetry.api.common.AttributeKey;
import io.opentelemetry.api.common.Attributes;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.api.trace.SpanBuilder;
import io.opentelemetry.api.trace.SpanContext;
import io.opentelemetry.api.trace.SpanKind;
import io.opentelemetry.api.trace.Tracer;
import io.opentelemetry.context.Context;

import org.apache.lucene.util.automaton.CharacterRunAutomaton;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.tracing.SpanId;

import java.time.Instant;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;

import static org.elasticsearch.tracing.apm.APMAgentSettings.APM_ENABLED_SETTING;
Expand All @@ -24,8 +37,12 @@
import static org.hamcrest.Matchers.aMapWithSize;
import static org.hamcrest.Matchers.anEmptyMap;
import static org.hamcrest.Matchers.hasKey;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.notNullValue;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;

public class APMTracerTests extends ESTestCase {

Expand Down Expand Up @@ -74,7 +91,25 @@ public void test_onTraceStarted_startsTrace() {
}

/**
* Check that when a trace is started, the tracer ends the span and removes the record of it.
* Checks that when a trace is started with a specific start time, the tracer starts a span and records it.
*/
public void test_onTraceStartedWithStartTime_startsTrace() {
Settings settings = Settings.builder().put(APM_ENABLED_SETTING.getKey(), true).build();
APMTracer apmTracer = buildTracer(settings);

ThreadContext threadContext = new ThreadContext(settings);
// 1_000_000L because of "toNanos" conversions that overflow for large long millis
Instant spanStartTime = Instant.ofEpochMilli(randomLongBetween(0, Long.MAX_VALUE / 1_000_000L));
threadContext.putTransient(Task.TRACE_START_TIME, spanStartTime);
apmTracer.startTrace(threadContext, SPAN_ID1, "name1", null);

assertThat(apmTracer.getSpans(), aMapWithSize(1));
assertThat(apmTracer.getSpans(), hasKey(SPAN_ID1));
assertThat(((SpyAPMTracer) apmTracer).getSpanStartTime("name1"), is(spanStartTime));
}

/**
* Check that when a trace is stopped, the tracer ends the span and removes the record of it.
*/
public void test_onTraceStopped_stopsTrace() {
Settings settings = Settings.builder().put(APM_ENABLED_SETTING.getKey(), true).build();
Expand Down Expand Up @@ -211,8 +246,120 @@ public void test_whenAddingAttributes_thenSensitiveValuesAreRedacted() {
}

private APMTracer buildTracer(Settings settings) {
APMTracer tracer = new APMTracer(settings);
APMTracer tracer = new SpyAPMTracer(settings);
tracer.doStart();
return tracer;
}

static class SpyAPMTracer extends APMTracer {

Map<String, Instant> spanStartTimeMap;

SpyAPMTracer(Settings settings) {
super(settings);
this.spanStartTimeMap = new HashMap<>();
}

@Override
APMServices createApmServices() {
APMServices apmServices = super.createApmServices();
Tracer mockTracer = mock(Tracer.class);
doAnswer(invocation -> {
String spanName = (String) invocation.getArguments()[0];
// spy the spanBuilder
return new SpySpanBuilder(apmServices.tracer(), spanName);
}).when(mockTracer).spanBuilder(anyString());
return new APMServices(mockTracer, apmServices.openTelemetry());
}

Instant getSpanStartTime(String spanName) {
return spanStartTimeMap.get(spanName);
}

class SpySpanBuilder implements SpanBuilder {

SpanBuilder delegatedSpanBuilder;
Instant startTime;
String spanName;

SpySpanBuilder(Tracer tracer, String spanName) {
this.delegatedSpanBuilder = tracer.spanBuilder(spanName);
this.spanName = spanName;
}

@Override
public SpanBuilder setParent(Context context) {
delegatedSpanBuilder.setParent(context);
return this;
}

@Override
public SpanBuilder setNoParent() {
delegatedSpanBuilder.setNoParent();
return this;
}

@Override
public SpanBuilder addLink(SpanContext spanContext) {
delegatedSpanBuilder.addLink(spanContext);
return this;
}

@Override
public SpanBuilder addLink(SpanContext spanContext, Attributes attributes) {
delegatedSpanBuilder.addLink(spanContext, attributes);
return this;
}

@Override
public SpanBuilder setAttribute(String key, String value) {
delegatedSpanBuilder.setAttribute(key, value);
return this;
}

@Override
public SpanBuilder setAttribute(String key, long value) {
delegatedSpanBuilder.setAttribute(key, value);
return this;
}

@Override
public SpanBuilder setAttribute(String key, double value) {
delegatedSpanBuilder.setAttribute(key, value);
return this;
}

@Override
public SpanBuilder setAttribute(String key, boolean value) {
delegatedSpanBuilder.setAttribute(key, value);
return this;
}

@Override
public <T> SpanBuilder setAttribute(AttributeKey<T> key, T value) {
delegatedSpanBuilder.setAttribute(key, value);
return this;
}

@Override
public SpanBuilder setSpanKind(SpanKind spanKind) {
delegatedSpanBuilder.setSpanKind(spanKind);
return this;
}

@Override
public SpanBuilder setStartTimestamp(long startTimestamp, TimeUnit unit) {
startTime = Instant.ofEpochMilli(TimeUnit.MILLISECONDS.convert(startTimestamp, unit));
delegatedSpanBuilder.setStartTimestamp(startTimestamp, unit);
return this;
}

@Override
public Span startSpan() {
// finally record the spanName-startTime association when the span is actually started
spanStartTimeMap.put(spanName, startTime);
return delegatedSpanBuilder.startSpan();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@
import org.elasticsearch.tracing.Tracer;
import org.elasticsearch.usage.UsageService;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -558,6 +559,9 @@ public ActionModule(
* finishes and returns.
*/
public void copyRequestHeadersToThreadContext(HttpPreRequest request, ThreadContext threadContext) {
// the request's thread-context must always be populated (by calling this method) before undergoing any request related processing
// we use this opportunity to first record the request processing start time
threadContext.putTransient(Task.TRACE_START_TIME, Instant.ofEpochMilli(threadPool.absoluteTimeInMillis()));
for (final RestHeaderDefinition restHeader : headersToCopy) {
final String name = restHeader.getName();
final List<String> headerValues = request.getHeaders().get(name);
Expand Down
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/tasks/Task.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ public class Task {
* Has to be declared as a header copied over for tasks.
*/
public static final String TRACE_ID = "trace.id";

public static final String TRACE_START_TIME = "trace.starttime";
public static final String TRACE_PARENT = "traceparent";

public static final Set<String> HEADERS_TO_COPY = Set.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Mockito.mock;

Expand Down Expand Up @@ -219,6 +220,8 @@ public void dispatchRequest(final RestRequest request, final RestChannel channel
// specified request headers value are copied into the thread context
assertEquals("true", threadContext.getHeader("header.1"));
assertEquals("true", threadContext.getHeader("header.2"));
// trace start time is also set
assertThat(threadContext.getTransient(Task.TRACE_START_TIME), notNullValue());
// but unknown headers are not copied at all
assertNull(threadContext.getHeader("header.3"));
}
Expand All @@ -229,6 +232,7 @@ public void dispatchBadRequest(final RestChannel channel, final ThreadContext th
assertNull(threadContext.getHeader("header.1"));
assertNull(threadContext.getHeader("header.2"));
assertNull(threadContext.getHeader("header.3"));
assertNull(threadContext.getTransient(Task.TRACE_START_TIME));
}

};
Expand Down Expand Up @@ -312,6 +316,8 @@ public void dispatchRequest(final RestRequest request, final RestChannel channel
assertThat(threadContext.getHeader(Task.TRACE_ID), equalTo("0af7651916cd43dd8448eb211c80319c"));
assertThat(threadContext.getHeader(Task.TRACE_PARENT_HTTP_HEADER), nullValue());
assertThat(threadContext.getTransient("parent_" + Task.TRACE_PARENT_HTTP_HEADER), equalTo(traceParentValue));
// request trace start time is also set
assertThat(threadContext.getTransient(Task.TRACE_START_TIME), notNullValue());
}

@Override
Expand All @@ -320,6 +326,7 @@ public void dispatchBadRequest(final RestChannel channel, final ThreadContext th
assertThat(threadContext.getHeader(Task.TRACE_ID), nullValue());
assertThat(threadContext.getHeader(Task.TRACE_PARENT_HTTP_HEADER), nullValue());
assertThat(threadContext.getTransient("parent_" + Task.TRACE_PARENT_HTTP_HEADER), nullValue());
assertThat(threadContext.getTransient(Task.TRACE_START_TIME), nullValue());
}

};
Expand Down