Skip to content

Commit c01c46f

Browse files
authored
TensorFlow static extensions, SDCA multiclass static extensions (#882)
TF static extensions SDCA static extensions Multiclass evaluator Prediction function
1 parent 5666dd1 commit c01c46f

21 files changed

+580
-112
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Core.Data;
6+
using Microsoft.ML.Runtime.Api;
7+
8+
namespace Microsoft.ML.Runtime.Data
9+
{
10+
/// <summary>
11+
/// A prediction engine class, that takes instances of <typeparamref name="TSrc"/> through
12+
/// the transformer pipeline and produces instances of <typeparamref name="TDst"/> as outputs.
13+
/// </summary>
14+
public sealed class PredictionFunction<TSrc, TDst>
15+
where TSrc : class
16+
where TDst : class, new()
17+
{
18+
private readonly PredictionEngine<TSrc, TDst> _engine;
19+
20+
public PredictionFunction(IHostEnvironment env, ITransformer transformer)
21+
{
22+
Contracts.CheckValue(env, nameof(env));
23+
env.CheckValue(transformer, nameof(transformer));
24+
25+
IDataView dv = env.CreateDataView(new TSrc[0]);
26+
_engine = env.CreatePredictionEngine<TSrc, TDst>(transformer.Transform(dv));
27+
}
28+
29+
public TDst Predict(TSrc example) => _engine.Predict(example);
30+
}
31+
32+
public static class PredictionFunctionExtensions
33+
{
34+
/// <summary>
35+
/// Create an instance of the 'prediction function', or 'prediction machine', from a model
36+
/// denoted by <paramref name="transformer"/>.
37+
/// It will be accepting instances of <typeparamref name="TSrc"/> as input, and produce
38+
/// instances of <typeparamref name="TDst"/> as output.
39+
/// </summary>
40+
public static PredictionFunction<TSrc, TDst> MakePredictionFunction<TSrc, TDst>(this ITransformer transformer, IHostEnvironment env)
41+
where TSrc : class
42+
where TDst : class, new()
43+
=> new PredictionFunction<TSrc, TDst>(env, transformer);
44+
}
45+
}

src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs

Lines changed: 223 additions & 60 deletions
Large diffs are not rendered by default.

src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ private sealed class TrivialLossFactory : ISupportRegressionLossFactory
233233
}
234234

235235
/// <summary>
236-
/// Evaluates scored binary classification data.
236+
/// Evaluates scored regression data.
237237
/// </summary>
238238
/// <typeparam name="T">The shape type for the input data.</typeparam>
239239
/// <param name="data">The data to evaluate.</param>

src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ public sealed override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
138138
public sealed class Regression : TrainerEstimatorReconciler
139139
{
140140
/// <summary>
141-
/// The delegate to create the <see cref="Regression"/> instance.
141+
/// The delegate to create the regression trainer instance.
142142
/// </summary>
143143
/// <param name="env">The environment with which to create the estimator</param>
144144
/// <param name="label">The label column name</param>
@@ -198,7 +198,7 @@ public Impl(Regression rec) : base(rec, rec._inputs) { }
198198
public sealed class BinaryClassifier : TrainerEstimatorReconciler
199199
{
200200
/// <summary>
201-
/// The delegate to create the <see cref="BinaryClassifier"/> instance.
201+
/// The delegate to create the binary classifier trainer instance.
202202
/// </summary>
203203
/// <param name="env">The environment with which to create the estimator.</param>
204204
/// <param name="label">The label column name.</param>
@@ -259,12 +259,12 @@ public ImplBool(BinaryClassifier rec) : base(rec, rec._inputs) { }
259259

260260
/// <summary>
261261
/// A reconciler capable of handling the most common cases for binary classification that does not
262-
/// necessarily have with calibrated outputs.
262+
/// necessarily have calibrated outputs.
263263
/// </summary>
264264
public sealed class BinaryClassifierNoCalibration : TrainerEstimatorReconciler
265265
{
266266
/// <summary>
267-
/// The delegate to create the <see cref="BinaryClassifier"/> instance.
267+
/// The delegate to create the binary classifier trainer instance.
268268
/// </summary>
269269
/// <param name="env">The environment with which to create the estimator</param>
270270
/// <param name="label">The label column name.</param>
@@ -336,5 +336,72 @@ private sealed class ImplBool : Scalar<bool>
336336
public ImplBool(BinaryClassifierNoCalibration rec) : base(rec, rec._inputs) { }
337337
}
338338
}
339+
340+
/// <summary>
341+
/// A reconciler for regression capable of handling the most common cases for regression.
342+
/// </summary>
343+
public sealed class MulticlassClassifier<TVal> : TrainerEstimatorReconciler
344+
{
345+
/// <summary>
346+
/// The delegate to create the multiclass classifier trainer instance.
347+
/// </summary>
348+
/// <param name="env">The environment with which to create the estimator</param>
349+
/// <param name="label">The label column name</param>
350+
/// <param name="features">The features column name</param>
351+
/// <param name="weights">The weights column name, or <c>null</c> if the reconciler was constructed with <c>null</c> weights</param>
352+
/// <returns>A estimator producing columns with the fixed name <see cref="DefaultColumnNames.Score"/> and <see cref="DefaultColumnNames.PredictedLabel"/>.</returns>
353+
public delegate IEstimator<ITransformer> EstimatorFactory(IHostEnvironment env, string label, string features, string weights);
354+
355+
private readonly EstimatorFactory _estFact;
356+
357+
/// <summary>
358+
/// The general output for multiclass classifiers.
359+
/// </summary>
360+
public (Vector<float> score, Key<uint, TVal> predictedLabel) Output { get; }
361+
362+
protected override IEnumerable<PipelineColumn> Outputs => new PipelineColumn[] { Output.score, Output.predictedLabel };
363+
364+
private static readonly string[] _fixedOutputNames = new[] { DefaultColumnNames.Score, DefaultColumnNames.PredictedLabel };
365+
366+
/// <summary>
367+
/// Constructs a new general multiclass classifier reconciler.
368+
/// </summary>
369+
/// <param name="estimatorFactory">The delegate to create the training estimator. It is assumed that this estimator
370+
/// will produce a vector <see cref="float"/> column named <see cref="DefaultColumnNames.Score"/> and a scalar
371+
/// key column named <see cref="DefaultColumnNames.PredictedLabel"/>.</param>
372+
/// <param name="label">The input label column.</param>
373+
/// <param name="features">The input features column.</param>
374+
/// <param name="weights">The input weights column, or <c>null</c> if there are no weights.</param>
375+
public MulticlassClassifier(EstimatorFactory estimatorFactory, Key<uint, TVal> label, Vector<float> features, Scalar<float> weights)
376+
: base(MakeInputs(Contracts.CheckRef(label, nameof(label)), Contracts.CheckRef(features, nameof(features)), weights),
377+
_fixedOutputNames)
378+
{
379+
Contracts.CheckValue(estimatorFactory, nameof(estimatorFactory));
380+
_estFact = estimatorFactory;
381+
Contracts.Assert(_inputs.Length == 2 || _inputs.Length == 3);
382+
Output = (new ImplScore(this), new ImplLabel(this));
383+
}
384+
385+
private static PipelineColumn[] MakeInputs(Key<uint, TVal> label, Vector<float> features, Scalar<float> weights)
386+
=> weights == null ? new PipelineColumn[] { label, features } : new PipelineColumn[] { label, features, weights };
387+
388+
protected override IEstimator<ITransformer> ReconcileCore(IHostEnvironment env, string[] inputNames)
389+
{
390+
Contracts.AssertValue(env);
391+
env.Assert(Utils.Size(inputNames) == _inputs.Length);
392+
return _estFact(env, inputNames[0], inputNames[1], inputNames.Length > 2 ? inputNames[2] : null);
393+
}
394+
395+
private sealed class ImplLabel : Key<uint, TVal>
396+
{
397+
public ImplLabel(MulticlassClassifier<TVal> rec) : base(rec, rec._inputs) { }
398+
}
399+
400+
private sealed class ImplScore : Vector<float>
401+
{
402+
public ImplScore(MulticlassClassifier<TVal> rec) : base(rec, rec._inputs) { }
403+
}
404+
}
405+
339406
}
340407
}

src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,6 @@ public abstract class TrainerEstimatorBase<TTransformer, TModel> : ITrainerEstim
4040
/// </summary>
4141
public readonly SchemaShape.Column WeightColumn;
4242

43-
/// <summary>
44-
/// The columns that will be created by the fitted transformer.
45-
/// </summary>
46-
protected abstract SchemaShape.Column[] OutputColumns { get; }
47-
4843
protected readonly IHost Host;
4944

5045
/// <summary>
@@ -76,12 +71,17 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
7671
CheckInputSchema(inputSchema);
7772

7873
var outColumns = inputSchema.Columns.ToDictionary(x => x.Name);
79-
foreach (var col in OutputColumns)
74+
foreach (var col in GetOutputColumnsCore(inputSchema))
8075
outColumns[col.Name] = col;
8176

8277
return new SchemaShape(outColumns.Values);
8378
}
8479

80+
/// <summary>
81+
/// The columns that will be created by the fitted transformer.
82+
/// </summary>
83+
protected abstract SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema);
84+
8585
public TModel Train(TrainContext context)
8686
{
8787
Host.CheckValue(context, nameof(context));

src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,7 +1390,9 @@ internal override void Check(IHostEnvironment env)
13901390

13911391
protected override bool ShuffleData => _args.Shuffle;
13921392

1393-
protected override SchemaShape.Column[] OutputColumns { get; }
1393+
private readonly SchemaShape.Column[] _outputColumns;
1394+
1395+
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns;
13941396

13951397
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
13961398

@@ -1408,15 +1410,15 @@ public LinearClassificationTrainer(IHostEnvironment env, Arguments args,
14081410

14091411
if (Info.NeedCalibration)
14101412
{
1411-
OutputColumns = new[]
1413+
_outputColumns = new[]
14121414
{
14131415
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
14141416
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)
14151417
};
14161418
}
14171419
else
14181420
{
1419-
OutputColumns = new[]
1421+
_outputColumns = new[]
14201422
{
14211423
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
14221424
new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),

src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace Microsoft.ML.Runtime.Learners
3030
// - Feature normalization. By default, rescaling between min and max values for every feature
3131
// - Prediction calibration to produce probabilities. Off by default, if on, uses exponential (aka Platt) calibration.
3232
/// <include file='doc.xml' path='doc/members/member[@name="AP"]/*' />
33-
public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer<BinaryPredictionTransformer<LinearBinaryPredictor> , LinearBinaryPredictor>
33+
public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer<BinaryPredictionTransformer<LinearBinaryPredictor>, LinearBinaryPredictor>
3434
{
3535
public const string LoadNameValue = "AveragedPerceptron";
3636
internal const string UserNameValue = "Averaged Perceptron";
@@ -57,7 +57,7 @@ public AveragedPerceptronTrainer(IHostEnvironment env, Arguments args)
5757
_args = args;
5858
LossFunction = _args.LossFunction.CreateComponent(env);
5959

60-
OutputColumns = new[]
60+
_outputColumns = new[]
6161
{
6262
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
6363
new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
@@ -69,7 +69,9 @@ public AveragedPerceptronTrainer(IHostEnvironment env, Arguments args)
6969

7070
protected override bool NeedCalibration => true;
7171

72-
protected override SchemaShape.Column[] OutputColumns { get; }
72+
private readonly SchemaShape.Column[] _outputColumns;
73+
74+
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns;
7375

7476
protected override void CheckLabel(RoleMappedData data)
7577
{

src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,15 @@ public sealed class Arguments : OnlineLinearArguments
4848
{
4949
[Argument(ArgumentType.AtMostOnce, HelpText = "Regularizer constant", ShortName = "lambda", SortOrder = 50)]
5050
[TGUI(SuggestedSweeps = "0.00001-0.1;log;inc:10")]
51-
[TlcModule.SweepableFloatParamAttribute("Lambda", 0.00001f, 0.1f, 10, isLogScale:true)]
51+
[TlcModule.SweepableFloatParamAttribute("Lambda", 0.00001f, 0.1f, 10, isLogScale: true)]
5252
public Float Lambda = (Float)0.001;
5353

5454
[Argument(ArgumentType.AtMostOnce, HelpText = "Batch size", ShortName = "batch", SortOrder = 190)]
5555
[TGUI(Label = "Batch Size")]
5656
public int BatchSize = 1;
5757

5858
[Argument(ArgumentType.AtMostOnce, HelpText = "Perform projection to unit-ball? Typically used with batch size > 1.", ShortName = "project", SortOrder = 50)]
59-
[TlcModule.SweepableDiscreteParam("PerformProjection", null, isBool:true)]
59+
[TlcModule.SweepableDiscreteParam("PerformProjection", null, isBool: true)]
6060
public bool PerformProjection = false;
6161

6262
[Argument(ArgumentType.AtMostOnce, HelpText = "No bias")]
@@ -93,7 +93,7 @@ public LinearSvm(IHostEnvironment env, Arguments args)
9393

9494
Args = args;
9595

96-
OutputColumns = new[]
96+
_outputColumns = new[]
9797
{
9898
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
9999
new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
@@ -103,7 +103,9 @@ public LinearSvm(IHostEnvironment env, Arguments args)
103103

104104
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
105105

106-
protected override SchemaShape.Column[] OutputColumns { get; }
106+
private readonly SchemaShape.Column[] _outputColumns;
107+
108+
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns;
107109

108110
protected override void CheckLabel(RoleMappedData data)
109111
{

src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,16 @@ public OnlineGradientDescentTrainer(IHostEnvironment env, Arguments args)
5656
{
5757
LossFunction = args.LossFunction.CreateComponent(env);
5858

59-
OutputColumns = new[]
59+
_outputColumns = new[]
6060
{
6161
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false)
6262
};
6363
}
6464

6565
public override PredictionKind PredictionKind => PredictionKind.Regression;
6666

67-
protected override SchemaShape.Column[] OutputColumns { get; }
67+
private readonly SchemaShape.Column[] _outputColumns;
68+
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns;
6869

6970
protected override void CheckLabel(RoleMappedData data)
7071
{

src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ public sealed class Arguments : ArgumentsBase
4646
private readonly Arguments _args;
4747

4848
public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification;
49-
protected override SchemaShape.Column[] OutputColumns { get; }
5049

5150
public SdcaMultiClassTrainer(IHostEnvironment env, Arguments args,
5251
string featureColumn, string labelColumn, string weightColumn = null)
@@ -55,10 +54,18 @@ public SdcaMultiClassTrainer(IHostEnvironment env, Arguments args,
5554
_loss = args.LossFunction.CreateComponent(env);
5655
Loss = _loss;
5756
_args = args;
58-
OutputColumns = new[]
57+
}
58+
59+
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
60+
{
61+
bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
62+
Contracts.Assert(success);
63+
64+
var metadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues));
65+
return new[]
5966
{
6067
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false),
61-
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true)
68+
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, labelCol.ItemType, labelCol.IsKey, metadata)
6269
};
6370
}
6471

src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,17 @@ public Arguments()
5151
private readonly Arguments _args;
5252

5353
public override PredictionKind PredictionKind => PredictionKind.Regression;
54-
protected override SchemaShape.Column[] OutputColumns { get; }
54+
55+
private readonly SchemaShape.Column[] _outputColumns;
56+
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns;
5557

5658
public SdcaRegressionTrainer(IHostEnvironment env, Arguments args, string featureColumn, string labelColumn, string weightColumn = null)
5759
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), args, MakeFeatureColumn(featureColumn), MakeLabelColumn(labelColumn), MakeWeightColumn(weightColumn))
5860
{
5961
_loss = args.LossFunction.CreateComponent(env);
6062
Loss = _loss;
6163
_args = args;
62-
OutputColumns = new[]
64+
_outputColumns = new[]
6365
{
6466
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false)
6567
};

0 commit comments

Comments
 (0)