Skip to content

Commit 9fd1d10

Browse files
[FEATURE][ML] Add checksum checks on dataframe result joining (#37259)
In order to sanity check that analytics results are joined correctly with their corresponding dataframe rows, we write a checksum for each dataframe row which is a 32-bit hash of the analysis fields. The analytics process includes it in the results. Upon joining we check that the checksums match.
1 parent 371f5d7 commit 9fd1d10

File tree

9 files changed

+98
-25
lines changed

9 files changed

+98
-25
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import java.io.IOException;
2727
import java.util.ArrayList;
28+
import java.util.Arrays;
2829
import java.util.List;
2930
import java.util.NoSuchElementException;
3031
import java.util.Objects;
@@ -238,5 +239,9 @@ public SearchHit getHit() {
238239
public boolean shouldSkip() {
239240
return values == null;
240241
}
242+
243+
public int getChecksum() {
244+
return Arrays.hashCode(values);
245+
}
241246
}
242247
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactory.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,10 @@ static ExtractedFields detectExtractedFields(FieldCapabilitiesResponse fieldCapa
109109
Set<String> fields = fieldCapabilitiesResponse.get().keySet();
110110
fields.removeAll(IGNORE_FIELDS);
111111
removeFieldsWithIncompatibleTypes(fields, fieldCapabilitiesResponse);
112-
ExtractedFields extractedFields = ExtractedFields.build(new ArrayList<>(fields), Collections.emptySet(), fieldCapabilitiesResponse)
112+
List<String> sortedFields = new ArrayList<>(fields);
113+
// We sort the fields to ensure the checksum for each document is deterministic
114+
Collections.sort(sortedFields);
115+
ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse)
113116
.filterFields(ExtractedField.ExtractionMethod.DOC_VALUE);
114117
if (extractedFields.getAllFields().isEmpty()) {
115118
throw ExceptionsHelper.badRequestException("No compatible fields could be detected");

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ private void processData(String jobId, DataFrameDataExtractor dataExtractor, Ana
7878
}
7979

8080
private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess process) throws IOException {
81-
// The extra field is the control field (should be an empty string)
82-
String[] record = new String[dataExtractor.getFieldNames().size() + 1];
81+
// The extra fields are for the doc hash and the control field (should be an empty string)
82+
String[] record = new String[dataExtractor.getFieldNames().size() + 2];
8383
// The value of the control field should be an empty string for data frame rows
8484
record[record.length - 1] = "";
8585

@@ -90,6 +90,7 @@ private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProces
9090
if (row.shouldSkip() == false) {
9191
String[] rowValues = row.getValues();
9292
System.arraycopy(rowValues, 0, record, 0, rowValues.length);
93+
record[record.length - 2] = String.valueOf(row.getChecksum());
9394
process.writeRecord(record);
9495
}
9596
}
@@ -99,11 +100,16 @@ private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProces
99100

100101
private void writeHeaderRecord(DataFrameDataExtractor dataExtractor, AnalyticsProcess process) throws IOException {
101102
List<String> fieldNames = dataExtractor.getFieldNames();
102-
String[] headerRecord = new String[fieldNames.size() + 1];
103+
104+
// We add 2 extra fields, both named dot:
105+
// - the document hash
106+
// - the control message
107+
String[] headerRecord = new String[fieldNames.size() + 2];
103108
for (int i = 0; i < fieldNames.size(); i++) {
104109
headerRecord[i] = fieldNames.get(i);
105110
}
106-
// The field name of the control field is dot
111+
112+
headerRecord[headerRecord.length - 2] = ".";
107113
headerRecord[headerRecord.length - 1] = ".";
108114
process.writeRecord(headerRecord);
109115
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResult.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,27 @@
1717
public class AnalyticsResult implements ToXContentObject {
1818

1919
public static final ParseField TYPE = new ParseField("analytics_result");
20-
public static final ParseField ID_HASH = new ParseField("id_hash");
20+
public static final ParseField CHECKSUM = new ParseField("checksum");
2121
public static final ParseField RESULTS = new ParseField("results");
2222

2323
static final ConstructingObjectParser<AnalyticsResult, Void> PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(),
24-
a -> new AnalyticsResult((String) a[0], (Map<String, Object>) a[1]));
24+
a -> new AnalyticsResult((Integer) a[0], (Map<String, Object>) a[1]));
2525

2626
static {
27-
PARSER.declareString(ConstructingObjectParser.constructorArg(), ID_HASH);
27+
PARSER.declareInt(ConstructingObjectParser.constructorArg(), CHECKSUM);
2828
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, context) -> p.map(), RESULTS);
2929
}
3030

31-
private final String idHash;
31+
private final int checksum;
3232
private final Map<String, Object> results;
3333

34-
public AnalyticsResult(String idHash, Map<String, Object> results) {
35-
this.idHash = Objects.requireNonNull(idHash);
34+
public AnalyticsResult(int checksum, Map<String, Object> results) {
35+
this.checksum = Objects.requireNonNull(checksum);
3636
this.results = Objects.requireNonNull(results);
3737
}
3838

39-
public String getIdHash() {
40-
return idHash;
39+
public int getChecksum() {
40+
return checksum;
4141
}
4242

4343
public Map<String, Object> getResults() {
@@ -47,7 +47,7 @@ public Map<String, Object> getResults() {
4747
@Override
4848
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
4949
builder.startObject();
50-
builder.field(ID_HASH.getPreferredName(), idHash);
50+
builder.field(CHECKSUM.getPreferredName(), checksum);
5151
builder.field(RESULTS.getPreferredName(), results);
5252
builder.endObject();
5353
return builder;
@@ -63,11 +63,11 @@ public boolean equals(Object other) {
6363
}
6464

6565
AnalyticsResult that = (AnalyticsResult) other;
66-
return Objects.equals(idHash, that.idHash) && Objects.equals(results, that.results);
66+
return checksum == that.checksum && Objects.equals(results, that.results);
6767
}
6868

6969
@Override
7070
public int hashCode() {
71-
return Objects.hash(idHash, results);
71+
return Objects.hash(checksum, results);
7272
}
7373
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultProcessor.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ private void joinCurrentResults() {
9696
continue;
9797
}
9898
AnalyticsResult result = currentResults.get(i);
99+
checkChecksumsMatch(row, result);
100+
99101
SearchHit hit = row.getHit();
100102
Map<String, Object> source = new LinkedHashMap(hit.getSourceAsMap());
101103
source.putAll(result.getResults());
@@ -112,4 +114,14 @@ private void joinCurrentResults() {
112114
}
113115
}
114116
}
117+
118+
private void checkChecksumsMatch(DataFrameDataExtractor.Row row, AnalyticsResult result) {
119+
if (row.getChecksum() != result.getChecksum()) {
120+
String msg = "Detected checksum mismatch for document with id [" + row.getHit().getId() + "]; ";
121+
msg += "expected [" + row.getChecksum() + "] but result had [" + result.getChecksum() + "]; ";
122+
msg += "this implies the data frame index [" + row.getHit().getIndex() + "] was modified while the analysis was running. ";
123+
msg += "We rely on this index being immutable during a running analysis and so the results will be unreliable.";
124+
throw new IllegalStateException(msg);
125+
}
126+
}
115127
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcessFactory.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ public AnalyticsProcess createAnalyticsProcess(String jobId, AnalyticsProcessCon
4545
ProcessPipes processPipes = new ProcessPipes(env, NAMED_PIPE_HELPER, AnalyticsBuilder.ANALYTICS, jobId,
4646
true, false, true, true, false, false);
4747

48-
// The extra 1 is the control field
49-
int numberOfFields = analyticsProcessConfig.cols() + 1;
48+
// The extra 2 are for the checksum and the control field
49+
int numberOfFields = analyticsProcessConfig.cols() + 2;
5050

5151
createNativeProcess(jobId, analyticsProcessConfig, filesToDelete, processPipes);
5252

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactoryTests.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField;
1313
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields;
1414

15+
import java.util.ArrayList;
16+
import java.util.Collections;
1517
import java.util.HashMap;
1618
import java.util.List;
1719
import java.util.Map;
@@ -89,6 +91,28 @@ public void testDetectExtractedFields_GivenIgnoredField() {
8991
assertThat(e.getMessage(), equalTo("No compatible fields could be detected"));
9092
}
9193

94+
public void testDetectExtractedFields_ShouldSortFieldsAlphabetically() {
95+
int fieldCount = randomIntBetween(10, 20);
96+
List<String> fields = new ArrayList<>();
97+
for (int i = 0; i < fieldCount; i++) {
98+
fields.add(randomAlphaOfLength(20));
99+
}
100+
List<String> sortedFields = new ArrayList<>(fields);
101+
Collections.sort(sortedFields);
102+
103+
MockFieldCapsResponseBuilder mockFieldCapsResponseBuilder = new MockFieldCapsResponseBuilder();
104+
for (String field : fields) {
105+
mockFieldCapsResponseBuilder.addAggregatableField(field, "float");
106+
}
107+
FieldCapabilitiesResponse fieldCapabilities = mockFieldCapsResponseBuilder.build();
108+
109+
ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(fieldCapabilities);
110+
111+
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
112+
.collect(Collectors.toList());
113+
assertThat(extractedFieldNames, equalTo(sortedFields));
114+
}
115+
92116
private static class MockFieldCapsResponseBuilder {
93117

94118
private final Map<String, Map<String, FieldCapabilities>> fieldCaps = new HashMap<>();

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultProcessorTests.java

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,13 @@ public void testProcess_GivenSingleRowAndResult() throws IOException {
6363

6464
String dataDoc = "{\"f_1\": \"foo\", \"f_2\": 42.0}";
6565
String[] dataValues = {"42.0"};
66-
DataFrameDataExtractor.Row row = newRow(newHit("1", dataDoc), dataValues);
66+
DataFrameDataExtractor.Row row = newRow(newHit(dataDoc), dataValues, 1);
6767
givenSingleDataFrameBatch(Arrays.asList(row));
6868

6969
Map<String, Object> resultFields = new HashMap<>();
7070
resultFields.put("a", "1");
7171
resultFields.put("b", "2");
72-
AnalyticsResult result = new AnalyticsResult("some_hash", resultFields);
72+
AnalyticsResult result = new AnalyticsResult(1, resultFields);
7373
givenProcessResults(Arrays.asList(result));
7474

7575
AnalyticsResultProcessor resultProcessor = createResultProcessor();
@@ -90,6 +90,28 @@ public void testProcess_GivenSingleRowAndResult() throws IOException {
9090
assertThat(indexedDocSource.get("b"), equalTo("2"));
9191
}
9292

93+
public void testProcess_GivenSingleRowAndResultWithMismatchingIdHash() throws IOException {
94+
givenClientHasNoFailures();
95+
96+
String dataDoc = "{\"f_1\": \"foo\", \"f_2\": 42.0}";
97+
String[] dataValues = {"42.0"};
98+
DataFrameDataExtractor.Row row = newRow(newHit(dataDoc), dataValues, 1);
99+
givenSingleDataFrameBatch(Arrays.asList(row));
100+
101+
Map<String, Object> resultFields = new HashMap<>();
102+
resultFields.put("a", "1");
103+
resultFields.put("b", "2");
104+
AnalyticsResult result = new AnalyticsResult(2, resultFields);
105+
givenProcessResults(Arrays.asList(result));
106+
107+
AnalyticsResultProcessor resultProcessor = createResultProcessor();
108+
109+
resultProcessor.process(process);
110+
resultProcessor.awaitForCompletion();
111+
112+
verifyNoMoreInteractions(client);
113+
}
114+
93115
private void givenProcessResults(List<AnalyticsResult> results) {
94116
when(process.readAnalyticsResults()).thenReturn(results.iterator());
95117
}
@@ -99,16 +121,17 @@ private void givenSingleDataFrameBatch(List<DataFrameDataExtractor.Row> batch) t
99121
when(dataExtractor.next()).thenReturn(Optional.of(batch)).thenReturn(Optional.empty());
100122
}
101123

102-
private static SearchHit newHit(String id, String json) {
103-
SearchHit hit = new SearchHit(42, id, new Text("doc"), Collections.emptyMap());
124+
private static SearchHit newHit(String json) {
125+
SearchHit hit = new SearchHit(randomInt(), randomAlphaOfLength(10), new Text("doc"), Collections.emptyMap());
104126
hit.sourceRef(new BytesArray(json));
105127
return hit;
106128
}
107129

108-
private static DataFrameDataExtractor.Row newRow(SearchHit hit, String[] values) {
130+
private static DataFrameDataExtractor.Row newRow(SearchHit hit, String[] values, int checksum) {
109131
DataFrameDataExtractor.Row row = mock(DataFrameDataExtractor.Row.class);
110132
when(row.getHit()).thenReturn(hit);
111133
when(row.getValues()).thenReturn(values);
134+
when(row.getChecksum()).thenReturn(checksum);
112135
return row;
113136
}
114137

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
1616

1717
@Override
1818
protected AnalyticsResult createTestInstance() {
19-
String idHash = randomAlphaOfLength(20);
19+
int checksum = randomInt();
2020
Map<String, Object> results = new HashMap<>();
2121
int resultsSize = randomIntBetween(1, 10);
2222
for (int i = 0; i < resultsSize; i++) {
2323
String resultField = randomAlphaOfLength(20);
2424
Object resultValue = randomBoolean() ? randomAlphaOfLength(20) : randomDouble();
2525
results.put(resultField, resultValue);
2626
}
27-
return new AnalyticsResult(idHash, results);
27+
return new AnalyticsResult(checksum, results);
2828
}
2929

3030
@Override

0 commit comments

Comments
 (0)