|
1 | 1 | using Microsoft.ML.Core.Data;
|
| 2 | +using Microsoft.ML.Models; |
2 | 3 | using Microsoft.ML.Runtime;
|
3 | 4 | using Microsoft.ML.Runtime.Api;
|
4 | 5 | using Microsoft.ML.Runtime.Data;
|
|
9 | 10 | using System;
|
10 | 11 | using System.Collections.Generic;
|
11 | 12 | using System.IO;
|
| 13 | +using System.Linq; |
| 14 | + |
12 | 15 | [assembly: LoadableClass(typeof(TransformWrapper), null, typeof(SignatureLoadModel),
|
13 | 16 | "Transform wrapper", TransformWrapper.LoaderSignature)]
|
14 | 17 | [assembly: LoadableClass(typeof(LoaderWrapper), null, typeof(SignatureLoadModel),
|
@@ -163,16 +166,16 @@ public TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
|
163 | 166 | public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, input);
|
164 | 167 | }
|
165 | 168 |
|
166 |
| - public interface IPredictorTransformer<out TModel>: ITransformer |
| 169 | + public interface IPredictorTransformer<out TModel> : ITransformer |
167 | 170 | {
|
168 | 171 | TModel TrainedModel { get; }
|
169 | 172 | }
|
170 | 173 |
|
171 |
| - public class ScorerWrapper<TModel>: TransformWrapper, IPredictorTransformer<TModel> |
172 |
| - where TModel: IPredictor |
| 174 | + public class ScorerWrapper<TModel> : TransformWrapper, IPredictorTransformer<TModel> |
| 175 | + where TModel : IPredictor |
173 | 176 | {
|
174 | 177 | public ScorerWrapper(IHostEnvironment env, IDataView scorer, TModel trainedModel)
|
175 |
| - :base(env, scorer) |
| 178 | + : base(env, scorer) |
176 | 179 | {
|
177 | 180 | Model = trainedModel;
|
178 | 181 | }
|
@@ -206,7 +209,7 @@ public SchemaShape GetOutputSchema()
|
206 | 209 | }
|
207 | 210 |
|
208 | 211 | public abstract class TrainerBase<TModel> : IEstimator<ScorerWrapper<TModel>>
|
209 |
| - where TModel: IPredictor |
| 212 | + where TModel : IPredictor |
210 | 213 | {
|
211 | 214 | protected readonly IHostEnvironment _env;
|
212 | 215 | private readonly string _featureCol;
|
@@ -298,12 +301,12 @@ public MySdca(IHostEnvironment env, LinearClassificationTrainer.Arguments args,
|
298 | 301 | public ITransformer Train(IDataView trainData, IDataView validationData = null) => TrainTransformer(trainData, validationData);
|
299 | 302 | }
|
300 | 303 |
|
301 |
| - public sealed class MyAveragedPerceptron: TrainerBase<IPredictor> |
| 304 | + public sealed class MyAveragedPerceptron : TrainerBase<IPredictor> |
302 | 305 | {
|
303 | 306 | private readonly AveragedPerceptronTrainer _trainer;
|
304 | 307 |
|
305 | 308 | public MyAveragedPerceptron(IHostEnvironment env, AveragedPerceptronTrainer.Arguments args, string featureCol, string labelCol)
|
306 |
| - :base(env, false, featureCol, labelCol) |
| 309 | + : base(env, false, featureCol, labelCol) |
307 | 310 | {
|
308 | 311 | _trainer = new AveragedPerceptronTrainer(env, args);
|
309 | 312 | }
|
@@ -334,5 +337,31 @@ public TDst Predict(TSrc example)
|
334 | 337 | }
|
335 | 338 | }
|
336 | 339 |
|
| 340 | + public sealed class MyBinaryClassifierEvaluator |
| 341 | + { |
| 342 | + private readonly IHostEnvironment _env; |
| 343 | + private readonly BinaryClassifierEvaluator _evaluator; |
| 344 | + |
| 345 | + public MyBinaryClassifierEvaluator(IHostEnvironment env, BinaryClassifierEvaluator.Arguments args) |
| 346 | + { |
| 347 | + _env = env; |
| 348 | + _evaluator = new BinaryClassifierEvaluator(env, args); |
| 349 | + } |
| 350 | + |
| 351 | + public BinaryClassificationMetrics Evaluate(IDataView data, string labelColumn, string probabilityColumn) |
| 352 | + { |
| 353 | + var ci = EvaluateUtils.GetScoreColumnInfo(_env, data.Schema, null, "Score", MetadataUtils.Const.ScoreColumnKind.BinaryClassification); |
| 354 | + var map = new KeyValuePair<RoleMappedSchema.ColumnRole, string>[] |
| 355 | + { |
| 356 | + RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Probability, probabilityColumn), |
| 357 | + RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, ci.Name) |
| 358 | + }; |
| 359 | + var rmd = new RoleMappedData(data, labelColumn, "Features", opt: true, custom: map); |
| 360 | + |
| 361 | + var metricsDict = _evaluator.Evaluate(rmd); |
| 362 | + return BinaryClassificationMetrics.FromMetrics(_env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"]).Single(); |
| 363 | + } |
| 364 | + } |
| 365 | + |
337 | 366 |
|
338 | 367 | }
|
0 commit comments