Skip to content

Commit 8cf5337

Browse files
[FEATURE][ML] Skip rows that have missing data (#36067)
This also prepares for allowing the `DataFrameDataExtractor` to be reused while joining the results with the raw documents.
1 parent 68bc59c commit 8cf5337

File tree

2 files changed

+44
-19
lines changed

2 files changed

+44
-19
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.action.search.SearchScrollAction;
1616
import org.elasticsearch.action.search.SearchScrollRequestBuilder;
1717
import org.elasticsearch.client.Client;
18+
import org.elasticsearch.common.Nullable;
1819
import org.elasticsearch.common.unit.TimeValue;
1920
import org.elasticsearch.search.SearchHit;
2021
import org.elasticsearch.search.sort.SortOrder;
@@ -68,18 +69,18 @@ public void cancel() {
6869
isCancelled = true;
6970
}
7071

71-
public Optional<List<String[]>> next() throws IOException {
72+
public Optional<List<Row>> next() throws IOException {
7273
if (!hasNext()) {
7374
throw new NoSuchElementException();
7475
}
75-
Optional<List<String[]>> records = scrollId == null ? Optional.ofNullable(initScroll()) : Optional.ofNullable(continueScroll());
76-
if (!records.isPresent()) {
76+
Optional<List<Row>> hits = scrollId == null ? Optional.ofNullable(initScroll()) : Optional.ofNullable(continueScroll());
77+
if (!hits.isPresent()) {
7778
hasNext = false;
7879
}
79-
return records;
80+
return hits;
8081
}
8182

82-
protected List<String[]> initScroll() throws IOException {
83+
protected List<Row> initScroll() throws IOException {
8384
LOGGER.debug("[{}] Initializing scroll", "analytics");
8485
SearchResponse searchResponse = executeSearchRequest(buildSearchRequest());
8586
LOGGER.debug("[{}] Search response was obtained", context.jobId);
@@ -106,7 +107,7 @@ private SearchRequestBuilder buildSearchRequest() {
106107
return searchRequestBuilder;
107108
}
108109

109-
private List<String[]> processSearchResponse(SearchResponse searchResponse) throws IOException {
110+
private List<Row> processSearchResponse(SearchResponse searchResponse) throws IOException {
110111

111112
if (searchResponse.getFailedShards() > 0 && searchHasShardFailure == false) {
112113
LOGGER.debug("[{}] Resetting scroll search after shard failure", context.jobId);
@@ -123,29 +124,35 @@ private List<String[]> processSearchResponse(SearchResponse searchResponse) thro
123124
}
124125

125126
SearchHit[] hits = searchResponse.getHits().getHits();
126-
List<String[]> records = new ArrayList<>(hits.length);
127+
List<Row> rows = new ArrayList<>(hits.length);
127128
for (SearchHit hit : hits) {
128129
if (isCancelled) {
129130
hasNext = false;
130131
clearScroll(scrollId);
131132
break;
132133
}
133-
records.add(toStringArray(hit));
134+
rows.add(createRow(hit));
134135
}
135-
return records;
136+
return rows;
137+
136138
}
137139

138-
private String[] toStringArray(SearchHit hit) {
139-
String[] result = new String[context.extractedFields.getAllFields().size()];
140-
for (int i = 0; i < result.length; ++i) {
140+
private Row createRow(SearchHit hit) {
141+
String[] extractedValues = new String[context.extractedFields.getAllFields().size()];
142+
for (int i = 0; i < extractedValues.length; ++i) {
141143
ExtractedField field = context.extractedFields.getAllFields().get(i);
142144
Object[] values = field.value(hit);
143-
result[i] = (values.length == 1) ? Objects.toString(values[0]) : "";
145+
if (values.length == 1 && values[0] instanceof Number) {
146+
extractedValues[i] = Objects.toString(values[0]);
147+
} else {
148+
extractedValues = null;
149+
break;
150+
}
144151
}
145-
return result;
152+
return new Row(extractedValues);
146153
}
147154

148-
private List<String[]> continueScroll() throws IOException {
155+
private List<Row> continueScroll() throws IOException {
149156
LOGGER.debug("[{}] Continuing scroll with id [{}]", context.jobId, scrollId);
150157
SearchResponse searchResponse = executeSearchScrollRequest(scrollId);
151158
LOGGER.debug("[{}] Search response was obtained", context.jobId);
@@ -210,4 +217,19 @@ public DataSummary(long rows, long cols) {
210217
this.cols = cols;
211218
}
212219
}
220+
221+
public static class Row {
222+
223+
@Nullable
224+
private String[] values;
225+
226+
private Row(String[] values) {
227+
this.values = values;
228+
}
229+
230+
@Nullable
231+
public String[] getValues() {
232+
return values;
233+
}
234+
}
213235
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,13 @@ public void processData(String jobId, DataFrameDataExtractor dataExtractor) {
4848
process.writeRecord(dataExtractor.getFieldNamesArray());
4949

5050
while (dataExtractor.hasNext()) {
51-
Optional<List<String[]>> records = dataExtractor.next();
52-
if (records.isPresent()) {
53-
for (String[] record : records.get()) {
54-
process.writeRecord(record);
51+
Optional<List<DataFrameDataExtractor.Row>> rows = dataExtractor.next();
52+
if (rows.isPresent()) {
53+
for (DataFrameDataExtractor.Row row : rows.get()) {
54+
String[] rowValues = row.getValues();
55+
if (rowValues != null) {
56+
process.writeRecord(rowValues);
57+
}
5558
}
5659
}
5760
}

0 commit comments

Comments
 (0)