Skip to content

Commit f0fb2a0

Browse files
authored
Add test coverage for VBuffer (#1804)
* Add ScaleBy test * Add VBuffer unit tests, and fix a couple of bugs * Fix test tolerance * Address PR comments and fix unit test tolerance * Fix AlignedArray bug
1 parent 93156b6 commit f0fb2a0

File tree

15 files changed

+1273
-117
lines changed

15 files changed

+1273
-117
lines changed

src/Microsoft.ML.Core/Utilities/MathUtils.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,7 @@ public static Float GetMedianInPlace(Float[] src, int count)
726726
return (src[iv - 1] + src[iv]) / 2;
727727
}
728728

729-
public static Double CosineSimilarity(Float[] a, Float[] b, int aIdx, int bIdx, int len)
729+
public static Double CosineSimilarity(ReadOnlySpan<Float> a, ReadOnlySpan<Float> b, int aIdx, int bIdx, int len)
730730
{
731731
const Double epsilon = 1e-12f;
732732
Contracts.Assert(len > 0);

src/Microsoft.ML.Core/Utilities/Utils.cs

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -530,28 +530,10 @@ public static int[] GetRandomPermutation(Random rand, int size)
530530
Contracts.Assert(size >= 0);
531531

532532
var res = GetIdentityPermutation(size);
533-
Shuffle(rand, res);
533+
Shuffle<int>(rand, res);
534534
return res;
535535
}
536536

537-
public static void Shuffle<T>(Random rand, T[] rgv)
538-
{
539-
Contracts.AssertValue(rand);
540-
Contracts.AssertValue(rgv);
541-
542-
Shuffle(rand, rgv, 0, rgv.Length);
543-
}
544-
545-
public static void Shuffle<T>(Random rand, T[] rgv, int min, int lim)
546-
{
547-
Contracts.AssertValue(rand);
548-
Contracts.AssertValue(rgv);
549-
Contracts.Check(0 <= min & min <= lim & lim <= rgv.Length);
550-
551-
for (int iv = min; iv < lim; iv++)
552-
Swap(ref rgv[iv], ref rgv[iv + rand.Next(lim - iv)]);
553-
}
554-
555537
public static bool AreEqual(Single[] arr1, Single[] arr2)
556538
{
557539
if (arr1 == arr2)
@@ -586,6 +568,14 @@ public static bool AreEqual(Double[] arr1, Double[] arr2)
586568
return true;
587569
}
588570

571+
public static void Shuffle<T>(Random rand, Span<T> rgv)
572+
{
573+
Contracts.AssertValue(rand);
574+
575+
for (int iv = 0; iv < rgv.Length; iv++)
576+
Swap(ref rgv[iv], ref rgv[iv + rand.Next(rgv.Length - iv)]);
577+
}
578+
589579
public static bool AreEqual(int[] arr1, int[] arr2)
590580
{
591581
if (arr1 == arr2)

src/Microsoft.ML.Core/Utilities/VBufferUtils.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1322,7 +1322,7 @@ public static void ApplyInto<TSrc1, TSrc2, TDst>(in VBuffer<TSrc1> a, in VBuffer
13221322
// REVIEW: Worth optimizing the newCount == a.Length case?
13231323
// Probably not...
13241324

1325-
editor = VBufferEditor.Create(ref dst, a.Length, newCount);
1325+
editor = VBufferEditor.Create(ref dst, a.Length, newCount, requireIndicesOnDense: true);
13261326
Span<int> indices = editor.Indices;
13271327

13281328
if (newCount == bValues.Length)

src/Microsoft.ML.CpuMath/AlignedArray.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,21 +110,21 @@ public Float this[int index]
110110
}
111111
}
112112

113-
public void CopyTo(Float[] dst, int index, int count)
113+
public void CopyTo(Span<Float> dst, int index, int count)
114114
{
115115
Contracts.Assert(0 <= count && count <= _size);
116116
Contracts.Assert(dst != null);
117117
Contracts.Assert(0 <= index && index <= dst.Length - count);
118-
Array.Copy(Items, _base, dst, index, count);
118+
Items.AsSpan(_base, count).CopyTo(dst.Slice(index));
119119
}
120120

121-
public void CopyTo(int start, Float[] dst, int index, int count)
121+
public void CopyTo(int start, Span<Float> dst, int index, int count)
122122
{
123123
Contracts.Assert(0 <= count);
124124
Contracts.Assert(0 <= start && start <= _size - count);
125125
Contracts.Assert(dst != null);
126126
Contracts.Assert(0 <= index && index <= dst.Length - count);
127-
Array.Copy(Items, start + _base, dst, index, count);
127+
Items.AsSpan(start + _base, count).CopyTo(dst.Slice(index));
128128
}
129129

130130
public void CopyFrom(ReadOnlySpan<Float> src)

src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ private void GenerateNextBatch()
444444
_batchEnd = newEnd;
445445
}
446446
_totalLeft -= _batchEnd;
447-
Utils.Shuffle(_rand, _batch, 0, _batchEnd);
447+
Utils.Shuffle(_rand, _batch.AsSpan(0, _batchEnd));
448448
}
449449

450450
public int Next()

src/Microsoft.ML.Data/Depricated/Vector/GenericSpanSortHelper.cs

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,49 +41,52 @@ internal static int FloorLog2PlusOne(int n)
4141
}
4242
}
4343

44-
internal partial class GenericSpanSortHelper<TKey, TValue>
44+
internal partial class GenericSpanSortHelper<TKey>
4545
where TKey : IComparable<TKey>
4646
{
47-
public static void Sort(Span<TKey> keys, Span<TValue> values, int index, int length)
47+
public static void Sort<TValue>(Span<TKey> keys, Span<TValue> values, int index, int length)
4848
{
4949
Contracts.Assert(keys != null, "Check the arguments in the caller!");
5050
Contracts.Assert(index >= 0 && length >= 0 && (keys.Length - index >= length), "Check the arguments in the caller!");
5151

5252
IntrospectiveSort(keys, values, index, length);
5353
}
5454

55-
private static void SwapIfGreaterWithItems(Span<TKey> keys, Span<TValue> values, int a, int b)
55+
public static void Sort(Span<TKey> keys, int index, int length)
56+
{
57+
Sort(keys, keys, index, length);
58+
}
59+
60+
private static void SwapIfGreaterWithItems<TValue>(Span<TKey> keys, Span<TValue> values, int a, int b)
5661
{
5762
if (a != b)
5863
{
5964
if (keys[a] != null && keys[a].CompareTo(keys[b]) > 0)
6065
{
6166
TKey key = keys[a];
62-
keys[a] = keys[b];
63-
keys[b] = key;
64-
6567
TValue value = values[a];
68+
keys[a] = keys[b];
6669
values[a] = values[b];
70+
keys[b] = key;
6771
values[b] = value;
6872
}
6973
}
7074
}
7175

72-
private static void Swap(Span<TKey> keys, Span<TValue> values, int i, int j)
76+
private static void Swap<TValue>(Span<TKey> keys, Span<TValue> values, int i, int j)
7377
{
7478
if (i != j)
7579
{
7680
TKey k = keys[i];
77-
keys[i] = keys[j];
78-
keys[j] = k;
79-
8081
TValue v = values[i];
82+
keys[i] = keys[j];
8183
values[i] = values[j];
84+
keys[j] = k;
8285
values[j] = v;
8386
}
8487
}
8588

86-
internal static void IntrospectiveSort(Span<TKey> keys, Span<TValue> values, int left, int length)
89+
internal static void IntrospectiveSort<TValue>(Span<TKey> keys, Span<TValue> values, int left, int length)
8790
{
8891
Contracts.Assert(keys != null);
8992
Contracts.Assert(values != null);
@@ -99,7 +102,7 @@ internal static void IntrospectiveSort(Span<TKey> keys, Span<TValue> values, int
99102
IntroSort(keys, values, left, length + left - 1, 2 * IntrospectiveSortUtilities.FloorLog2PlusOne(length));
100103
}
101104

102-
private static void IntroSort(Span<TKey> keys, Span<TValue> values, int lo, int hi, int depthLimit)
105+
private static void IntroSort<TValue>(Span<TKey> keys, Span<TValue> values, int lo, int hi, int depthLimit)
103106
{
104107
Contracts.Assert(keys != null);
105108
Contracts.Assert(values != null);
@@ -146,7 +149,7 @@ private static void IntroSort(Span<TKey> keys, Span<TValue> values, int lo, int
146149
}
147150
}
148151

149-
private static int PickPivotAndPartition(Span<TKey> keys, Span<TValue> values, int lo, int hi)
152+
private static int PickPivotAndPartition<TValue>(Span<TKey> keys, Span<TValue> values, int lo, int hi)
150153
{
151154
Contracts.Assert(keys != null);
152155
Contracts.Assert(values != null);
@@ -191,7 +194,7 @@ private static int PickPivotAndPartition(Span<TKey> keys, Span<TValue> values, i
191194
return left;
192195
}
193196

194-
private static void Heapsort(Span<TKey> keys, Span<TValue> values, int lo, int hi)
197+
private static void Heapsort<TValue>(Span<TKey> keys, Span<TValue> values, int lo, int hi)
195198
{
196199
Contracts.Assert(keys != null);
197200
Contracts.Assert(values != null);
@@ -211,7 +214,7 @@ private static void Heapsort(Span<TKey> keys, Span<TValue> values, int lo, int h
211214
}
212215
}
213216

214-
private static void DownHeap(Span<TKey> keys, Span<TValue> values, int i, int n, int lo)
217+
private static void DownHeap<TValue>(Span<TKey> keys, Span<TValue> values, int i, int n, int lo)
215218
{
216219
Contracts.Assert(keys != null);
217220
Contracts.Assert(lo >= 0);
@@ -237,7 +240,7 @@ private static void DownHeap(Span<TKey> keys, Span<TValue> values, int i, int n,
237240
values[lo + i - 1] = dValue;
238241
}
239242

240-
private static void InsertionSort(Span<TKey> keys, Span<TValue> values, int lo, int hi)
243+
private static void InsertionSort<TValue>(Span<TKey> keys, Span<TValue> values, int lo, int hi)
241244
{
242245
Contracts.Assert(keys != null);
243246
Contracts.Assert(values != null);
@@ -265,5 +268,4 @@ private static void InsertionSort(Span<TKey> keys, Span<TValue> values, int lo,
265268
}
266269
}
267270
}
268-
269271
}

src/Microsoft.ML.Data/Depricated/Vector/VBufferMathUtils.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,8 @@ public static void AddMultWithOffset(in VBuffer<Float> src, Float c, ref VBuffer
298298
editor = VBufferEditor.Create(ref dst,
299299
dst.Length,
300300
dstValues.Length + gapCount,
301-
keepOldOnResize: true);
301+
keepOldOnResize: true,
302+
requireIndicesOnDense: true);
302303
var indices = editor.Indices;
303304
values = editor.Values;
304305
if (gapCount > 0)

src/Microsoft.ML.Data/Depricated/Vector/VectorUtils.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ public static void SparsifyNormalize(ref VBuffer<Float> a, int top, int bottom,
150150
}
151151

152152
if (!aEditor.Indices.IsEmpty)
153-
GenericSpanSortHelper<int, float>.Sort(aEditor.Indices, aEditor.Values, 0, newCount);
153+
GenericSpanSortHelper<int>.Sort(aEditor.Indices, aEditor.Values, 0, newCount);
154154
a = aEditor.Commit();
155155
}
156156

src/Microsoft.ML.Data/Properties/AssemblyInfo.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TestFramework" + PublicKey.TestValue)]
99
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)]
10+
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Core.Tests" + PublicKey.TestValue)]
1011
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.InferenceTesting" + PublicKey.TestValue)]
1112
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.OnnxTransformTest" + PublicKey.TestValue)]
1213

src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ private ValueGetter<ReadOnlyMemory<char>> GetTextValueGetter<TSrc>(IRow input, i
273273
contributions.GetValues().CopyTo(values);
274274
var count = values.Length;
275275
var sb = new StringBuilder();
276-
GenericSpanSortHelper<int, float>.Sort(indices, values, 0, count);
276+
GenericSpanSortHelper<int>.Sort(indices, values, 0, count);
277277
for (var i = 0; i < count; i++)
278278
{
279279
var val = values[i];

src/Microsoft.ML.Data/Transforms/KeyToVector.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ private ValueGetter<VBuffer<float>> MakeGetterOne(IRow input, int iinfo)
478478
return;
479479
}
480480

481-
var editor = VBufferEditor.Create(ref dst, size, 1);
481+
var editor = VBufferEditor.Create(ref dst, size, 1, requireIndicesOnDense: true);
482482
editor.Values[0] = 1;
483483
editor.Indices[0] = (int)src - 1;
484484

src/Microsoft.ML.FastTree/Utils/LinqExtensions.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ public static int ArgMin<T>(this T[] arr) where T : IComparable<T>
2323
return argMin;
2424
}
2525

26-
public static int ArgMax<T>(this T[] arr) where T : IComparable<T>
26+
public static int ArgMax<T>(this ReadOnlySpan<T> span) where T : IComparable<T>
2727
{
28-
if (arr.Length == 0)
28+
if (span.Length == 0)
2929
return -1;
3030
int argMax = 0;
31-
for (int i = 1; i < arr.Length; i++)
31+
for (int i = 1; i < span.Length; i++)
3232
{
33-
if (arr[i].CompareTo(arr[argMax]) > 0)
33+
if (span[i].CompareTo(span[argMax]) > 0)
3434
argMax = i;
3535
}
3636
return argMax;

src/Microsoft.ML.Transforms/PermutationFeatureImportance.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ public static ImmutableArray<TResult>
180180

181181
int nextValuesIndex = 0;
182182

183-
Utils.Shuffle(RandomUtils.Create(shuffleSeed), featureValuesBuffer);
183+
Utils.Shuffle<float>(RandomUtils.Create(shuffleSeed), featureValuesBuffer);
184184

185185
Action<FeaturesBuffer, FeaturesBuffer, PermuterState> permuter =
186186
(src, dst, state) =>

0 commit comments

Comments
 (0)