Skip to content

[ML] Create and inject APM Inference Metrics #111293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.elasticsearch.inference;

import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.xcontent.ToXContentObject;

Expand Down Expand Up @@ -48,5 +49,6 @@ default DenseVectorFieldMapper.ElementType elementType() {
* be chosen when initializing a deployment within their service. In this situation, return null.
* @return the model used to perform inference or null if the model is not defined
*/
@Nullable
String modelId();
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.indices.SystemIndexDescriptor;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.node.PluginComponentBinding;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.ExtensiblePlugin;
import org.elasticsearch.plugins.MapperPlugin;
Expand Down Expand Up @@ -84,8 +85,8 @@
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService;
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
import org.elasticsearch.xpack.inference.telemetry.InferenceAPMStats;
import org.elasticsearch.xpack.inference.telemetry.StatsMap;
import org.elasticsearch.xpack.inference.telemetry.ApmInferenceStats;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;

import java.util.ArrayList;
import java.util.Collection;
Expand Down Expand Up @@ -196,10 +197,10 @@ public Collection<?> createComponents(PluginServices services) {
var actionFilter = new ShardBulkInferenceActionFilter(registry, modelRegistry);
shardBulkInferenceActionFilter.set(actionFilter);

var statsFactory = new InferenceAPMStats.Factory(services.telemetryProvider().getMeterRegistry());
var statsMap = new StatsMap<>(InferenceAPMStats::key, statsFactory::newInferenceRequestAPMCounter);
var meterRegistry = services.telemetryProvider().getMeterRegistry();
var stats = new PluginComponentBinding<>(InferenceStats.class, ApmInferenceStats.create(meterRegistry));

return List.of(modelRegistry, registry, httpClientManager, statsMap);
return List.of(modelRegistry, registry, httpClientManager, stats);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,26 @@
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;

public class TransportInferenceAction extends HandledTransportAction<InferenceAction.Request, InferenceAction.Response> {

private final ModelRegistry modelRegistry;
private final InferenceServiceRegistry serviceRegistry;
private final InferenceStats inferenceStats;

@Inject
public TransportInferenceAction(
TransportService transportService,
ActionFilters actionFilters,
ModelRegistry modelRegistry,
InferenceServiceRegistry serviceRegistry
InferenceServiceRegistry serviceRegistry,
InferenceStats inferenceStats
) {
super(InferenceAction.NAME, transportService, actionFilters, InferenceAction.Request::new, EsExecutors.DIRECT_EXECUTOR_SERVICE);
this.modelRegistry = modelRegistry;
this.serviceRegistry = serviceRegistry;
this.inferenceStats = inferenceStats;
}

@Override
Expand Down Expand Up @@ -76,6 +80,7 @@ protected void doExecute(Task task, InferenceAction.Request request, ActionListe
unparsedModel.settings(),
unparsedModel.secrets()
);
inferenceStats.incrementRequestCount(model);
inferOnService(model, request, service.get(), delegate);
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public static CohereEmbeddingsModel of(CohereEmbeddingsModel model, Map<String,
}

public CohereEmbeddingsModel(
String modelId,
String inferenceId,
TaskType taskType,
String service,
Map<String, Object> serviceSettings,
Expand All @@ -37,7 +37,7 @@ public CohereEmbeddingsModel(
ConfigurationParseContext context
) {
this(
modelId,
inferenceId,
taskType,
service,
CohereEmbeddingsServiceSettings.fromMap(serviceSettings, context),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ public OpenAiEmbeddingsServiceSettings(
@Nullable RateLimitSettings rateLimitSettings
) {
this.uri = uri;
this.modelId = modelId;
this.modelId = Objects.requireNonNull(modelId);
this.organizationId = organizationId;
this.similarity = similarity;
this.dimensions = dimensions;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.telemetry;

import org.elasticsearch.inference.Model;
import org.elasticsearch.telemetry.metric.LongCounter;
import org.elasticsearch.telemetry.metric.MeterRegistry;

import java.util.HashMap;
import java.util.Objects;

public class ApmInferenceStats implements InferenceStats {
private final LongCounter inferenceAPMRequestCounter;

public ApmInferenceStats(LongCounter inferenceAPMRequestCounter) {
this.inferenceAPMRequestCounter = Objects.requireNonNull(inferenceAPMRequestCounter);
}

@Override
public void incrementRequestCount(Model model) {
var service = model.getConfigurations().getService();
var taskType = model.getTaskType();
var modelId = model.getServiceSettings().modelId();

var attributes = new HashMap<String, Object>(5);
attributes.put("service", service);
attributes.put("task_type", taskType.toString());
if (modelId != null) {
attributes.put("model_id", modelId);
}

inferenceAPMRequestCounter.incrementBy(1, attributes);
}

public static ApmInferenceStats create(MeterRegistry meterRegistry) {
return new ApmInferenceStats(
meterRegistry.registerLongCounter(
"es.inference.requests.count.total",
"Inference API request counts for a particular service, task type, model ID",
"operations"
)
);
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -8,52 +8,14 @@
package org.elasticsearch.xpack.inference.telemetry;

import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.inference.InferenceRequestStats;

import java.util.Objects;
import java.util.concurrent.atomic.LongAdder;
public interface InferenceStats {

public class InferenceStats implements Stats {
protected final String service;
protected final TaskType taskType;
protected final String modelId;
protected final LongAdder counter = new LongAdder();
/**
* Increment the counter for a particular value in a thread safe manner.
* @param model the model to increment request count for
*/
void incrementRequestCount(Model model);

public static String key(Model model) {
StringBuilder builder = new StringBuilder();
builder.append(model.getConfigurations().getService());
builder.append(":");
builder.append(model.getTaskType());

if (model.getServiceSettings().modelId() != null) {
builder.append(":");
builder.append(model.getServiceSettings().modelId());
}

return builder.toString();
}

public InferenceStats(Model model) {
Objects.requireNonNull(model);

service = model.getConfigurations().getService();
taskType = model.getTaskType();
modelId = model.getServiceSettings().modelId();
}

@Override
public void increment() {
counter.increment();
}

@Override
public long getCount() {
return counter.sum();
}

@Override
public InferenceRequestStats toSerializableForm() {
return new InferenceRequestStats(service, taskType, modelId, getCount());
}
InferenceStats NOOP = model -> {};
}

This file was deleted.

This file was deleted.

Loading