Skip to content

[FEATURE][ML] Skip rows that have missing data #36067

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.action.search.SearchScrollAction;
import org.elasticsearch.action.search.SearchScrollRequestBuilder;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.sort.SortOrder;
Expand Down Expand Up @@ -68,18 +69,18 @@ public void cancel() {
isCancelled = true;
}

public Optional<List<String[]>> next() throws IOException {
public Optional<List<Row>> next() throws IOException {
if (!hasNext()) {
throw new NoSuchElementException();
}
Optional<List<String[]>> records = scrollId == null ? Optional.ofNullable(initScroll()) : Optional.ofNullable(continueScroll());
if (!records.isPresent()) {
Optional<List<Row>> hits = scrollId == null ? Optional.ofNullable(initScroll()) : Optional.ofNullable(continueScroll());
if (!hits.isPresent()) {
hasNext = false;
}
return records;
return hits;
}

protected List<String[]> initScroll() throws IOException {
protected List<Row> initScroll() throws IOException {
LOGGER.debug("[{}] Initializing scroll", "analytics");
SearchResponse searchResponse = executeSearchRequest(buildSearchRequest());
LOGGER.debug("[{}] Search response was obtained", context.jobId);
Expand All @@ -106,7 +107,7 @@ private SearchRequestBuilder buildSearchRequest() {
return searchRequestBuilder;
}

private List<String[]> processSearchResponse(SearchResponse searchResponse) throws IOException {
private List<Row> processSearchResponse(SearchResponse searchResponse) throws IOException {

if (searchResponse.getFailedShards() > 0 && searchHasShardFailure == false) {
LOGGER.debug("[{}] Resetting scroll search after shard failure", context.jobId);
Expand All @@ -123,29 +124,35 @@ private List<String[]> processSearchResponse(SearchResponse searchResponse) thro
}

SearchHit[] hits = searchResponse.getHits().getHits();
List<String[]> records = new ArrayList<>(hits.length);
List<Row> rows = new ArrayList<>(hits.length);
for (SearchHit hit : hits) {
if (isCancelled) {
hasNext = false;
clearScroll(scrollId);
break;
}
records.add(toStringArray(hit));
rows.add(createRow(hit));
}
return records;
return rows;

}

private String[] toStringArray(SearchHit hit) {
String[] result = new String[context.extractedFields.getAllFields().size()];
for (int i = 0; i < result.length; ++i) {
private Row createRow(SearchHit hit) {
String[] extractedValues = new String[context.extractedFields.getAllFields().size()];
for (int i = 0; i < extractedValues.length; ++i) {
ExtractedField field = context.extractedFields.getAllFields().get(i);
Object[] values = field.value(hit);
result[i] = (values.length == 1) ? Objects.toString(values[0]) : "";
if (values.length == 1 && values[0] instanceof Number) {
extractedValues[i] = Objects.toString(values[0]);
} else {
extractedValues = null;
break;
}
}
return result;
return new Row(extractedValues);
}

private List<String[]> continueScroll() throws IOException {
private List<Row> continueScroll() throws IOException {
LOGGER.debug("[{}] Continuing scroll with id [{}]", context.jobId, scrollId);
SearchResponse searchResponse = executeSearchScrollRequest(scrollId);
LOGGER.debug("[{}] Search response was obtained", context.jobId);
Expand Down Expand Up @@ -210,4 +217,19 @@ public DataSummary(long rows, long cols) {
this.cols = cols;
}
}

public static class Row {

@Nullable
private String[] values;

private Row(String[] values) {
this.values = values;
}

@Nullable
public String[] getValues() {
return values;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,13 @@ public void processData(String jobId, DataFrameDataExtractor dataExtractor) {
process.writeRecord(dataExtractor.getFieldNamesArray());

while (dataExtractor.hasNext()) {
Optional<List<String[]>> records = dataExtractor.next();
if (records.isPresent()) {
for (String[] record : records.get()) {
process.writeRecord(record);
Optional<List<DataFrameDataExtractor.Row>> rows = dataExtractor.next();
if (rows.isPresent()) {
for (DataFrameDataExtractor.Row row : rows.get()) {
String[] rowValues = row.getValues();
if (rowValues != null) {
process.writeRecord(rowValues);
}
}
}
}
Expand Down