Skip to content

Commit 759dafb

Browse files
Ivanidzo4kaZruty0
authored andcommitted
API scenarios implemented with low-level functions (dotnet#653)
Added scenarios implementation for low-level API
1 parent 307b38f commit 759dafb

18 files changed

+1143
-6
lines changed

src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public interface IArgsComponent : IComponentFactory
2424
/// <summary>
2525
/// An interface for creating a component with no extra parameters (other than an <see cref="IHostEnvironment"/>).
2626
/// </summary>
27-
public interface IComponentFactory<out TComponent>: IComponentFactory
27+
public interface IComponentFactory<out TComponent> : IComponentFactory
2828
{
2929
TComponent CreateComponent(IHostEnvironment env);
3030
}
@@ -57,6 +57,21 @@ public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1)
5757
}
5858
}
5959

60+
public class SimpleComponentFactory<TComponent> : IComponentFactory<TComponent>
61+
{
62+
private Func<IHostEnvironment, TComponent> _factory;
63+
64+
public SimpleComponentFactory(Func<IHostEnvironment, TComponent> factory)
65+
{
66+
_factory = factory;
67+
}
68+
69+
public TComponent CreateComponent(IHostEnvironment env)
70+
{
71+
return _factory(env);
72+
}
73+
}
74+
6075
/// <summary>
6176
/// An interface for creating a component when we take two extra parameters (and an <see cref="IHostEnvironment"/>).
6277
/// </summary>

src/Microsoft.ML.FastTree/FastTreeArguments.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -228,14 +228,14 @@ public abstract class TreeArgs : LearnerInputBaseWithGroupId
228228
// REVIEW: Different from original FastRank arguments (shortname l vs. nl). Different default from TLC FR Wrapper (20 vs. 20).
229229
[Argument(ArgumentType.LastOccurenceWins, HelpText = "The max number of leaves in each regression tree", ShortName = "nl", SortOrder = 2)]
230230
[TGUI(Description = "The maximum number of leaves per tree", SuggestedSweeps = "2-128;log;inc:4")]
231-
[TlcModule.SweepableLongParamAttribute("NumLeaves", 2, 128, isLogScale:true, stepSize:4)]
231+
[TlcModule.SweepableLongParamAttribute("NumLeaves", 2, 128, isLogScale: true, stepSize: 4)]
232232
public int NumLeaves = 20;
233233

234234
// REVIEW: Arrays not supported in GUI
235235
// REVIEW: Different shortname than FastRank module. Same as the TLC FRWrapper.
236236
[Argument(ArgumentType.LastOccurenceWins, HelpText = "The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data", ShortName = "mil", SortOrder = 3)]
237237
[TGUI(Description = "Minimum number of training instances required to form a leaf", SuggestedSweeps = "1,10,50")]
238-
[TlcModule.SweepableDiscreteParamAttribute("MinDocumentsInLeafs", new object[] {1, 10, 50})]
238+
[TlcModule.SweepableDiscreteParamAttribute("MinDocumentsInLeafs", new object[] { 1, 10, 50 })]
239239
public int MinDocumentsInLeafs = 10;
240240

241241
// REVIEW: Different shortname than FastRank module. Same as the TLC FRWrapper.
@@ -364,17 +364,17 @@ public enum OptimizationAlgorithmType { GradientDescent, AcceleratedGradientDesc
364364

365365
[Argument(ArgumentType.LastOccurenceWins, HelpText = "The learning rate", ShortName = "lr", SortOrder = 4)]
366366
[TGUI(Label = "Learning Rate", SuggestedSweeps = "0.025-0.4;log")]
367-
[TlcModule.SweepableFloatParamAttribute("LearningRates", 0.025f, 0.4f, isLogScale:true)]
367+
[TlcModule.SweepableFloatParamAttribute("LearningRates", 0.025f, 0.4f, isLogScale: true)]
368368
public Double LearningRates = 0.2;
369369

370370
[Argument(ArgumentType.AtMostOnce, HelpText = "Shrinkage", ShortName = "shrk")]
371371
[TGUI(Label = "Shrinkage", SuggestedSweeps = "0.25-4;log")]
372-
[TlcModule.SweepableFloatParamAttribute("Shrinkage", 0.025f, 4f, isLogScale:true)]
372+
[TlcModule.SweepableFloatParamAttribute("Shrinkage", 0.025f, 4f, isLogScale: true)]
373373
public Double Shrinkage = 1;
374374

375375
[Argument(ArgumentType.AtMostOnce, HelpText = "Dropout rate for tree regularization", ShortName = "tdrop")]
376376
[TGUI(SuggestedSweeps = "0,0.000000001,0.05,0.1,0.2")]
377-
[TlcModule.SweepableDiscreteParamAttribute("DropoutRate", new object[] { 0.0f, 1E-9f, 0.05f, 0.1f, 0.2f})]
377+
[TlcModule.SweepableDiscreteParamAttribute("DropoutRate", new object[] { 0.0f, 1E-9f, 0.05f, 0.1f, 0.2f })]
378378
public Double DropoutRate = 0;
379379

380380
[Argument(ArgumentType.AtMostOnce, HelpText = "Sample each query 1 in k times in the GetDerivatives function", ShortName = "sr")]
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime.Api;
6+
using Microsoft.ML.TestFramework;
7+
using Xunit.Abstractions;
8+
9+
namespace Microsoft.ML.Tests.Scenarios.Api
10+
{
11+
/// <summary>
12+
/// Common utility functions for API scenarios tests.
13+
/// </summary>
14+
public partial class ApiScenariosTests : BaseTestClass
15+
{
16+
public ApiScenariosTests(ITestOutputHelper output) : base(output)
17+
{
18+
}
19+
20+
public const string IrisDataPath = "iris.data";
21+
public const string SentimentDataPath = "wikipedia-detox-250-line-data.tsv";
22+
public const string SentimentTestPath = "wikipedia-detox-250-line-test.tsv";
23+
24+
public class IrisData : IrisDataNoLabel
25+
{
26+
public string Label;
27+
}
28+
29+
public class IrisDataNoLabel
30+
{
31+
public float SepalLength;
32+
public float SepalWidth;
33+
public float PetalLength;
34+
public float PetalWidth;
35+
}
36+
37+
public class IrisPrediction
38+
{
39+
public string PredictedLabel;
40+
public float[] Score;
41+
}
42+
43+
public class SentimentData
44+
{
45+
[ColumnName("Label")]
46+
public bool Sentiment;
47+
public string SentimentText;
48+
}
49+
50+
public class SentimentPrediction
51+
{
52+
[ColumnName("PredictedLabel")]
53+
public bool Sentiment;
54+
55+
public float Score;
56+
}
57+
}
58+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime.Data;
6+
using Microsoft.ML.Runtime.Learners;
7+
using Xunit;
8+
9+
namespace Microsoft.ML.Tests.Scenarios.Api
10+
{
11+
public partial class ApiScenariosTests
12+
{
13+
/// <summary>
14+
/// Auto-normalization and caching: It should be relatively easy for normalization
15+
/// and caching to be introduced for training, if the trainer supports or would benefit
16+
/// from that.
17+
/// </summary>
18+
[Fact]
19+
public void AutoNormalizationAndCaching()
20+
{
21+
var dataPath = GetDataPath(SentimentDataPath);
22+
var testDataPath = GetDataPath(SentimentTestPath);
23+
24+
using (var env = new TlcEnvironment(seed: 1, conc: 1))
25+
{
26+
// Pipeline.
27+
var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));
28+
29+
var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(false), loader);
30+
31+
// Train.
32+
var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
33+
{
34+
NumThreads = 1,
35+
ConvergenceTolerance = 1f
36+
});
37+
38+
// Auto-caching.
39+
IDataView trainData = trainer.Info.WantCaching ? (IDataView)new CacheDataView(env, trans, prefetch: null) : trans;
40+
var trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");
41+
42+
// Auto-normalization.
43+
NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
44+
var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));
45+
}
46+
47+
}
48+
}
49+
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Models;
6+
using Microsoft.ML.Runtime.Data;
7+
using Microsoft.ML.Runtime.Learners;
8+
using System;
9+
using System.Collections.Generic;
10+
using System.Linq;
11+
using Xunit;
12+
13+
namespace Microsoft.ML.Tests.Scenarios.Api
14+
{
15+
public partial class ApiScenariosTests
16+
{
17+
/// <summary>
18+
/// Cross-validation: Have a mechanism to do cross validation, that is, you come up with
19+
/// a data source (optionally with stratification column), come up with an instantiable transform
20+
/// and trainer pipeline, and it will handle (1) splitting up the data, (2) training the separate
21+
/// pipelines on in-fold data, (3) scoring on the out-fold data, (4) returning the set of
22+
/// evaluations and optionally trained pipes. (People always want metrics out of xfold,
23+
/// they sometimes want the actual models too.)
24+
/// </summary>
25+
[Fact]
26+
void CrossValidation()
27+
{
28+
var dataPath = GetDataPath(SentimentDataPath);
29+
var testDataPath = GetDataPath(SentimentTestPath);
30+
31+
int numFolds = 5;
32+
using (var env = new TlcEnvironment(seed: 1, conc: 1))
33+
{
34+
// Pipeline.
35+
var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));
36+
37+
var text = TextTransform.Create(env, MakeSentimentTextTransformArgs(false), loader);
38+
IDataView trans = new GenerateNumberTransform(env, text, "StratificationColumn");
39+
// Train.
40+
var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
41+
{
42+
NumThreads = 1,
43+
ConvergenceTolerance = 1f
44+
});
45+
46+
47+
var metrics = new List<BinaryClassificationMetrics>();
48+
for (int fold = 0; fold < numFolds; fold++)
49+
{
50+
IDataView trainPipe = new RangeFilter(env, new RangeFilter.Arguments()
51+
{
52+
Column = "StratificationColumn",
53+
Min = (Double)fold / numFolds,
54+
Max = (Double)(fold + 1) / numFolds,
55+
Complement = true
56+
}, trans);
57+
trainPipe = new OpaqueDataView(trainPipe);
58+
var trainData = new RoleMappedData(trainPipe, label: "Label", feature: "Features");
59+
// Auto-normalization.
60+
NormalizeTransform.CreateIfNeeded(env, ref trainData, trainer);
61+
var preCachedData = trainData;
62+
// Auto-caching.
63+
if (trainer.Info.WantCaching)
64+
{
65+
var prefetch = trainData.Schema.GetColumnRoles().Select(kc => kc.Value.Index).ToArray();
66+
var cacheView = new CacheDataView(env, trainData.Data, prefetch);
67+
// Because the prefetching worked, we know that these are valid columns.
68+
trainData = new RoleMappedData(cacheView, trainData.Schema.GetColumnRoleNames());
69+
}
70+
71+
var predictor = trainer.Train(new Runtime.TrainContext(trainData));
72+
IDataView testPipe = new RangeFilter(env, new RangeFilter.Arguments()
73+
{
74+
Column = "StratificationColumn",
75+
Min = (Double)fold / numFolds,
76+
Max = (Double)(fold + 1) / numFolds,
77+
Complement = false
78+
}, trans);
79+
testPipe = new OpaqueDataView(testPipe);
80+
var pipe = ApplyTransformUtils.ApplyAllTransformsToData(env, preCachedData.Data, testPipe, trainPipe);
81+
82+
var testRoles = new RoleMappedData(pipe, trainData.Schema.GetColumnRoleNames());
83+
84+
IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, testRoles, env, testRoles.Schema);
85+
86+
BinaryClassifierMamlEvaluator eval = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments() { });
87+
var dataEval = new RoleMappedData(scorer, testRoles.Schema.GetColumnRoleNames(), opt: true);
88+
var dict = eval.Evaluate(dataEval);
89+
var foldMetrics = BinaryClassificationMetrics.FromMetrics(env, dict["OverallMetrics"], dict["ConfusionMatrix"]);
90+
metrics.Add(foldMetrics.Single());
91+
}
92+
}
93+
}
94+
}
95+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime.Api;
6+
using Microsoft.ML.Runtime.Data;
7+
using Microsoft.ML.Runtime.Learners;
8+
using System.Linq;
9+
using Xunit;
10+
11+
namespace Microsoft.ML.Tests.Scenarios.Api
12+
{
13+
14+
public partial class ApiScenariosTests
15+
{
16+
/// <summary>
17+
/// Decomposable train and predict: Train on Iris multiclass problem, which will require
18+
/// a transform on labels. Be able to reconstitute the pipeline for a prediction only task,
19+
/// which will essentially "drop" the transform over labels, while retaining the property
20+
/// that the predicted label for this has a key-type, the probability outputs for the classes
21+
/// have the class labels as slot names, etc. This should be do-able without ugly compromises like,
22+
/// say, injecting a dummy label.
23+
/// </summary>
24+
[Fact]
25+
void DecomposableTrainAndPredict()
26+
{
27+
var dataPath = GetDataPath(IrisDataPath);
28+
using (var env = new TlcEnvironment())
29+
{
30+
var loader = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
31+
var term = new TermTransform(env, loader, "Label");
32+
var concat = new ConcatTransform(env, term, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth");
33+
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 });
34+
35+
IDataView trainData = trainer.Info.WantCaching ? (IDataView)new CacheDataView(env, concat, prefetch: null) : concat;
36+
var trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");
37+
38+
// Auto-normalization.
39+
NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
40+
var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));
41+
42+
var scoreRoles = new RoleMappedData(concat, label: "Label", feature: "Features");
43+
IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);
44+
45+
// Cut out term transform from pipeline.
46+
var newScorer = ApplyTransformUtils.ApplyAllTransformsToData(env, scorer, loader, term);
47+
var keyToValue = new KeyToValueTransform(env, newScorer, "PredictedLabel");
48+
var model = env.CreatePredictionEngine<IrisDataNoLabel, IrisPrediction>(keyToValue);
49+
50+
var testLoader = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
51+
var testData = testLoader.AsEnumerable<IrisDataNoLabel>(env, false);
52+
foreach (var input in testData.Take(20))
53+
{
54+
var prediction = model.Predict(input);
55+
Assert.True(prediction.PredictedLabel == "Iris-setosa");
56+
}
57+
}
58+
}
59+
}
60+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime.Api;
6+
using Microsoft.ML.Runtime.Data;
7+
using Microsoft.ML.Runtime.Learners;
8+
using Xunit;
9+
using Microsoft.ML.Models;
10+
11+
namespace Microsoft.ML.Tests.Scenarios.Api
12+
{
13+
public partial class ApiScenariosTests
14+
{
15+
/// <summary>
16+
/// Evaluation: Similar to the simple train scenario, except instead of having some
17+
/// predictive structure, be able to score another "test" data file, run the result
18+
/// through an evaluator and get metrics like AUC, accuracy, PR curves, and whatnot.
19+
/// Getting metrics out of this shoudl be as straightforward and unannoying as possible.
20+
/// </summary>
21+
[Fact]
22+
public void Evaluation()
23+
{
24+
var dataPath = GetDataPath(SentimentDataPath);
25+
var testDataPath = GetDataPath(SentimentTestPath);
26+
27+
using (var env = new TlcEnvironment(seed: 1, conc: 1))
28+
{
29+
// Pipeline
30+
var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));
31+
32+
var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader);
33+
34+
// Train
35+
var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
36+
{
37+
NumThreads = 1
38+
});
39+
40+
var cached = new CacheDataView(env, trans, prefetch: null);
41+
var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
42+
var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));
43+
var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
44+
IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);
45+
46+
// Create prediction engine and test predictions.
47+
var model = env.CreatePredictionEngine<SentimentData, SentimentPrediction>(scorer);
48+
49+
// Take a couple examples out of the test data and run predictions on top.
50+
var testLoader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(SentimentTestPath)));
51+
var testData = testLoader.AsEnumerable<SentimentData>(env, false);
52+
53+
var dataEval = new RoleMappedData(scorer, label: "Label", feature: "Features", opt: true);
54+
55+
var evaluator = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments() { });
56+
var metricsDict = evaluator.Evaluate(dataEval);
57+
58+
var metrics = BinaryClassificationMetrics.FromMetrics(env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"])[0];
59+
}
60+
}
61+
}
62+
}

0 commit comments

Comments
 (0)