Skip to content

Commit 3e03fce

Browse files
authored
Test GAM Public APIs (#2188)
* Refactoring the GAM Trainer / Predictor to move all Training information into the trainer, and make the GAM predictor generic. * Adding a test for IsSorted<double>, changing all IsSorted to use IList and renaming to IsMonotonicallyIncreasing.
1 parent a65da53 commit 3e03fce

File tree

9 files changed

+269
-66
lines changed

9 files changed

+269
-66
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/GeneralizedAdditiveModels.cs

+8-5
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,13 @@ public static void RunExample()
7777
// First, let's get the index of the variable we want to look at
7878
var studentTeacherRatioIndex = featureNames.ToList().FindIndex(str => str.Equals("TeacherRatio"));
7979

80-
// Next, let's get the array of bin upper bounds from the model for this feature
81-
var teacherRatioBinUpperBounds = gamModel.GetFeatureBinUpperBounds(studentTeacherRatioIndex);
82-
// And the array of bin weights; these are the effect size for each bin
83-
var teacherRatioFeatureWeights = gamModel.GetFeatureWeights(studentTeacherRatioIndex);
80+
// Next, let's get the array of histogram bin upper bounds from the model for this feature
81+
// For each feature, the shape function is calculated at `MaxBins` locations along the range of
82+
// values that the feature takes, and the resulting shape function can be seen as a histogram of
83+
// effects.
84+
var teacherRatioBinUpperBounds = gamModel.GetBinUpperBounds(studentTeacherRatioIndex);
85+
// And the array of bin effects; these are the effect size for each bin
86+
var teacherRatioBinEffects = gamModel.GetBinEffects(studentTeacherRatioIndex);
8487

8588
// Now, write the function to the console. The function is a set of bins, and the corresponding
8689
// function values. You can think of GAMs as building a bar-chart lookup table.
@@ -118,7 +121,7 @@ public static void RunExample()
118121
Console.WriteLine("Student-Teacher Ratio");
119122
for (int i = 0; i < teacherRatioBinUpperBounds.Length; i++)
120123
{
121-
Console.WriteLine($"x < {teacherRatioBinUpperBounds[i]:0.00} => {teacherRatioFeatureWeights[i]:0.000}");
124+
Console.WriteLine($"x < {teacherRatioBinUpperBounds[i]:0.00} => {teacherRatioBinEffects[i]:0.000}");
122125
}
123126
Console.WriteLine();
124127
}

src/Microsoft.ML.Core/Utilities/Utils.cs

+62-19
Original file line numberDiff line numberDiff line change
@@ -534,14 +534,6 @@ public static int[] GetRandomPermutation(Random rand, int size)
534534
return res;
535535
}
536536

537-
public static void Shuffle<T>(Random rand, Span<T> rgv)
538-
{
539-
Contracts.AssertValue(rand);
540-
541-
for (int iv = 0; iv < rgv.Length; iv++)
542-
Swap(ref rgv[iv], ref rgv[iv + rand.Next(rgv.Length - iv)]);
543-
}
544-
545537
public static bool AreEqual(float[] arr1, float[] arr2)
546538
{
547539
if (arr1 == arr2)
@@ -576,6 +568,14 @@ public static bool AreEqual(double[] arr1, double[] arr2)
576568
return true;
577569
}
578570

571+
public static void Shuffle<T>(Random rand, Span<T> rgv)
572+
{
573+
Contracts.AssertValue(rand);
574+
575+
for (int iv = 0; iv < rgv.Length; iv++)
576+
Swap(ref rgv[iv], ref rgv[iv + rand.Next(rgv.Length - iv)]);
577+
}
578+
579579
public static bool AreEqual(int[] arr1, int[] arr2)
580580
{
581581
if (arr1 == arr2)
@@ -615,37 +615,80 @@ public static string ExtractLettersAndNumbers(string value)
615615
return Regex.Replace(value, "[^A-Za-z0-9]", "");
616616
}
617617

618-
public static bool IsSorted(IList<float> values)
618+
/// <summary>
619+
/// Checks that an input IList is monotonically increasing.
620+
/// </summary>
621+
/// <param name="values">An array of values</param>
622+
/// <returns>True if the array is monotonically increasing (if each element is greater
623+
/// than or equal to previous elements); false otherwise. ILists containing NaN values
624+
/// are considered to be not monotonically increasing.</returns>
625+
public static bool IsMonotonicallyIncreasing(IList<float> values)
619626
{
620627
if (Utils.Size(values) <= 1)
621628
return true;
622629

623-
var prev = values[0];
624-
625-
for (int i = 1; i < values.Count; i++)
630+
var previousValue = values[0];
631+
var listLength = values.Count;
632+
for (int i = 1; i < listLength; i++)
626633
{
627-
if (!(values[i] >= prev))
634+
var currentValue = values[i];
635+
// Inverted check for NaNs
636+
if (!(currentValue >= previousValue))
628637
return false;
629638

630-
prev = values[i];
639+
previousValue = currentValue;
631640
}
632641

633642
return true;
634643
}
635644

636-
public static bool IsSorted(int[] values)
645+
/// <summary>
646+
/// Checks that an input array is monotonically increasing.
647+
/// </summary>
648+
/// <param name="values">An array of values</param>
649+
/// <returns>True if the array is monotonically increasing (if each element is greater
650+
/// than or equal to previous elements); false otherwise.</returns>
651+
public static bool IsMonotonicallyIncreasing(IList<int> values)
637652
{
638653
if (Utils.Size(values) <= 1)
639654
return true;
640655

641-
var prev = values[0];
656+
var previousValue = values[0];
657+
var listLength = values.Count;
658+
for (int i = 1; i < listLength; i++)
659+
{
660+
var currentValue = values[i];
661+
if (currentValue < previousValue)
662+
return false;
642663

643-
for (int i = 1; i < values.Length; i++)
664+
previousValue = currentValue;
665+
}
666+
667+
return true;
668+
}
669+
670+
/// <summary>
671+
/// Checks that an input array is monotonically increasing.
672+
/// </summary>
673+
/// <param name="values">An array of values</param>
674+
/// <returns>True if the array is monotonically increasing (if each element is greater
675+
/// than or equal to previous elements); false otherwise. Arrays containing NaN values
676+
/// are considered to be not monotonically increasing.</returns>
677+
public static bool IsMonotonicallyIncreasing(IList<double> values)
678+
{
679+
if (Utils.Size(values) <= 1)
680+
return true;
681+
682+
var previousValue = values[0];
683+
var listLength = values.Count;
684+
for (int i = 1; i < listLength; i++)
644685
{
645-
if (values[i] < prev)
686+
var currentValue = values[i];
687+
// Inverted check for NaNs
688+
if (!(currentValue >= previousValue))
646689
return false;
647690

648-
prev = values[i];
691+
previousValue = currentValue;
649692
}
650693

651694
return true;

src/Microsoft.ML.Data/Prediction/Calibrator.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -1653,9 +1653,9 @@ public PavCalibrator(IHostEnvironment env, ImmutableArray<float> mins, Immutable
16531653
_host.AssertNonEmpty(mins);
16541654
_host.AssertNonEmpty(maxes);
16551655
_host.AssertNonEmpty(values);
1656-
_host.Assert(Utils.IsSorted(mins));
1657-
_host.Assert(Utils.IsSorted(maxes));
1658-
_host.Assert(Utils.IsSorted(values));
1656+
_host.Assert(Utils.IsMonotonicallyIncreasing(mins));
1657+
_host.Assert(Utils.IsMonotonicallyIncreasing(maxes));
1658+
_host.Assert(Utils.IsMonotonicallyIncreasing(values));
16591659
_host.Assert(values.Length == 0 || (0 <= values[0] && values[values.Length - 1] <= 1));
16601660
_host.Assert(mins.Zip(maxes, (min, max) => min <= max).All(x => x));
16611661

src/Microsoft.ML.FastTree/Dataset/FeatureFlock.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1170,7 +1170,7 @@ protected SinglePartitionedIntArrayFlockBase(TIntArray bins, int[] hotFeatureSta
11701170
Contracts.AssertValue(binUpperBounds);
11711171
Contracts.Assert(Utils.Size(hotFeatureStarts) == binUpperBounds.Length + 1); // One more than number of features.
11721172
Contracts.Assert(hotFeatureStarts[0] == 1);
1173-
Contracts.Assert(Utils.IsSorted(hotFeatureStarts));
1173+
Contracts.Assert(Utils.IsMonotonicallyIncreasing(hotFeatureStarts));
11741174
Contracts.Assert(bins.Max() < hotFeatureStarts[hotFeatureStarts.Length - 1]);
11751175

11761176
Bins = bins;

src/Microsoft.ML.FastTree/GamModelParameters.cs

+55-24
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ private protected GamModelParametersBase(IHostEnvironment env, string name,
8989
// Check data validity
9090
Host.CheckValue(binEffects[i], nameof(binEffects), "Array contained null entries");
9191
Host.CheckParam(binUpperBounds[i].Length == binEffects[i].Length, nameof(binEffects), "Array contained wrong number of effect values");
92+
Host.CheckParam(Utils.IsMonotonicallyIncreasing(binUpperBounds[i]), nameof(binUpperBounds), "Array must be monotonically increasing");
9293

9394
// Update the value at zero
9495
_valueAtAllZero += GetBinEffect(i, 0, out _binsAtAllZero[i]);
@@ -282,44 +283,74 @@ private double GetBinEffect(int featureIndex, double featureValue, out int binIn
282283
/// Get the bin upper bounds for each feature.
283284
/// </summary>
284285
/// <param name="featureIndex">The index of the feature (in the training vector) to get.</param>
285-
/// <returns>The bin upper bounds. May be null if this feature has no bins.</returns>
286-
public double[] GetFeatureBinUpperBounds(int featureIndex)
286+
/// <returns>The bin upper bounds. May be zero length if this feature has no bins.</returns>
287+
public double[] GetBinUpperBounds(int featureIndex)
287288
{
288289
Host.Check(0 <= featureIndex && featureIndex < NumShapeFunctions, "Index out of range.");
289-
double[] featureBins;
290-
if (_inputFeatureToShapeFunctionMap.TryGetValue(featureIndex, out int j))
291-
{
292-
featureBins = new double[_binUpperBounds[j].Length];
293-
_binUpperBounds[j].CopyTo(featureBins, 0);
294-
}
295-
else
290+
if (!_inputFeatureToShapeFunctionMap.TryGetValue(featureIndex, out int j))
291+
return new double[0];
292+
293+
var binUpperBounds = new double[_binUpperBounds[j].Length];
294+
_binUpperBounds[j].CopyTo(binUpperBounds, 0);
295+
return binUpperBounds;
296+
}
297+
298+
/// <summary>
299+
/// Get all the bin upper bounds.
300+
/// </summary>
301+
public double[][] GetBinUpperBounds()
302+
{
303+
double[][] binUpperBounds = new double[NumShapeFunctions][];
304+
for (int i = 0; i < NumShapeFunctions; i++)
296305
{
297-
featureBins = new double[0];
306+
if (_inputFeatureToShapeFunctionMap.TryGetValue(i, out int j))
307+
{
308+
binUpperBounds[i] = new double[_binUpperBounds[j].Length];
309+
_binUpperBounds[j].CopyTo(binUpperBounds[i], 0);
310+
}
311+
else
312+
{
313+
binUpperBounds[i] = new double[0];
314+
}
298315
}
299-
300-
return featureBins;
316+
return binUpperBounds;
301317
}
302318

303319
/// <summary>
304320
/// Get the binned weights for each feature.
305321
/// </summary>
306322
/// <param name="featureIndex">The index of the feature (in the training vector) to get.</param>
307-
/// <returns>The binned weights for each feature.</returns>
308-
public double[] GetFeatureWeights(int featureIndex)
323+
/// <returns>The binned effects for each feature. May be zero length if this feature has no bins.</returns>
324+
public double[] GetBinEffects(int featureIndex)
309325
{
310326
Host.Check(0 <= featureIndex && featureIndex < NumShapeFunctions, "Index out of range.");
311-
double[] featureWeights;
312-
if (_inputFeatureToShapeFunctionMap.TryGetValue(featureIndex, out int j))
313-
{
314-
featureWeights = new double[_binEffects[j].Length];
315-
_binEffects[j].CopyTo(featureWeights, 0);
316-
}
317-
else
327+
if (!_inputFeatureToShapeFunctionMap.TryGetValue(featureIndex, out int j))
328+
return new double[0];
329+
330+
var binEffects = new double[_binEffects[j].Length];
331+
_binEffects[j].CopyTo(binEffects, 0);
332+
return binEffects;
333+
}
334+
335+
/// <summary>
336+
/// Get all the binned effects.
337+
/// </summary>
338+
public double[][] GetBinEffects()
339+
{
340+
double[][] binEffects = new double[NumShapeFunctions][];
341+
for (int i = 0; i < NumShapeFunctions; i++)
318342
{
319-
featureWeights = new double[0];
343+
if (_inputFeatureToShapeFunctionMap.TryGetValue(i, out int j))
344+
{
345+
binEffects[i] = new double[_binEffects[j].Length];
346+
_binEffects[j].CopyTo(binEffects[i], 0);
347+
}
348+
else
349+
{
350+
binEffects[i] = new double[0];
351+
}
320352
}
321-
322-
return featureWeights;
353+
return binEffects;
323354
}
324355

325356
void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema)

src/Microsoft.ML.FastTree/QuantileStatistics.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public QuantileStatistics(float[] data, float[] weights = null, bool isSorted =
6060
if (!isSorted)
6161
Array.Sort(_data);
6262
else
63-
Contracts.Assert(Utils.IsSorted(_data));
63+
Contracts.Assert(Utils.IsMonotonicallyIncreasing(_data));
6464
}
6565

6666
/// <summary>

src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ private MulticlassLogisticRegressionModelParameters(IHostEnvironment env, ModelL
512512
Host.CheckDecode(numStarts == _numClasses + 1);
513513
int[] starts = ctx.Reader.ReadIntArray(numStarts);
514514
Host.CheckDecode(starts[0] == 0);
515-
Host.CheckDecode(Utils.IsSorted(starts));
515+
Host.CheckDecode(Utils.IsMonotonicallyIncreasing(starts));
516516

517517
int numIndices = ctx.Reader.ReadInt32();
518518
Host.CheckDecode(numIndices == starts[starts.Length - 1]);

0 commit comments

Comments
 (0)