Skip to content

Add test coverage for VBuffer #1804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Core/Utilities/MathUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ public static Float GetMedianInPlace(Float[] src, int count)
return (src[iv - 1] + src[iv]) / 2;
}

public static Double CosineSimilarity(Float[] a, Float[] b, int aIdx, int bIdx, int len)
public static Double CosineSimilarity(ReadOnlySpan<Float> a, ReadOnlySpan<Float> b, int aIdx, int bIdx, int len)
{
const Double epsilon = 1e-12f;
Contracts.Assert(len > 0);
Expand Down
28 changes: 9 additions & 19 deletions src/Microsoft.ML.Core/Utilities/Utils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -530,28 +530,10 @@ public static int[] GetRandomPermutation(Random rand, int size)
Contracts.Assert(size >= 0);

var res = GetIdentityPermutation(size);
Shuffle(rand, res);
Shuffle<int>(rand, res);
return res;
}

public static void Shuffle<T>(Random rand, T[] rgv)
{
Contracts.AssertValue(rand);
Contracts.AssertValue(rgv);

Shuffle(rand, rgv, 0, rgv.Length);
}

public static void Shuffle<T>(Random rand, T[] rgv, int min, int lim)
{
Contracts.AssertValue(rand);
Contracts.AssertValue(rgv);
Contracts.Check(0 <= min & min <= lim & lim <= rgv.Length);

for (int iv = min; iv < lim; iv++)
Swap(ref rgv[iv], ref rgv[iv + rand.Next(lim - iv)]);
}

public static bool AreEqual(Single[] arr1, Single[] arr2)
{
if (arr1 == arr2)
Expand Down Expand Up @@ -586,6 +568,14 @@ public static bool AreEqual(Double[] arr1, Double[] arr2)
return true;
}

public static void Shuffle<T>(Random rand, Span<T> rgv)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rgv [](start = 59, length = 3)

just curios, what it stands for?
random general values?

{
Contracts.AssertValue(rand);

for (int iv = 0; iv < rgv.Length; iv++)
Swap(ref rgv[iv], ref rgv[iv + rand.Next(rgv.Length - iv)]);
}

public static bool AreEqual(int[] arr1, int[] arr2)
{
if (arr1 == arr2)
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Core/Utilities/VBufferUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1322,7 +1322,7 @@ public static void ApplyInto<TSrc1, TSrc2, TDst>(in VBuffer<TSrc1> a, in VBuffer
// REVIEW: Worth optimizing the newCount == a.Length case?
// Probably not...

editor = VBufferEditor.Create(ref dst, a.Length, newCount);
editor = VBufferEditor.Create(ref dst, a.Length, newCount, requireIndicesOnDense: true);
Span<int> indices = editor.Indices;

if (newCount == bValues.Length)
Expand Down
8 changes: 4 additions & 4 deletions src/Microsoft.ML.CpuMath/AlignedArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -110,21 +110,21 @@ public Float this[int index]
}
}

public void CopyTo(Float[] dst, int index, int count)
public void CopyTo(Span<Float> dst, int index, int count)
{
Contracts.Assert(0 <= count && count <= _size);
Contracts.Assert(dst != null);
Contracts.Assert(0 <= index && index <= dst.Length - count);
Array.Copy(Items, _base, dst, index, count);
Items.AsSpan(_base, count).CopyTo(dst.Slice(index));
}

public void CopyTo(int start, Float[] dst, int index, int count)
public void CopyTo(int start, Span<Float> dst, int index, int count)
{
Contracts.Assert(0 <= count);
Contracts.Assert(0 <= start && start <= _size - count);
Contracts.Assert(dst != null);
Contracts.Assert(0 <= index && index <= dst.Length - count);
Array.Copy(Items, start + _base, dst, index, count);
Items.AsSpan(start + _base, count).CopyTo(dst.Slice(index));
}

public void CopyFrom(ReadOnlySpan<Float> src)
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ private void GenerateNextBatch()
_batchEnd = newEnd;
}
_totalLeft -= _batchEnd;
Utils.Shuffle(_rand, _batch, 0, _batchEnd);
Utils.Shuffle(_rand, _batch.AsSpan(0, _batchEnd));
}

public int Next()
Expand Down
36 changes: 19 additions & 17 deletions src/Microsoft.ML.Data/Depricated/Vector/GenericSpanSortHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,49 +41,52 @@ internal static int FloorLog2PlusOne(int n)
}
}

internal partial class GenericSpanSortHelper<TKey, TValue>
internal partial class GenericSpanSortHelper<TKey>
where TKey : IComparable<TKey>
{
public static void Sort(Span<TKey> keys, Span<TValue> values, int index, int length)
public static void Sort<TValue>(Span<TKey> keys, Span<TValue> values, int index, int length)
{
Contracts.Assert(keys != null, "Check the arguments in the caller!");
Contracts.Assert(index >= 0 && length >= 0 && (keys.Length - index >= length), "Check the arguments in the caller!");

IntrospectiveSort(keys, values, index, length);
}

private static void SwapIfGreaterWithItems(Span<TKey> keys, Span<TValue> values, int a, int b)
public static void Sort(Span<TKey> keys, int index, int length)
{
Sort(keys, keys, index, length);
}

private static void SwapIfGreaterWithItems<TValue>(Span<TKey> keys, Span<TValue> values, int a, int b)
{
if (a != b)
{
if (keys[a] != null && keys[a].CompareTo(keys[b]) > 0)
{
TKey key = keys[a];
keys[a] = keys[b];
keys[b] = key;

TValue value = values[a];
keys[a] = keys[b];
values[a] = values[b];
keys[b] = key;
values[b] = value;
}
}
}

private static void Swap(Span<TKey> keys, Span<TValue> values, int i, int j)
private static void Swap<TValue>(Span<TKey> keys, Span<TValue> values, int i, int j)
{
if (i != j)
{
TKey k = keys[i];
keys[i] = keys[j];
keys[j] = k;

TValue v = values[i];
keys[i] = keys[j];
values[i] = values[j];
keys[j] = k;
values[j] = v;
}
}

internal static void IntrospectiveSort(Span<TKey> keys, Span<TValue> values, int left, int length)
internal static void IntrospectiveSort<TValue>(Span<TKey> keys, Span<TValue> values, int left, int length)
{
Contracts.Assert(keys != null);
Contracts.Assert(values != null);
Expand All @@ -99,7 +102,7 @@ internal static void IntrospectiveSort(Span<TKey> keys, Span<TValue> values, int
IntroSort(keys, values, left, length + left - 1, 2 * IntrospectiveSortUtilities.FloorLog2PlusOne(length));
}

private static void IntroSort(Span<TKey> keys, Span<TValue> values, int lo, int hi, int depthLimit)
private static void IntroSort<TValue>(Span<TKey> keys, Span<TValue> values, int lo, int hi, int depthLimit)
{
Contracts.Assert(keys != null);
Contracts.Assert(values != null);
Expand Down Expand Up @@ -146,7 +149,7 @@ private static void IntroSort(Span<TKey> keys, Span<TValue> values, int lo, int
}
}

private static int PickPivotAndPartition(Span<TKey> keys, Span<TValue> values, int lo, int hi)
private static int PickPivotAndPartition<TValue>(Span<TKey> keys, Span<TValue> values, int lo, int hi)
{
Contracts.Assert(keys != null);
Contracts.Assert(values != null);
Expand Down Expand Up @@ -191,7 +194,7 @@ private static int PickPivotAndPartition(Span<TKey> keys, Span<TValue> values, i
return left;
}

private static void Heapsort(Span<TKey> keys, Span<TValue> values, int lo, int hi)
private static void Heapsort<TValue>(Span<TKey> keys, Span<TValue> values, int lo, int hi)
{
Contracts.Assert(keys != null);
Contracts.Assert(values != null);
Expand All @@ -211,7 +214,7 @@ private static void Heapsort(Span<TKey> keys, Span<TValue> values, int lo, int h
}
}

private static void DownHeap(Span<TKey> keys, Span<TValue> values, int i, int n, int lo)
private static void DownHeap<TValue>(Span<TKey> keys, Span<TValue> values, int i, int n, int lo)
{
Contracts.Assert(keys != null);
Contracts.Assert(lo >= 0);
Expand All @@ -237,7 +240,7 @@ private static void DownHeap(Span<TKey> keys, Span<TValue> values, int i, int n,
values[lo + i - 1] = dValue;
}

private static void InsertionSort(Span<TKey> keys, Span<TValue> values, int lo, int hi)
private static void InsertionSort<TValue>(Span<TKey> keys, Span<TValue> values, int lo, int hi)
{
Contracts.Assert(keys != null);
Contracts.Assert(values != null);
Expand Down Expand Up @@ -265,5 +268,4 @@ private static void InsertionSort(Span<TKey> keys, Span<TValue> values, int lo,
}
}
}

}
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Data/Depricated/Vector/VBufferMathUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,8 @@ public static void AddMultWithOffset(in VBuffer<Float> src, Float c, ref VBuffer
editor = VBufferEditor.Create(ref dst,
dst.Length,
dstValues.Length + gapCount,
keepOldOnResize: true);
keepOldOnResize: true,
requireIndicesOnDense: true);
var indices = editor.Indices;
values = editor.Values;
if (gapCount > 0)
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Depricated/Vector/VectorUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ public static void SparsifyNormalize(ref VBuffer<Float> a, int top, int bottom,
}

if (!aEditor.Indices.IsEmpty)
GenericSpanSortHelper<int, float>.Sort(aEditor.Indices, aEditor.Values, 0, newCount);
GenericSpanSortHelper<int>.Sort(aEditor.Indices, aEditor.Values, 0, newCount);
a = aEditor.Commit();
}

Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.ML.Data/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TestFramework" + PublicKey.TestValue)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Core.Tests" + PublicKey.TestValue)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.InferenceTesting" + PublicKey.TestValue)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.OnnxTransformTest" + PublicKey.TestValue)]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ private ValueGetter<ReadOnlyMemory<char>> GetTextValueGetter<TSrc>(IRow input, i
contributions.GetValues().CopyTo(values);
var count = values.Length;
var sb = new StringBuilder();
GenericSpanSortHelper<int, float>.Sort(indices, values, 0, count);
GenericSpanSortHelper<int>.Sort(indices, values, 0, count);
for (var i = 0; i < count; i++)
{
var val = values[i];
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Transforms/KeyToVector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ private ValueGetter<VBuffer<float>> MakeGetterOne(IRow input, int iinfo)
return;
}

var editor = VBufferEditor.Create(ref dst, size, 1);
var editor = VBufferEditor.Create(ref dst, size, 1, requireIndicesOnDense: true);
editor.Values[0] = 1;
editor.Indices[0] = (int)src - 1;

Expand Down
8 changes: 4 additions & 4 deletions src/Microsoft.ML.FastTree/Utils/LinqExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ public static int ArgMin<T>(this T[] arr) where T : IComparable<T>
return argMin;
}

public static int ArgMax<T>(this T[] arr) where T : IComparable<T>
public static int ArgMax<T>(this ReadOnlySpan<T> span) where T : IComparable<T>
Copy link
Member

@eerhardt eerhardt Dec 3, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to remove the method above this, since arrays are implicitly convertible to ReadOnlySpan. #Resolved

{
if (arr.Length == 0)
if (span.Length == 0)
return -1;
int argMax = 0;
for (int i = 1; i < arr.Length; i++)
for (int i = 1; i < span.Length; i++)
{
if (arr[i].CompareTo(arr[argMax]) > 0)
if (span[i].CompareTo(span[argMax]) > 0)
argMax = i;
}
return argMax;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ public static ImmutableArray<TResult>

int nextValuesIndex = 0;

Utils.Shuffle(RandomUtils.Create(shuffleSeed), featureValuesBuffer);
Utils.Shuffle<float>(RandomUtils.Create(shuffleSeed), featureValuesBuffer);

Action<FeaturesBuffer, FeaturesBuffer, PermuterState> permuter =
(src, dst, state) =>
Expand Down
Loading