Skip to content

Commit 5ff9846

Browse files
committed
[ML] Improve CSV header row detection in find_file_structure
When doing a fieldwise Levenshtein distance comparison between CSV rows, this change ignores all fields that have long values, not just the longest field. This approach works better for CSV formats that have multiple freeform text fields rather than just a single "message" field. Fixes elastic#45047
1 parent 0a6adce commit 5ff9846

File tree

2 files changed

+95
-21
lines changed

2 files changed

+95
-21
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/filestructurefinder/DelimitedFileStructureFinder.java

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import java.io.StringReader;
1818
import java.util.ArrayList;
1919
import java.util.Arrays;
20+
import java.util.BitSet;
2021
import java.util.Collections;
2122
import java.util.DoubleSummaryStatistics;
2223
import java.util.HashSet;
@@ -27,12 +28,12 @@
2728
import java.util.Random;
2829
import java.util.SortedMap;
2930
import java.util.stream.Collectors;
30-
import java.util.stream.IntStream;
3131

3232
public class DelimitedFileStructureFinder implements FileStructureFinder {
3333

3434
private static final String REGEX_NEEDS_ESCAPE_PATTERN = "([\\\\|()\\[\\]{}^$.+*?])";
3535
private static final int MAX_LEVENSHTEIN_COMPARISONS = 100;
36+
private static final int LONG_FIELD_THRESHOLD = 100;
3637

3738
private final List<String> sampleMessages;
3839
private final FileStructure structure;
@@ -322,10 +323,15 @@ private static boolean isFirstRowUnusual(List<String> explanation, List<List<Str
322323
explanation.add("First row is not unusual based on length test: [" + firstRowLength + "] and [" +
323324
toNiceString(otherRowStats) + "]");
324325

325-
// Check edit distances
326+
// Check edit distances between short fields
326327

328+
BitSet shortFieldMask = makeShortFieldMask(rows, LONG_FIELD_THRESHOLD);
329+
330+
// The reason that only short fields are included is that sometimes
331+
// there are "message" fields that are much longer than the other
332+
// fields, vary enormously between rows, and skew the comparison.
327333
DoubleSummaryStatistics firstRowStats = otherRows.stream().limit(MAX_LEVENSHTEIN_COMPARISONS)
328-
.mapToDouble(otherRow -> (double) levenshteinFieldwiseCompareRows(firstRow, otherRow))
334+
.mapToDouble(otherRow -> (double) levenshteinFieldwiseCompareRows(firstRow, otherRow, shortFieldMask))
329335
.collect(DoubleSummaryStatistics::new, DoubleSummaryStatistics::accept, DoubleSummaryStatistics::combine);
330336

331337
otherRowStats = new DoubleSummaryStatistics();
@@ -336,7 +342,7 @@ private static boolean isFirstRowUnusual(List<String> explanation, List<List<Str
336342
for (int i = 0; numComparisons < MAX_LEVENSHTEIN_COMPARISONS && i < otherRowStrs.size(); ++i) {
337343
for (int j = i + 1 + random.nextInt(innerIncrement); numComparisons < MAX_LEVENSHTEIN_COMPARISONS && j < otherRowStrs.size();
338344
j += innerIncrement) {
339-
otherRowStats.accept((double) levenshteinFieldwiseCompareRows(otherRows.get(i), otherRows.get(j)));
345+
otherRowStats.accept((double) levenshteinFieldwiseCompareRows(otherRows.get(i), otherRows.get(j), shortFieldMask));
340346
++numComparisons;
341347
}
342348
}
@@ -358,30 +364,58 @@ private static String toNiceString(DoubleSummaryStatistics stats) {
358364
stats.getMax());
359365
}
360366

367+
/**
368+
* Make a mask whose bits are set when the corresponding field in every supplied
369+
* row is short, and unset if the corresponding field in any supplied row is long.
370+
*/
371+
static BitSet makeShortFieldMask(List<List<String>> rows, int longFieldThreshold) {
372+
373+
assert rows.isEmpty() == false;
374+
375+
BitSet shortFieldMask = new BitSet();
376+
377+
int maxLength = rows.stream().map(List::size).max(Integer::compareTo).get();
378+
for (int index = 0; index < maxLength; ++index) {
379+
final int i = index;
380+
shortFieldMask.set(i,
381+
rows.stream().allMatch(row -> i >= row.size() || row.get(i) == null || row.get(i).length() < longFieldThreshold));
382+
}
383+
384+
return shortFieldMask;
385+
}
386+
361387
/**
362388
* Sum of the Levenshtein distances between corresponding elements
363-
* in the two supplied lists _excluding_ the biggest difference.
364-
* The reason the biggest difference is excluded is that sometimes
365-
* there's a "message" field that is much longer than any of the other
366-
* fields, varies enormously between rows, and skews the comparison.
389+
* in the two supplied lists.
367390
*/
368391
static int levenshteinFieldwiseCompareRows(List<String> firstRow, List<String> secondRow) {
369392

370393
int largestSize = Math.max(firstRow.size(), secondRow.size());
371-
if (largestSize <= 1) {
394+
if (largestSize < 1) {
372395
return 0;
373396
}
374397

375-
int[] distances = new int[largestSize];
398+
BitSet allFields = new BitSet();
399+
allFields.set(0, largestSize);
400+
401+
return levenshteinFieldwiseCompareRows(firstRow, secondRow, allFields);
402+
}
376403

377-
for (int index = 0; index < largestSize; ++index) {
378-
distances[index] = levenshteinDistance((index < firstRow.size()) ? firstRow.get(index) : "",
404+
/**
405+
* Sum of the Levenshtein distances between corresponding elements
406+
* in the two supplied lists where the corresponding bit in the
407+
* supplied bit mask is set.
408+
*/
409+
static int levenshteinFieldwiseCompareRows(List<String> firstRow, List<String> secondRow, BitSet fieldMask) {
410+
411+
int result = 0;
412+
413+
for (int index = fieldMask.nextSetBit(0); index >= 0; index = fieldMask.nextSetBit(index + 1)) {
414+
result += levenshteinDistance((index < firstRow.size()) ? firstRow.get(index) : "",
379415
(index < secondRow.size()) ? secondRow.get(index) : "");
380416
}
381417

382-
Arrays.sort(distances);
383-
384-
return IntStream.of(distances).limit(distances.length - 1).sum();
418+
return result;
385419
}
386420

387421
/**

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/filestructurefinder/DelimitedFileStructureFinderTests.java

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,16 @@
1010
import org.supercsv.prefs.CsvPreference;
1111

1212
import java.io.IOException;
13+
import java.util.ArrayList;
1314
import java.util.Arrays;
15+
import java.util.BitSet;
1416
import java.util.Collections;
17+
import java.util.List;
1518

1619
import static org.elasticsearch.xpack.ml.filestructurefinder.DelimitedFileStructureFinder.levenshteinFieldwiseCompareRows;
1720
import static org.elasticsearch.xpack.ml.filestructurefinder.DelimitedFileStructureFinder.levenshteinDistance;
1821
import static org.hamcrest.Matchers.arrayContaining;
22+
import static org.hamcrest.Matchers.equalTo;
1923

2024
public class DelimitedFileStructureFinderTests extends FileStructureTestCase {
2125

@@ -449,15 +453,51 @@ public void testLevenshteinDistance() {
449453
assertEquals(0, levenshteinDistance("", ""));
450454
}
451455

456+
public void testMakeShortFieldMask() {
457+
458+
List<List<String>> rows = new ArrayList<>();
459+
rows.add(Arrays.asList(randomAlphaOfLength(5), randomAlphaOfLength(20), randomAlphaOfLength(5)));
460+
rows.add(Arrays.asList(randomAlphaOfLength(50), randomAlphaOfLength(5), randomAlphaOfLength(5)));
461+
rows.add(Arrays.asList(randomAlphaOfLength(5), randomAlphaOfLength(5), randomAlphaOfLength(5)));
462+
rows.add(Arrays.asList(randomAlphaOfLength(5), randomAlphaOfLength(5), randomAlphaOfLength(80)));
463+
464+
BitSet shortFieldMask = DelimitedFileStructureFinder.makeShortFieldMask(rows, 110);
465+
assertThat(shortFieldMask, equalTo(TimestampFormatFinder.stringToNumberPosBitSet("111")));
466+
shortFieldMask = DelimitedFileStructureFinder.makeShortFieldMask(rows, 80);
467+
assertThat(shortFieldMask, equalTo(TimestampFormatFinder.stringToNumberPosBitSet("11 ")));
468+
shortFieldMask = DelimitedFileStructureFinder.makeShortFieldMask(rows, 50);
469+
assertThat(shortFieldMask, equalTo(TimestampFormatFinder.stringToNumberPosBitSet(" 1 ")));
470+
shortFieldMask = DelimitedFileStructureFinder.makeShortFieldMask(rows, 20);
471+
assertThat(shortFieldMask, equalTo(TimestampFormatFinder.stringToNumberPosBitSet(" ")));
472+
}
473+
452474
public void testLevenshteinCompareRows() {
453475

454476
assertEquals(0, levenshteinFieldwiseCompareRows(Arrays.asList("cat", "dog"), Arrays.asList("cat", "dog")));
455-
assertEquals(0, levenshteinFieldwiseCompareRows(Arrays.asList("cat", "dog"), Arrays.asList("cat", "cat")));
456-
assertEquals(3, levenshteinFieldwiseCompareRows(Arrays.asList("cat", "dog"), Arrays.asList("dog", "cat")));
457-
assertEquals(3, levenshteinFieldwiseCompareRows(Arrays.asList("cat", "dog"), Arrays.asList("mouse", "cat")));
458-
assertEquals(5, levenshteinFieldwiseCompareRows(Arrays.asList("cat", "dog", "mouse"), Arrays.asList("mouse", "dog", "cat")));
459-
assertEquals(4, levenshteinFieldwiseCompareRows(Arrays.asList("cat", "dog", "mouse"), Arrays.asList("mouse", "mouse", "mouse")));
460-
assertEquals(7, levenshteinFieldwiseCompareRows(Arrays.asList("cat", "dog", "mouse"), Arrays.asList("mouse", "cat", "dog")));
477+
assertEquals(3, levenshteinFieldwiseCompareRows(Arrays.asList("cat", "dog"), Arrays.asList("cat", "cat")));
478+
assertEquals(6, levenshteinFieldwiseCompareRows(Arrays.asList("cat", "dog"), Arrays.asList("dog", "cat")));
479+
assertEquals(8, levenshteinFieldwiseCompareRows(Arrays.asList("cat", "dog"), Arrays.asList("mouse", "cat")));
480+
assertEquals(10, levenshteinFieldwiseCompareRows(Arrays.asList("cat", "dog", "mouse"), Arrays.asList("mouse", "dog", "cat")));
481+
assertEquals(9, levenshteinFieldwiseCompareRows(Arrays.asList("cat", "dog", "mouse"), Arrays.asList("mouse", "mouse", "mouse")));
482+
assertEquals(12, levenshteinFieldwiseCompareRows(Arrays.asList("cat", "dog", "mouse"), Arrays.asList("mouse", "cat", "dog")));
483+
}
484+
485+
public void testLevenshteinCompareRowsWithMask() {
486+
487+
assertEquals(0, levenshteinFieldwiseCompareRows(Arrays.asList("cat", "dog"), Arrays.asList("cat", "dog"),
488+
TimestampFormatFinder.stringToNumberPosBitSet(randomFrom(" ", "1 ", " 1", "11"))));
489+
assertEquals(0, levenshteinFieldwiseCompareRows(Arrays.asList("cat", "dog"), Arrays.asList("cat", "cat"),
490+
TimestampFormatFinder.stringToNumberPosBitSet(randomFrom(" ", "1 "))));
491+
assertEquals(3, levenshteinFieldwiseCompareRows(Arrays.asList("cat", "dog"), Arrays.asList("dog", "cat"),
492+
TimestampFormatFinder.stringToNumberPosBitSet(randomFrom(" 1", "1 "))));
493+
assertEquals(3, levenshteinFieldwiseCompareRows(Arrays.asList("cat", "dog"), Arrays.asList("mouse", "cat"),
494+
TimestampFormatFinder.stringToNumberPosBitSet(" 1")));
495+
assertEquals(5, levenshteinFieldwiseCompareRows(Arrays.asList("cat", "dog", "mouse"), Arrays.asList("mouse", "dog", "cat"),
496+
TimestampFormatFinder.stringToNumberPosBitSet(" 11")));
497+
assertEquals(4, levenshteinFieldwiseCompareRows(Arrays.asList("cat", "dog", "mouse"), Arrays.asList("mouse", "mouse", "mouse"),
498+
TimestampFormatFinder.stringToNumberPosBitSet(" 11")));
499+
assertEquals(7, levenshteinFieldwiseCompareRows(Arrays.asList("cat", "dog", "mouse"), Arrays.asList("mouse", "cat", "dog"),
500+
TimestampFormatFinder.stringToNumberPosBitSet(" 11")));
461501
}
462502

463503
public void testLineHasUnescapedQuote() {

0 commit comments

Comments
 (0)