Skip to content
This repository was archived by the owner on Aug 2, 2023. It is now read-only.

Commit f217d59

Browse files
authored
Prevent DataFrame.Sample() method from returning duplicated rows (#2939)
* resolves #2806 * replace forloop with ArraySegment<T> * reduce shuffle loop operations from O(Rows.Count) to O(numberOfRows)
1 parent 54d3f56 commit f217d59

File tree

4 files changed

+44
-7
lines changed

4 files changed

+44
-7
lines changed

src/Microsoft.Data.Analysis/DataFrame.cs

+18-4
Original file line numberDiff line numberDiff line change
@@ -328,14 +328,28 @@ public DataFrame AddSuffix(string suffix, bool inPlace = false)
328328
/// <param name="numberOfRows">Number of rows in the returned DataFrame</param>
329329
public DataFrame Sample(int numberOfRows)
330330
{
331+
if (numberOfRows > Rows.Count)
332+
{
333+
throw new ArgumentException(string.Format(Strings.ExceedsNumberOfRows, Rows.Count), nameof(numberOfRows));
334+
}
335+
336+
int shuffleLowerLimit = 0;
337+
int shuffleUpperLimit = (int)Math.Min(Int32.MaxValue, Rows.Count);
338+
339+
int[] shuffleArray = Enumerable.Range(0, shuffleUpperLimit).ToArray();
331340
Random rand = new Random();
332-
PrimitiveDataFrameColumn<long> indices = new PrimitiveDataFrameColumn<long>("Indices", numberOfRows);
333-
int randMaxValue = (int)Math.Min(Int32.MaxValue, Rows.Count);
334-
for (long i = 0; i < numberOfRows; i++)
341+
while (shuffleLowerLimit < numberOfRows)
335342
{
336-
indices[i] = rand.Next(randMaxValue);
343+
int randomIndex = rand.Next(shuffleLowerLimit, shuffleUpperLimit);
344+
int temp = shuffleArray[shuffleLowerLimit];
345+
shuffleArray[shuffleLowerLimit] = shuffleArray[randomIndex];
346+
shuffleArray[randomIndex] = temp;
347+
shuffleLowerLimit++;
337348
}
349+
ArraySegment<int> segment = new ArraySegment<int>(shuffleArray, 0, shuffleLowerLimit);
338350

351+
PrimitiveDataFrameColumn<int> indices = new PrimitiveDataFrameColumn<int>("indices", segment);
352+
339353
return Clone(indices);
340354
}
341355

src/Microsoft.Data.Analysis/strings.Designer.cs

+9
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/Microsoft.Data.Analysis/strings.resx

+4-1
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@
141141
<data name="ExceedsNumberOfColumns" xml:space="preserve">
142142
<value>Parameter.Count exceeds the number of columns({0}) in the DataFrame </value>
143143
</data>
144+
<data name="ExceedsNumberOfRows" xml:space="preserve">
145+
<value>Parameter.Count exceeds the number of rows({0}) in the DataFrame </value>
146+
</data>
144147
<data name="ExpectedEitherGuessRowsOrDataTypes" xml:space="preserve">
145148
<value>Expected either {0} or {1} to be provided</value>
146149
</data>
@@ -186,4 +189,4 @@
186189
<data name="SpansMultipleBuffers" xml:space="preserve">
187190
<value>Cannot span multiple buffers</value>
188191
</data>
189-
</root>
192+
</root>

tests/Microsoft.Data.Analysis.Tests/DataFrameTests.cs

+13-2
Original file line numberDiff line numberDiff line change
@@ -1561,9 +1561,20 @@ public void TestPrefixAndSuffix()
15611561
public void TestSample()
15621562
{
15631563
DataFrame df = MakeDataFrameWithAllColumnTypes(10);
1564-
DataFrame sampled = df.Sample(3);
1565-
Assert.Equal(3, sampled.Rows.Count);
1564+
DataFrame sampled = df.Sample(7);
1565+
Assert.Equal(7, sampled.Rows.Count);
15661566
Assert.Equal(df.Columns.Count, sampled.Columns.Count);
1567+
1568+
// all sampled rows should be unique.
1569+
HashSet<int?> uniqueRowValues = new HashSet<int?>();
1570+
foreach(int? value in sampled.Columns["Int"])
1571+
{
1572+
uniqueRowValues.Add(value);
1573+
}
1574+
Assert.Equal(uniqueRowValues.Count, sampled.Rows.Count);
1575+
1576+
// should throw exception as sample size is greater than dataframe rows
1577+
Assert.Throws<ArgumentException>(()=> df.Sample(13));
15671578
}
15681579

15691580
[Fact]

0 commit comments

Comments
 (0)