Skip to content

[FEATURE][ML] Parse results and join them in the data-frame copy index #36382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import org.elasticsearch.xpack.core.ml.action.RunAnalyticsAction;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.analytics.DataFrameDataExtractor;
import org.elasticsearch.xpack.ml.analytics.DataFrameDataExtractorFactory;
import org.elasticsearch.xpack.ml.analytics.DataFrameFields;
import org.elasticsearch.xpack.ml.analytics.process.AnalyticsProcessManager;
Expand Down Expand Up @@ -178,8 +177,7 @@ private void runPipelineAnalytics(String index, ActionListener<AcknowledgedRespo

ActionListener<DataFrameDataExtractorFactory> dataExtractorFactoryListener = ActionListener.wrap(
dataExtractorFactory -> {
DataFrameDataExtractor dataExtractor = dataExtractorFactory.newExtractor();
analyticsProcessManager.processData(jobId, dataExtractor);
analyticsProcessManager.runJob(jobId, dataExtractorFactory);
listener.onResponse(new AcknowledgedResponse(true));
},
listener::onFailure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ private SearchRequestBuilder buildSearchRequest() {
.setIndices(context.indices)
.setSize(context.scrollSize)
.setQuery(context.query)
.setFetchSource(false);
.setFetchSource(context.includeSource);

for (ExtractedField docValueField : context.extractedFields.getDocValueFields()) {
searchRequestBuilder.addDocValueField(docValueField.getName(), docValueField.getDocValueFormat());
Expand Down Expand Up @@ -149,7 +149,7 @@ private Row createRow(SearchHit hit) {
break;
}
}
return new Row(extractedValues);
return new Row(extractedValues, hit);
}

private List<Row> continueScroll() throws IOException {
Expand Down Expand Up @@ -196,10 +196,11 @@ public DataSummary collectDataSummary() {
SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE)
.setIndices(context.indices)
.setSize(0)
.setQuery(context.query);
.setQuery(context.query)
.setTrackTotalHits(true);

SearchResponse searchResponse = executeSearchRequest(searchRequestBuilder);
return new DataSummary(searchResponse.getHits().getTotalHits(), context.extractedFields.getAllFields().size());
return new DataSummary(searchResponse.getHits().getTotalHits().value, context.extractedFields.getAllFields().size());
}

public static class DataSummary {
Expand All @@ -215,16 +216,27 @@ public DataSummary(long rows, int cols) {

public static class Row {

private SearchHit hit;

@Nullable
private String[] values;

private Row(String[] values) {
private Row(String[] values, SearchHit hit) {
this.values = values;
this.hit = hit;
}

@Nullable
public String[] getValues() {
return values;
}

public SearchHit getHit() {
return hit;
}

public boolean shouldSkip() {
return values == null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ public class DataFrameDataExtractorContext {
final QueryBuilder query;
final int scrollSize;
final Map<String, String> headers;
final boolean includeSource;

DataFrameDataExtractorContext(String jobId, ExtractedFields extractedFields, List<String> indices, QueryBuilder query, int scrollSize,
Map<String, String> headers) {
Map<String, String> headers, boolean includeSource) {
this.jobId = Objects.requireNonNull(jobId);
this.extractedFields = Objects.requireNonNull(extractedFields);
this.indices = indices.toArray(new String[indices.size()]);
this.query = Objects.requireNonNull(query);
this.scrollSize = scrollSize;
this.headers = headers;
this.includeSource = includeSource;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,16 @@ private DataFrameDataExtractorFactory(Client client, String index, ExtractedFiel
this.extractedFields = Objects.requireNonNull(extractedFields);
}

public DataFrameDataExtractor newExtractor() {
public DataFrameDataExtractor newExtractor(boolean includeSource) {
DataFrameDataExtractorContext context = new DataFrameDataExtractorContext(
"ml-analytics-" + index,
extractedFields,
Arrays.asList(index),
QueryBuilders.matchAllQuery(),
1000,
Collections.emptyMap());
Collections.emptyMap(),
includeSource
);
return new DataFrameDataExtractor(client, context);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public class AnalyticsControlMessageWriter extends AbstractControlMsgWriter {
* but in the context of the java side it is more descriptive to call this the
* end of data message.
*/
private static final String END_OF_DATA_MESSAGE_CODE = "r";
private static final String END_OF_DATA_MESSAGE_CODE = "$";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's still 'r' on the C++ side.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK cool. I missed that.


/**
* Construct the control message writer with a LengthEncodedWriter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.elasticsearch.xpack.ml.process.NativeProcess;

import java.io.IOException;
import java.util.Iterator;

public interface AnalyticsProcess extends NativeProcess {

Expand All @@ -17,4 +18,17 @@ public interface AnalyticsProcess extends NativeProcess {
* @throws IOException If an error occurs writing to the process
*/
void writeEndOfDataMessage() throws IOException;

/**
* @return stream of analytics results.
*/
Iterator<AnalyticsResult> readAnalyticsResults();

/**
* Read anything left in the stream before
* closing the stream otherwise if the process
* tries to write more after the close it gets
* a SIGPIPE
*/
void consumeAndCloseOutputStream();
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.analytics.DataFrameAnalysis;
import org.elasticsearch.xpack.ml.analytics.DataFrameDataExtractor;
import org.elasticsearch.xpack.ml.analytics.DataFrameDataExtractorFactory;

import java.io.IOException;
import java.util.List;
Expand All @@ -41,28 +42,39 @@ public AnalyticsProcessManager(Client client, Environment environment, ThreadPoo
this.processFactory = Objects.requireNonNull(analyticsProcessFactory);
}

public void processData(String jobId, DataFrameDataExtractor dataExtractor) {
public void runJob(String jobId, DataFrameDataExtractorFactory dataExtractorFactory) {
threadPool.generic().execute(() -> {
AnalyticsProcess process = createProcess(jobId, dataExtractor);
try {
writeHeaderRecord(dataExtractor, process);
writeDataRows(dataExtractor, process);
process.writeEndOfDataMessage();
process.flushStream();
DataFrameDataExtractor dataExtractor = dataExtractorFactory.newExtractor(false);
AnalyticsProcess process = createProcess(jobId, createProcessConfig(dataExtractor));
ExecutorService executorService = threadPool.executor(MachineLearning.AUTODETECT_THREAD_POOL_NAME);
AnalyticsResultProcessor resultProcessor = new AnalyticsResultProcessor(client, dataExtractorFactory.newExtractor(true));
executorService.execute(() -> resultProcessor.process(process));
executorService.execute(() -> processData(jobId, dataExtractor, process, resultProcessor));
});
}

private void processData(String jobId, DataFrameDataExtractor dataExtractor, AnalyticsProcess process,
AnalyticsResultProcessor resultProcessor) {
try {
writeHeaderRecord(dataExtractor, process);
writeDataRows(dataExtractor, process);
process.writeEndOfDataMessage();
process.flushStream();

LOGGER.debug("[{}] Closing process", jobId);
LOGGER.info("[{}] Waiting for result processor to complete", jobId);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we'll want this to be an INFO in production. If you want to leave it like this on the feature branch please add a TODO to downgrade it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, this shouldn't be an info message. I left it on purpose as I think it's useful during development. I had in mind that we'd review all logging when this is refactored into persistent tasks. Not sure I'd have a todo for each one of them, but I can add them if you think it's best.

resultProcessor.awaitForCompletion();
LOGGER.info("[{}] Result processor has completed", jobId);
} catch (IOException e) {
LOGGER.error(new ParameterizedMessage("[{}] Error writing data to the process", jobId), e);
} finally {
LOGGER.info("[{}] Closing process", jobId);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto either downgrade this now or add a TODO to do so before merging the feature branch.

try {
process.close();
LOGGER.info("[{}] Closed process", jobId);
} catch (IOException e) {
LOGGER.error(new ParameterizedMessage("[{}] Error writing data to the process", jobId), e);
} finally {
try {
process.close();
} catch (IOException e) {
LOGGER.error("[{}] Error closing data frame analyzer process", jobId);
}
LOGGER.error("[{}] Error closing data frame analyzer process", jobId);
}
});
}
}

private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess process) throws IOException {
Expand All @@ -75,8 +87,8 @@ private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProces
Optional<List<DataFrameDataExtractor.Row>> rows = dataExtractor.next();
if (rows.isPresent()) {
for (DataFrameDataExtractor.Row row : rows.get()) {
String[] rowValues = row.getValues();
if (rowValues != null) {
if (row.shouldSkip() == false) {
String[] rowValues = row.getValues();
System.arraycopy(rowValues, 0, record, 0, rowValues.length);
process.writeRecord(record);
}
Expand All @@ -96,10 +108,10 @@ private void writeHeaderRecord(DataFrameDataExtractor dataExtractor, AnalyticsPr
process.writeRecord(headerRecord);
}

private AnalyticsProcess createProcess(String jobId, DataFrameDataExtractor dataExtractor) {
private AnalyticsProcess createProcess(String jobId, AnalyticsProcessConfig analyticsProcessConfig) {
// TODO We should rename the thread pool to reflect its more general use now, e.g. JOB_PROCESS_THREAD_POOL_NAME
ExecutorService executorService = threadPool.executor(MachineLearning.AUTODETECT_THREAD_POOL_NAME);
AnalyticsProcess process = processFactory.createAnalyticsProcess(jobId, createProcessConfig(dataExtractor), executorService);
AnalyticsProcess process = processFactory.createAnalyticsProcess(jobId, analyticsProcessConfig, executorService);
if (process.isProcessAlive() == false) {
throw ExceptionsHelper.serverError("Failed to start analytics process");
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.analytics.process;

import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;

public class AnalyticsResult implements ToXContentObject {

public static final ParseField TYPE = new ParseField("analytics_result");
public static final ParseField ID_HASH = new ParseField("id_hash");
public static final ParseField RESULTS = new ParseField("results");

static final ConstructingObjectParser<AnalyticsResult, Void> PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(),
a -> new AnalyticsResult((String) a[0], (Map<String, Object>) a[1]));

static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), ID_HASH);
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, context) -> p.map(), RESULTS);
}

private final String idHash;
private final Map<String, Object> results;

public AnalyticsResult(String idHash, Map<String, Object> results) {
this.idHash = Objects.requireNonNull(idHash);
this.results = Objects.requireNonNull(results);
}

public String getIdHash() {
return idHash;
}

public Map<String, Object> getResults() {
return results;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ID_HASH.getPreferredName(), idHash);
builder.field(RESULTS.getPreferredName(), results);
builder.endObject();
return builder;
}

@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (other == null || getClass() != other.getClass()) {
return false;
}

AnalyticsResult that = (AnalyticsResult) other;
return Objects.equals(idHash, that.idHash) && Objects.equals(results, that.results);
}

@Override
public int hashCode() {
return Objects.hash(idHash, results);
}
}
Loading