From f03e5c1d74489105fee6096cd9f51a3cdb195711 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Tue, 25 Sep 2018 22:40:26 -0700 Subject: [PATCH 1/7] Adding the Ranker TrainContext, the Ranker TrainerEstimatorReconcilier, and an Evaluate method + metrics class to the existing RankerEvaluator. --- docs/images/DCG.png | Bin 0 -> 1233 bytes docs/images/NDCG.png | Bin 0 -> 1057 bytes .../Evaluators/EvaluatorStaticExtensions.cs | 34 ++++++++ .../Evaluators/RankerEvaluator.cs | 82 +++++++++++++++++- .../StaticPipe/TrainerEstimatorReconciler.cs | 64 ++++++++++++++ .../Training/TrainContext.cs | 43 +++++++++ 6 files changed, 220 insertions(+), 3 deletions(-) create mode 100644 docs/images/DCG.png create mode 100644 docs/images/NDCG.png diff --git a/docs/images/DCG.png b/docs/images/DCG.png new file mode 100644 index 0000000000000000000000000000000000000000..5322bc0b9dff7b8127b03b81989af1342069f155 GIT binary patch literal 1233 zcmV;?1TOoDP)Px#jZjQfMF0Q*0000006G8wW&k=m06LHWX0|#206G9VIsj%mIsiI4IyyR#I%c*y zkdQi%+hza&W&k>70A^-7I%YbMW@ct)X0~RK+h(@^kU9X6Iy#VMI*_)IklX*ZW&pNk zI<{tJwvcAFwq~}r|F-|Ow*UXzkUHD8X50U^+yC3!|NsBCX8+re|NplC|JxP`Sycc4 z00DGTPE!Ct=GbNc000SaNLh0L01m_e01m_fl`9S#0000PbVXQnQ*UN;cVTj608MFQ za&L2QW^^D=W@c$)Wq<88aR2}Vt4TybRA@u(T8UEIFcg%bCbXr*&;v{xLyxF!eEk3C zb@#~*b}ZR)<9IYJnt>>>k@xM}cSu30|8L~cg({QtZK#%4pD#B`--ilM6?fTMJq1gj z%J1{|*N0uO{5@@}hc9=*!Ye_Yj>45+25jYRxbh3uu`j^>VFvQZInFjA*qs5NC8-pv zC=7gM5?PVtf#6eMDvJA6txUq;Pbo_9qA#^96nx4qK^+XfDz!^6gzyy%j1h|jgTd#L z-~d2GRs40i1gQ}t>j`}2$eH(MPg zR;5}oT-i+dp@fz6btb?vzC!<_raC%>p4ldP$#r@ARq(;inYF$Xrn=&nAh& zG{=6T;~jtlrKBf~bBo=OQl33Tq0E(aMW4i%H7hm~(ub>6ifdrwaZ}!Gnw_^umBSvY zJ~H(SLq1(76XkKp1Rh!gG_L{`@!{M8SBfjIgP&q`K%aa2G& zYkdHtS6j$L=HHGq`n8a(JF5(rTZ*1$j#>G43W73i;;a2-BCrPVY)AcxTs5Ae_hu7u z+Y*QPsPks@nQ7t6){SRbu4W%D@c|#c7(r#mLxsx@&u#wp8xk+EPc*<;=_N%$gc{B> z#I!M4@tr_`4)IY{#0SOGc?Mroih8NE@$U)nxien}pJ=hI$c6knkons93^D9EnqkG) z4*7{Pv)^8<+~tK2Yy2x{P&%Kx`Yt|GFJ{j#cJY~Rmx0S4)FzO}{OpQf*-;ST?DEU3 zW!JWM^r42Wz^ua|K50IBW8VzBAdKEH6}wA(c2=jZe)RpW%!dx+BxQ6eVA)37M+o>r z|9)nGKHtNbsUCZ|e{lx7$Dm$2h{ZAT^oj+lOkT}h0gkohpp?yCp^n9OaJUs7J~m8^ vf6Zn;sBv%%8XmsRf2}cPpJIm?l1u&oro~zJzy!|e00000NkvXXu0mjf!!1Du literal 0 HcmV?d00001 diff --git a/docs/images/NDCG.png b/docs/images/NDCG.png new file mode 100644 index 0000000000000000000000000000000000000000..f76d63ee451387d959a2c1c76afb91c97f78ba30 GIT binary patch literal 1057 zcmeAS@N?(olHy`uVBq!ia0vp^CxKX*gBeJ29=~@PNHG=%xjQkeJ16rJ$eADD6XFV_ z!N7`vA&tQbh$b+k0ns)f--^M?iXjb%tw3OcRa%-AP<+A!s|mN$psImF6VgBcNN!7; zFd+>nwGD{=Pq1Q`0M-K(0_(Uv;r9P+X$;%athND_PO#c`d&0K=x3~TOe|v(}?QLnd z!Sc8N-@g6-KTsV|6i5Lj%}?y@1-g&1B*-tA!Qt5rpv4TF1s;*b3=DinK$vl=HlH+5 zP_o1|q9iy!t)x7$D3!rCGr1_g7|2ubPfN>8POX1yTFAh_^wQJCF(ktMZA5nQEdzlq z&Y24`XE7}Gx6Vv=gd9_H#9CbG15J@Ld=5IZZ>~6 z*S6NFNyfoH($!cUE_nJL&YO6^X!4&=lXv9r(`ZmW;4-`7Wbe&+3_`chU3ko^z~o=u z>b~p2g{jh06r_&wDhNG&d`$X|kWH?)jdR7?RcV2y?;u-?49Zoh5NbRKlfkMpDU4Zk8KA${%GOdQ>9Lxj&Kn!p!Z@rQc@F zoH@^O2j?;IxpyWhFG-qdzSY#^vb^N_&mZ?sa#0GrF)e870q0~HrcB}Z}!PQ zOKZOgR5Ot$i|zi$<0{rdTi;v%Fv{_8_`gx5St0Y8hU)~$edR6XwR<)RU#Pb~si5^| z+m9PNimtdtDIc|*m9hTySN&DXleaYdTkP4k*JFX@xtxuq4<47^_#gdix=E*^st>5$~Q|;dbiE}+m&8(`8Qx;qm^VsxTq)KJ0 z;gS9CiqEDiU3~vem!a^ATX{NL@0IwY{h zZ<_hPXPFsa!HVwz8QgMr-tJX8_ckoqDDcA95b3>hC5{BZ`VzTaGa|D-)D8H z?(ferdK!0jP6{tx#N?OfpwqS7=YzM-R>PE)^HcJ2wtiIcEB`mdKI;Vst06?wj)c^nh literal 0 HcmV?d00001 diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs index 8fed434a94..03ff419981 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs @@ -167,5 +167,39 @@ public static RegressionEvaluator.Result Evaluate( args.LossFunction = new TrivialRegressionLossFactory(loss); return new RegressionEvaluator(env, args).Evaluate(data.AsDynamic, labelName, scoreName); } + + /// + /// Evaluates scored ranking data. + /// + /// The shape type for the input data. + /// The regression context. + /// The data to evaluate. + /// The index delegate for the label column. + /// The index delegate for the groupId column. + /// The index delegate for predicted score column. + /// The evaluation metrics. + public static RankerEvaluator.Result Evaluate( + this RegressionContext ctx, + DataView data, + Func> label, + Func> groupId, + Func> score) + { + Contracts.CheckValue(data, nameof(data)); + var env = StaticPipeUtils.GetEnvironment(data); + Contracts.AssertValue(env); + env.CheckValue(label, nameof(label)); + env.CheckValue(groupId, nameof(groupId)); + env.CheckValue(score, nameof(score)); + + var indexer = StaticPipeUtils.GetIndexer(data); + string labelName = indexer.Get(label(indexer.Indices)); + string scoreName = indexer.Get(score(indexer.Indices)); + string groupIdName = indexer.Get(groupId(indexer.Indices)); + + var args = new RankerEvaluator.Arguments() { }; + + return new RankerEvaluator(env, args).Evaluate(data.AsDynamic, labelName, groupIdName, scoreName); + } } } diff --git a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs index d071a906e3..a305018e15 100644 --- a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs @@ -41,17 +41,17 @@ public sealed class Arguments public bool OutputGroupSummary; } - public const string LoadName = "RankingEvaluator"; + internal const string LoadName = "RankingEvaluator"; public const string Ndcg = "NDCG"; public const string Dcg = "DCG"; public const string MaxDcg = "MaxDCG"; - /// + /// /// The ranking evaluator outputs a data view by this name, which contains metrics aggregated per group. /// It contains four columns: GroupId, NDCG, DCG and MaxDCG. Each row in the data view corresponds to one /// group in the scored data. - /// + /// public const string GroupSummary = "GroupSummary"; private const string GroupId = "GroupId"; @@ -234,6 +234,40 @@ protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, A }; } + /// + /// Evaluates scored regression data. + /// + /// The data to evaluate. + /// The name of the label column. + /// The name of the groupId column. + /// The name of the predicted score column. + /// The evaluation metrics for these outputs. + public Result Evaluate(IDataView data, string label, string groupId, string score) + { + Host.CheckValue(data, nameof(data)); + Host.CheckNonEmpty(label, nameof(label)); + Host.CheckNonEmpty(score, nameof(score)); + var roles = new RoleMappedData(data, opt: false, + RoleMappedSchema.ColumnRole.Label.Bind(label), + RoleMappedSchema.ColumnRole.Group.Bind(groupId), + RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, score)); + + var resultDict = Evaluate(roles); + Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics)); + var overall = resultDict[MetricKinds.OverallMetrics]; + + Result result; + using (var cursor = overall.GetRowCursor(i => true)) + { + var moved = cursor.MoveNext(); + Host.Assert(moved); + result = new Result(Host, cursor); + moved = cursor.MoveNext(); + Host.Assert(!moved); + } + return result; + } + public sealed class Aggregator : AggregatorBase { public sealed class Counters @@ -509,6 +543,48 @@ public void GetSlotNames(ref VBuffer> slotNames) slotNames = new VBuffer>(UnweightedCounters.TruncationLevel, values); } } + + public sealed class Result + { + /// + /// Normalized Discounted Cumulative Gain + /// + /// + public double[] Ndcg { get; } + + /// + /// Discounted Cumulative gain + /// is the sum of the gains, for all the instances i, normalized by the natural logarithm of the instance + 1. + /// Note that unline the Wikipedia article, ML.Net uses the natural logarithm. + /// + /// + public double[] Dcg { get; } + + /// + /// MaxDcgs is the value of when the documents are ordered in the ideal order from most relevant to least relevant. + /// In case there are ties in scores, metrics are computed in a pessimistic fashion. In other words, if two or more results get the same score, + /// for the purpose of computing DCG and NDCG they are ordered from least relevant to most relevant. + /// + public double[] MaxDcg { get; } + + private static T Fetch(IExceptionContext ectx, IRow row, string name) + { + if (!row.Schema.TryGetColumnIndex(name, out int col)) + throw ectx.Except($"Could not find column '{name}'"); + T val = default; + row.GetGetter(col)(ref val); + return val; + } + + internal Result(IExceptionContext ectx, IRow overallResult) + { + double[] Fetch(string name) => Fetch(ectx, overallResult, name); + + Dcg = Fetch(RankerEvaluator.Dcg); + Ndcg = Fetch(RankerEvaluator.Ndcg); + MaxDcg = Fetch(RankerEvaluator.MaxDcg); + } + } } public sealed class RankerPerInstanceTransform : IDataTransform diff --git a/src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs b/src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs index ff2b9d7c02..952ec7bf21 100644 --- a/src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs +++ b/src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs @@ -403,5 +403,69 @@ public ImplScore(MulticlassClassifier rec) : base(rec, rec.Inputs) { } } } + /// + /// A reconciler for regression capable of handling the most common cases for regression. + /// + public sealed class Ranker : TrainerEstimatorReconciler + { + /// + /// The delegate to create the regression trainer instance. + /// + /// The environment with which to create the estimator + /// The label column name + /// The features column name + /// The weights column name, or null if the reconciler was constructed with null weights + /// The groupID column name. + /// A estimator producing columns with the fixed name . + public delegate IEstimator EstimatorFactory(IHostEnvironment env, string label, string features, string weights, string groupId); + + private readonly EstimatorFactory _estFact; + + /// + /// The output score column for ranking. This will have this instance as its reconciler. + /// + public Scalar Score { get; } + + protected override IEnumerable Outputs => Enumerable.Repeat(Score, 1); + + private static readonly string[] _fixedOutputNames = new[] { DefaultColumnNames.Score }; + + /// + /// Constructs a new general ranker reconciler. + /// + /// The delegate to create the training estimator. It is assumed that this estimator + /// will produce a single new scalar column named . + /// The input label column. + /// The input features column. + /// The input weights column, or null if there are no weights. + /// The input groupId column. + public Ranker(EstimatorFactory estimatorFactory, Scalar label, Vector features, Scalar weights, Scalar groupId) + : base(MakeInputs(Contracts.CheckRef(label, nameof(label)), + Contracts.CheckRef(features, nameof(features)), + Contracts.CheckRef(groupId, nameof(groupId)), + weights), + _fixedOutputNames) + { + Contracts.CheckValue(estimatorFactory, nameof(estimatorFactory)); + _estFact = estimatorFactory; + Contracts.Assert(Inputs.Length == 2 || Inputs.Length == 3); + Score = new Impl(this); + } + + private static PipelineColumn[] MakeInputs(Scalar label, Vector features, Scalar weights, Scalar groupId) + => weights == null ? new PipelineColumn[] { label, features, groupId } : new PipelineColumn[] { label, features, groupId, weights }; + + protected override IEstimator ReconcileCore(IHostEnvironment env, string[] inputNames) + { + Contracts.AssertValue(env); + env.Assert(Utils.Size(inputNames) == Inputs.Length); + return _estFact(env, inputNames[0], inputNames[1], inputNames[2], inputNames.Length > 3 ? inputNames[3] : null); + } + + private sealed class Impl : Scalar + { + public Impl(Ranker rec) : base(rec, rec.Inputs) { } + } + } } } diff --git a/src/Microsoft.ML.Data/Training/TrainContext.cs b/src/Microsoft.ML.Data/Training/TrainContext.cs index b15bd1b317..3d06ef4b5f 100644 --- a/src/Microsoft.ML.Data/Training/TrainContext.cs +++ b/src/Microsoft.ML.Data/Training/TrainContext.cs @@ -234,4 +234,47 @@ public RegressionEvaluator.Result Evaluate(IDataView data, string label, string return eval.Evaluate(data, label, score); } } + + /// + /// The central context for regression trainers. + /// + public sealed class RankerContext : TrainContextBase + { + /// + /// For trainers for performing regression. + /// + public RankerTrainers Trainers { get; } + + public RankerContext(IHostEnvironment env) + : base(env, nameof(RankerContext)) + { + Trainers = new RankerTrainers(this); + } + + public sealed class RankerTrainers : ContextInstantiatorBase + { + internal RankerTrainers(RankerContext ctx) + : base(ctx) + { + } + } + + /// + /// Evaluates scored regression data. + /// + /// The scored data. + /// The name of the label column in . + /// The name of the groupId column in . + /// The name of the score column in . + /// The evaluation results for these calibrated outputs. + public RankerEvaluator.Result Evaluate(IDataView data, string label, string groupId, string score = DefaultColumnNames.Score) + { + Host.CheckValue(data, nameof(data)); + Host.CheckNonEmpty(label, nameof(label)); + Host.CheckNonEmpty(score, nameof(score)); + + var eval = new RankerEvaluator(Host, new RankerEvaluator.Arguments() { }); + return eval.Evaluate(data, label, groupId, score); + } + } } From fdcfe0c334c9300c375bc95ac62ab2da8e3844b0 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Tue, 25 Sep 2018 23:07:49 -0700 Subject: [PATCH 2/7] Adding the FastTree ranking xtension and test. --- .../StaticPipe/TrainerEstimatorReconciler.cs | 2 +- src/Microsoft.ML.FastTree/FastTreeStatic.cs | 42 +++++++++++++++++++ .../Training.cs | 39 +++++++++++++++++ 3 files changed, 82 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs b/src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs index 952ec7bf21..010dc5916a 100644 --- a/src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs +++ b/src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs @@ -439,7 +439,7 @@ public sealed class Ranker : TrainerEstimatorReconciler /// The input features column. /// The input weights column, or null if there are no weights. /// The input groupId column. - public Ranker(EstimatorFactory estimatorFactory, Scalar label, Vector features, Scalar weights, Scalar groupId) + public Ranker(EstimatorFactory estimatorFactory, Scalar label, Vector features, Scalar groupId, Scalar weights) : base(MakeInputs(Contracts.CheckRef(label, nameof(label)), Contracts.CheckRef(features, nameof(features)), Contracts.CheckRef(groupId, nameof(groupId)), diff --git a/src/Microsoft.ML.FastTree/FastTreeStatic.cs b/src/Microsoft.ML.FastTree/FastTreeStatic.cs index 242834091f..6933e7fd26 100644 --- a/src/Microsoft.ML.FastTree/FastTreeStatic.cs +++ b/src/Microsoft.ML.FastTree/FastTreeStatic.cs @@ -105,6 +105,48 @@ public static (Scalar score, Scalar probability, Scalar pred return rec.Output; } + /// + /// FastTree . + /// + /// The . + /// The label column. + /// The features colum. + /// The name of the groupId column. + /// The weights column. + /// The number of leaves to use. + /// Total number of decision trees to create in the ensemble. + /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. + /// The learning rate. + /// Algorithm advanced settings. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained. Note that this action cannot change the result in any way; + /// it is only a way for the caller to be informed about what was learnt. + /// The Score output column indicating the predicted value. + public static Scalar FastTree(this RankerContext.RankerTrainers ctx, + Scalar label, Vector features, Key groupId, Scalar weights = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null, + Action onFit = null) + { + CheckUserValues(label, features, weights, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings, onFit); + + var rec = new TrainerEstimatorReconciler.Ranker( + (env, labelName, featuresName, groupIdName, weightsName) => + { + var trainer = new FastTreeRankingTrainer(env, labelName, featuresName, groupIdName, weightsName,advancedSettings); + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + return trainer; + }, label, features, groupId, weights); + + return rec.Score; + } + private static void CheckUserValues(PipelineColumn label, Vector features, Scalar weights, int numLeaves, int numTrees, diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs index e3fc685c6d..7ee61bc59d 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs @@ -343,5 +343,44 @@ public void FastTreeRegression() Assert.Equal(metrics.Rms * metrics.Rms, metrics.L2, 5); Assert.InRange(metrics.LossFn, 0, double.PositiveInfinity); } + + [Fact] + public void FastTreeRanking() + { + var env = new ConsoleEnvironment(seed: 0); + var dataPath = GetDataPath(TestDatasets.adultRanking.trainFilename); + var dataSource = new MultiFileSource(dataPath); + + var ctx = new RankerContext(env); + + var reader = TextLoader.CreateReader(env, + c => (label: c.LoadFloat(0), features: c.LoadFloat(9, 14), groupId: c.LoadText(1)), + separator: ';', hasHeader: true); + + FastTreeRankingPredictor pred = null; + + var est = reader.MakeNewEstimator() + .Append(r => (r.label, r.features, r.groupId.ToKey())) + .Append(r => (r.label, r.groupId, score: ctx.Trainers.FastTree(r.label, r.features, r.groupId, + onFit: (p) => { pred = p; }))); + + var pipe = reader.Append(est); + + Assert.Null(pred); + var model = pipe.Fit(dataSource); + Assert.NotNull(pred); + // 11 input features, so we ought to have 11 weights. + VBuffer weights = new VBuffer(); + pred.GetFeatureWeights(ref weights); + Assert.Equal(11, weights.Length); + + var data = model.Read(dataSource); + + var metrics = ctx.Evaluate(data, r => r.label, r => r.score, new PoissonLoss()); + // Run a sanity check against a few of the metrics. + //Assert.InRange(metrics.Dcg, 0, double.PositiveInfinity); + //Assert.InRange(metrics.Ndcg, 0, double.PositiveInfinity); + //Assert.InRange(metrics.MaxDcg, 0, double.PositiveInfinity); + } } } From 4857e80f0ac87f70877f34e500e853e5a06345aa Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Wed, 26 Sep 2018 22:26:42 -0700 Subject: [PATCH 3/7] The groupId column should be a key, rather than a float. fixing the comments. adding checks to the test. --- .../Evaluators/EvaluatorStaticExtensions.cs | 9 +++--- .../Evaluators/RankerEvaluator.cs | 8 +++--- .../StaticPipe/TrainerEstimatorReconciler.cs | 16 +++++------ .../Training/TrainContext.cs | 3 +- src/Microsoft.ML.FastTree/FastTreeStatic.cs | 6 ++-- .../Training.cs | 28 ++++++++++--------- 6 files changed, 37 insertions(+), 33 deletions(-) diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs index 03ff419981..394e446a57 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs @@ -172,17 +172,18 @@ public static RegressionEvaluator.Result Evaluate( /// Evaluates scored ranking data. /// /// The shape type for the input data. - /// The regression context. + /// The type of data, before being converted to a key. + /// The ranking context. /// The data to evaluate. /// The index delegate for the label column. /// The index delegate for the groupId column. /// The index delegate for predicted score column. /// The evaluation metrics. - public static RankerEvaluator.Result Evaluate( - this RegressionContext ctx, + public static RankerEvaluator.Result Evaluate( + this RankerContext ctx, DataView data, Func> label, - Func> groupId, + Func> groupId, Func> score) { Contracts.CheckValue(data, nameof(data)); diff --git a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs index a305018e15..7ad0d7ed0e 100644 --- a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs @@ -578,11 +578,11 @@ private static T Fetch(IExceptionContext ectx, IRow row, string name) internal Result(IExceptionContext ectx, IRow overallResult) { - double[] Fetch(string name) => Fetch(ectx, overallResult, name); + VBuffer Fetch(string name) => Fetch>(ectx, overallResult, name); - Dcg = Fetch(RankerEvaluator.Dcg); - Ndcg = Fetch(RankerEvaluator.Ndcg); - MaxDcg = Fetch(RankerEvaluator.MaxDcg); + Dcg = Fetch(RankerEvaluator.Dcg).Values; + Ndcg = Fetch(RankerEvaluator.Ndcg).Values; + // MaxDcg = Fetch(RankerEvaluator.MaxDcg).Values; } } } diff --git a/src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs b/src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs index 010dc5916a..6e71f2c610 100644 --- a/src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs +++ b/src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs @@ -404,18 +404,18 @@ public ImplScore(MulticlassClassifier rec) : base(rec, rec.Inputs) { } } /// - /// A reconciler for regression capable of handling the most common cases for regression. + /// A reconciler for ranking capable of handling the most common cases for ranking. /// - public sealed class Ranker : TrainerEstimatorReconciler + public sealed class Ranker : TrainerEstimatorReconciler { /// - /// The delegate to create the regression trainer instance. + /// The delegate to create the ranking trainer instance. /// /// The environment with which to create the estimator /// The label column name /// The features column name /// The weights column name, or null if the reconciler was constructed with null weights - /// The groupID column name. + /// The groupId column name. /// A estimator producing columns with the fixed name . public delegate IEstimator EstimatorFactory(IHostEnvironment env, string label, string features, string weights, string groupId); @@ -439,7 +439,7 @@ public sealed class Ranker : TrainerEstimatorReconciler /// The input features column. /// The input weights column, or null if there are no weights. /// The input groupId column. - public Ranker(EstimatorFactory estimatorFactory, Scalar label, Vector features, Scalar groupId, Scalar weights) + public Ranker(EstimatorFactory estimatorFactory, Scalar label, Vector features, Key groupId, Scalar weights) : base(MakeInputs(Contracts.CheckRef(label, nameof(label)), Contracts.CheckRef(features, nameof(features)), Contracts.CheckRef(groupId, nameof(groupId)), @@ -448,11 +448,11 @@ public Ranker(EstimatorFactory estimatorFactory, Scalar label, Vector label, Vector features, Scalar weights, Scalar groupId) + private static PipelineColumn[] MakeInputs(Scalar label, Vector features, Key groupId, Scalar weights) => weights == null ? new PipelineColumn[] { label, features, groupId } : new PipelineColumn[] { label, features, groupId, weights }; protected override IEstimator ReconcileCore(IHostEnvironment env, string[] inputNames) @@ -464,7 +464,7 @@ protected override IEstimator ReconcileCore(IHostEnvironment env, private sealed class Impl : Scalar { - public Impl(Ranker rec) : base(rec, rec.Inputs) { } + public Impl(Ranker rec) : base(rec, rec.Inputs) { } } } } diff --git a/src/Microsoft.ML.Data/Training/TrainContext.cs b/src/Microsoft.ML.Data/Training/TrainContext.cs index 3d06ef4b5f..860e8bc84b 100644 --- a/src/Microsoft.ML.Data/Training/TrainContext.cs +++ b/src/Microsoft.ML.Data/Training/TrainContext.cs @@ -260,7 +260,7 @@ internal RankerTrainers(RankerContext ctx) } /// - /// Evaluates scored regression data. + /// Evaluates scored ranking data. /// /// The scored data. /// The name of the label column in . @@ -272,6 +272,7 @@ public RankerEvaluator.Result Evaluate(IDataView data, string label, string grou Host.CheckValue(data, nameof(data)); Host.CheckNonEmpty(label, nameof(label)); Host.CheckNonEmpty(score, nameof(score)); + Host.CheckNonEmpty(groupId, nameof(groupId)); var eval = new RankerEvaluator(Host, new RankerEvaluator.Arguments() { }); return eval.Evaluate(data, label, groupId, score); diff --git a/src/Microsoft.ML.FastTree/FastTreeStatic.cs b/src/Microsoft.ML.FastTree/FastTreeStatic.cs index 6933e7fd26..7b38c5d3ac 100644 --- a/src/Microsoft.ML.FastTree/FastTreeStatic.cs +++ b/src/Microsoft.ML.FastTree/FastTreeStatic.cs @@ -124,8 +124,8 @@ public static (Scalar score, Scalar probability, Scalar pred /// the linear model that was trained. Note that this action cannot change the result in any way; /// it is only a way for the caller to be informed about what was learnt. /// The Score output column indicating the predicted value. - public static Scalar FastTree(this RankerContext.RankerTrainers ctx, - Scalar label, Vector features, Key groupId, Scalar weights = null, + public static Scalar FastTree(this RankerContext.RankerTrainers ctx, + Scalar label, Vector features, Key groupId, Scalar weights = null, int numLeaves = Defaults.NumLeaves, int numTrees = Defaults.NumTrees, int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, @@ -135,7 +135,7 @@ public static Scalar FastTree(this RankerContext.RankerTrainers ctx, { CheckUserValues(label, features, weights, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings, onFit); - var rec = new TrainerEstimatorReconciler.Ranker( + var rec = new TrainerEstimatorReconciler.Ranker( (env, labelName, featuresName, groupIdName, weightsName) => { var trainer = new FastTreeRankingTrainer(env, labelName, featuresName, groupIdName, weightsName,advancedSettings); diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs index 7ee61bc59d..05e5beac5b 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs @@ -355,32 +355,34 @@ public void FastTreeRanking() var reader = TextLoader.CreateReader(env, c => (label: c.LoadFloat(0), features: c.LoadFloat(9, 14), groupId: c.LoadText(1)), - separator: ';', hasHeader: true); + separator: '\t', hasHeader: true); FastTreeRankingPredictor pred = null; var est = reader.MakeNewEstimator() - .Append(r => (r.label, r.features, r.groupId.ToKey())) - .Append(r => (r.label, r.groupId, score: ctx.Trainers.FastTree(r.label, r.features, r.groupId, - onFit: (p) => { pred = p; }))); + .Append(r => (r.label, r.features, groupId: r.groupId.ToKey())) + .Append(r => (r.label, r.groupId, score: ctx.Trainers.FastTree(r.label, r.features, r.groupId, onFit: (p) => { pred = p; }))); var pipe = reader.Append(est); Assert.Null(pred); var model = pipe.Fit(dataSource); Assert.NotNull(pred); - // 11 input features, so we ought to have 11 weights. - VBuffer weights = new VBuffer(); - pred.GetFeatureWeights(ref weights); - Assert.Equal(11, weights.Length); var data = model.Read(dataSource); - var metrics = ctx.Evaluate(data, r => r.label, r => r.score, new PoissonLoss()); - // Run a sanity check against a few of the metrics. - //Assert.InRange(metrics.Dcg, 0, double.PositiveInfinity); - //Assert.InRange(metrics.Ndcg, 0, double.PositiveInfinity); - //Assert.InRange(metrics.MaxDcg, 0, double.PositiveInfinity); + var metrics = ctx.Evaluate(data, r => r.label, r => r.groupId, r => r.score); + Assert.NotNull(metrics); + + Assert.True(metrics.Ndcg.Length == metrics.Dcg.Length && metrics.Dcg.Length == 3); + + Assert.InRange(metrics.Dcg[0], 1.4, 1.6); + Assert.InRange(metrics.Dcg[1], 1.4, 1.8); + Assert.InRange(metrics.Dcg[2], 1.4, 1.8); + + Assert.InRange(metrics.Ndcg[0], 36.5, 37); + Assert.InRange(metrics.Ndcg[1], 36.5, 37); + Assert.InRange(metrics.Ndcg[2], 36.5, 37); } } } From 52c46418811cc361fa82ef08d7c20d5a2fc6e319 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Thu, 27 Sep 2018 23:11:18 -0700 Subject: [PATCH 4/7] Grouping the xtensions in classes with more meaningful names, since the docs site displays the methods per class, not file. Addressing Tom's comments . --- .../Evaluators/RankerEvaluator.cs | 8 -- src/Microsoft.ML.FastTree/FastTreeStatic.cs | 72 +++++++++------- src/Microsoft.ML.LightGBM/LightGBMStatics.cs | 14 ++- .../Standard/SdcaStatic.cs | 85 ++++++++++--------- 4 files changed, 99 insertions(+), 80 deletions(-) diff --git a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs index 7ad0d7ed0e..5f7e9c5fa6 100644 --- a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs @@ -560,13 +560,6 @@ public sealed class Result /// public double[] Dcg { get; } - /// - /// MaxDcgs is the value of when the documents are ordered in the ideal order from most relevant to least relevant. - /// In case there are ties in scores, metrics are computed in a pessimistic fashion. In other words, if two or more results get the same score, - /// for the purpose of computing DCG and NDCG they are ordered from least relevant to most relevant. - /// - public double[] MaxDcg { get; } - private static T Fetch(IExceptionContext ectx, IRow row, string name) { if (!row.Schema.TryGetColumnIndex(name, out int col)) @@ -582,7 +575,6 @@ internal Result(IExceptionContext ectx, IRow overallResult) Dcg = Fetch(RankerEvaluator.Dcg).Values; Ndcg = Fetch(RankerEvaluator.Ndcg).Values; - // MaxDcg = Fetch(RankerEvaluator.MaxDcg).Values; } } } diff --git a/src/Microsoft.ML.FastTree/FastTreeStatic.cs b/src/Microsoft.ML.FastTree/FastTreeStatic.cs index 7b38c5d3ac..c5c0f4c86f 100644 --- a/src/Microsoft.ML.FastTree/FastTreeStatic.cs +++ b/src/Microsoft.ML.FastTree/FastTreeStatic.cs @@ -16,18 +16,19 @@ namespace Microsoft.ML.Trainers /// /// FastTree extension methods. /// - public static class FastTreeStatic + public static partial class RegressionTrainers { /// /// FastTree extension method. + /// Predicts a target using a decision tree regression model trained with the . /// /// The . /// The label column. /// The features colum. - /// The weights column. - /// The number of leaves to use. + /// The optional weights column. /// Total number of decision trees to create in the ensemble. - /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. + /// The maximum number of leaves per decision tree. + /// The minimal number of datapoints allowed in a leaf of a regression tree, out of the subsampled data. /// The learning rate. /// Algorithm advanced settings. /// A delegate that is called every time the @@ -40,18 +41,18 @@ public static Scalar FastTree(this RegressionContext.RegressionTrainers c Scalar label, Vector features, Scalar weights = null, int numLeaves = Defaults.NumLeaves, int numTrees = Defaults.NumTrees, - int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, - double learningRate= Defaults.LearningRates, + int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, Action advancedSettings = null, Action onFit = null) { - CheckUserValues(label, features, weights, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings, onFit); + FastTreeStaticsUtils.CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); var rec = new TrainerEstimatorReconciler.Regression( (env, labelName, featuresName, weightsName) => { var trainer = new FastTreeRegressionTrainer(env, labelName, featuresName, weightsName, numLeaves, - numTrees, minDocumentsInLeafs, learningRate, advancedSettings); + numTrees, minDatapointsInLeafs, learningRate, advancedSettings); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); return trainer; @@ -59,17 +60,22 @@ public static Scalar FastTree(this RegressionContext.RegressionTrainers c return rec.Score; } + } + + public static partial class BinaryClassificationTrainers + { /// /// FastTree extension method. + /// Predict a target using a decision tree binary classificaiton model trained with the . /// /// The . /// The label column. /// The features colum. - /// The weights column. - /// The number of leaves to use. + /// The optional weights column. /// Total number of decision trees to create in the ensemble. - /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. + /// The maximum number of leaves per decision tree. + /// The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. /// Algorithm advanced settings. /// A delegate that is called every time the @@ -83,39 +89,44 @@ public static (Scalar score, Scalar probability, Scalar pred Scalar label, Vector features, Scalar weights = null, int numLeaves = Defaults.NumLeaves, int numTrees = Defaults.NumTrees, - int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, + int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, double learningRate = Defaults.LearningRates, Action advancedSettings = null, Action> onFit = null) { - CheckUserValues(label, features, weights, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings, onFit); + FastTreeStaticsUtils.CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); var rec = new TrainerEstimatorReconciler.BinaryClassifier( (env, labelName, featuresName, weightsName) => { var trainer = new FastTreeBinaryClassificationTrainer(env, labelName, featuresName, weightsName, numLeaves, - numTrees, minDocumentsInLeafs, learningRate, advancedSettings); + numTrees, minDatapointsInLeafs, learningRate, advancedSettings); - if (onFit != null) - return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); - else - return trainer; + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + else + return trainer; }, label, features, weights); return rec.Output; } + } + + public static partial class RankingTrainers + { /// /// FastTree . + /// Ranks a series of inputs based on their relevance, training a decision tree ranking model through the . /// /// The . /// The label column. /// The features colum. - /// The name of the groupId column. - /// The weights column. - /// The number of leaves to use. + /// The groupId column. + /// The optional weights column. /// Total number of decision trees to create in the ensemble. - /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. + /// The maximum number of leaves per decision tree. + /// The minimal number of datapoints allowed in a leaf of a regression tree, out of the subsampled data. /// The learning rate. /// Algorithm advanced settings. /// A delegate that is called every time the @@ -128,17 +139,17 @@ public static Scalar FastTree(this RankerContext.RankerTrainers ctx Scalar label, Vector features, Key groupId, Scalar weights = null, int numLeaves = Defaults.NumLeaves, int numTrees = Defaults.NumTrees, - int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, + int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, double learningRate = Defaults.LearningRates, Action advancedSettings = null, Action onFit = null) { - CheckUserValues(label, features, weights, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings, onFit); + FastTreeStaticsUtils.CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); var rec = new TrainerEstimatorReconciler.Ranker( - (env, labelName, featuresName, groupIdName, weightsName) => + (env, labelName, featuresName, groupIdName, weightsName) => { - var trainer = new FastTreeRankingTrainer(env, labelName, featuresName, groupIdName, weightsName,advancedSettings); + var trainer = new FastTreeRankingTrainer(env, labelName, featuresName, groupIdName, weightsName, advancedSettings); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); return trainer; @@ -146,11 +157,14 @@ public static Scalar FastTree(this RankerContext.RankerTrainers ctx return rec.Score; } + } - private static void CheckUserValues(PipelineColumn label, Vector features, Scalar weights, + internal class FastTreeStaticsUtils + { + internal static void CheckUserValues(PipelineColumn label, Vector features, Scalar weights, int numLeaves, int numTrees, - int minDocumentsInLeafs, + int minDatapointsInLeafs, double learningRate, Delegate advancedSettings, Delegate onFit) @@ -160,7 +174,7 @@ private static void CheckUserValues(PipelineColumn label, Vector features Contracts.CheckValueOrNull(weights); Contracts.CheckParam(numLeaves >= 2, nameof(numLeaves), "Must be at least 2."); Contracts.CheckParam(numTrees > 0, nameof(numTrees), "Must be positive"); - Contracts.CheckParam(minDocumentsInLeafs > 0, nameof(minDocumentsInLeafs), "Must be positive"); + Contracts.CheckParam(minDatapointsInLeafs > 0, nameof(minDatapointsInLeafs), "Must be positive"); Contracts.CheckParam(learningRate > 0, nameof(learningRate), "Must be positive"); Contracts.CheckValueOrNull(advancedSettings); Contracts.CheckValueOrNull(onFit); diff --git a/src/Microsoft.ML.LightGBM/LightGBMStatics.cs b/src/Microsoft.ML.LightGBM/LightGBMStatics.cs index eb23d523b0..0ae4313fd5 100644 --- a/src/Microsoft.ML.LightGBM/LightGBMStatics.cs +++ b/src/Microsoft.ML.LightGBM/LightGBMStatics.cs @@ -16,7 +16,7 @@ namespace Microsoft.ML.Trainers /// /// LightGbm extension methods. /// - public static class LightGbmStatics + public static partial class RegressionTrainers { /// /// LightGbm extension method. @@ -45,7 +45,7 @@ public static Scalar LightGbm(this RegressionContext.RegressionTrainers c Action advancedSettings = null, Action onFit = null) { - CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); + LightGbmStaticsUtils.CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); var rec = new TrainerEstimatorReconciler.Regression( (env, labelName, featuresName, weightsName) => @@ -59,6 +59,9 @@ public static Scalar LightGbm(this RegressionContext.RegressionTrainers c return rec.Score; } + } + + public static partial class ClassificationTrainers { /// /// LightGbm extension method. @@ -88,7 +91,7 @@ public static (Scalar score, Scalar probability, Scalar pred Action advancedSettings = null, Action> onFit = null) { - CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); + LightGbmStaticsUtils.CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); var rec = new TrainerEstimatorReconciler.BinaryClassifier( (env, labelName, featuresName, weightsName) => @@ -104,8 +107,11 @@ public static (Scalar score, Scalar probability, Scalar pred return rec.Output; } + } + + internal class LightGbmStaticsUtils { - private static void CheckUserValues(PipelineColumn label, Vector features, Scalar weights, + internal static void CheckUserValues(PipelineColumn label, Vector features, Scalar weights, int? numLeaves, int? minDataPerLeaf, double? learningRate, diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs index 113092ed26..f49b17d0a4 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs @@ -16,7 +16,7 @@ namespace Microsoft.ML.Trainers /// /// Extension methods and utilities for instantiating SDCA trainer estimators inside statically typed pipelines. /// - public static class SdcaStatic + public static partial class RegressionTrainers { /// /// Predict a target using a linear regression model trained with the SDCA trainer. @@ -73,6 +73,24 @@ public static Scalar Sdca(this RegressionContext.RegressionTrainers ctx, return rec.Score; } + private sealed class TrivialRegressionLossFactory : ISupportSdcaRegressionLossFactory + { + private readonly ISupportSdcaRegressionLoss _loss; + + public TrivialRegressionLossFactory(ISupportSdcaRegressionLoss loss) + { + _loss = loss; + } + + public ISupportSdcaRegressionLoss CreateComponent(IHostEnvironment env) + { + return _loss; + } + } + } + public static partial class BinaryClassificationTrainers + { + /// /// Predict a target using a linear binary classification model trained with the SDCA trainer, and log-loss. /// @@ -91,12 +109,12 @@ public static Scalar Sdca(this RegressionContext.RegressionTrainers ctx, /// The set of output columns including in order the predicted binary classification score (which will range /// from negative to positive infinity), the calibrated prediction (from 0 to 1), and the predicted label. public static (Scalar score, Scalar probability, Scalar predictedLabel) Sdca( - this BinaryClassificationContext.BinaryClassificationTrainers ctx, - Scalar label, Vector features, Scalar weights = null, - float? l2Const = null, - float? l1Threshold = null, - int? maxIterations = null, - Action onFit = null) + this BinaryClassificationContext.BinaryClassificationTrainers ctx, + Scalar label, Vector features, Scalar weights = null, + float? l2Const = null, + float? l1Threshold = null, + int? maxIterations = null, + Action onFit = null) { Contracts.CheckValue(label, nameof(label)); Contracts.CheckValue(features, nameof(features)); @@ -205,6 +223,9 @@ public static (Scalar score, Scalar predictedLabel) Sdca( return rec.Output; } + } + + public static partial class MultiClassClassificationTrainers { /// /// Predict a target using a linear multiclass classification model trained with the SDCA trainer. @@ -224,15 +245,15 @@ public static (Scalar score, Scalar predictedLabel) Sdca( /// result in any way; it is only a way for the caller to be informed about what was learnt. /// The set of output columns including in order the predicted per-class likelihoods (between 0 and 1, and summing up to 1), and the predicted label. public static (Vector score, Key predictedLabel) - Sdca(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx, - Key label, - Vector features, - ISupportSdcaClassificationLoss loss = null, - Scalar weights = null, - float? l2Const = null, - float? l1Threshold = null, - int? maxIterations = null, - Action onFit = null) + Sdca(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx, + Key label, + Vector features, + ISupportSdcaClassificationLoss loss = null, + Scalar weights = null, + float? l2Const = null, + float? l1Threshold = null, + int? maxIterations = null, + Action onFit = null) { Contracts.CheckValue(label, nameof(label)); Contracts.CheckValue(features, nameof(features)); @@ -264,34 +285,20 @@ public static (Vector score, Key predictedLabel) return rec.Output; } - private sealed class TrivialRegressionLossFactory : ISupportSdcaRegressionLossFactory - { - private readonly ISupportSdcaRegressionLoss _loss; + } - public TrivialRegressionLossFactory(ISupportSdcaRegressionLoss loss) - { - _loss = loss; - } + internal sealed class TrivialClassificationLossFactory : ISupportSdcaClassificationLossFactory + { + private readonly ISupportSdcaClassificationLoss _loss; - public ISupportSdcaRegressionLoss CreateComponent(IHostEnvironment env) - { - return _loss; - } + public TrivialClassificationLossFactory(ISupportSdcaClassificationLoss loss) + { + _loss = loss; } - private sealed class TrivialClassificationLossFactory : ISupportSdcaClassificationLossFactory + public ISupportSdcaClassificationLoss CreateComponent(IHostEnvironment env) { - private readonly ISupportSdcaClassificationLoss _loss; - - public TrivialClassificationLossFactory(ISupportSdcaClassificationLoss loss) - { - _loss = loss; - } - - public ISupportSdcaClassificationLoss CreateComponent(IHostEnvironment env) - { - return _loss; - } + return _loss; } } } From fb8470e88de3e084dd8e193a3f2f6b60842927e5 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Fri, 28 Sep 2018 11:11:34 -0700 Subject: [PATCH 5/7] namespace change --- src/Microsoft.ML.LightGBM/LightGBMStatics.cs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/Microsoft.ML.LightGBM/LightGBMStatics.cs b/src/Microsoft.ML.LightGBM/LightGBMStatics.cs index 0ae4313fd5..e564c75a12 100644 --- a/src/Microsoft.ML.LightGBM/LightGBMStatics.cs +++ b/src/Microsoft.ML.LightGBM/LightGBMStatics.cs @@ -6,15 +6,13 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.LightGBM; -using Microsoft.ML.Runtime.Training; -using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; using System; -namespace Microsoft.ML.Trainers +namespace Microsoft.ML.StaticPipe { /// - /// LightGbm extension methods. + /// Regression trainer estimators. /// public static partial class RegressionTrainers { @@ -61,6 +59,9 @@ public static Scalar LightGbm(this RegressionContext.RegressionTrainers c } } + /// + /// Binary Classification trainer estimators. + /// public static partial class ClassificationTrainers { /// @@ -109,7 +110,7 @@ public static (Scalar score, Scalar probability, Scalar pred } } - internal class LightGbmStaticsUtils { + internal static class LightGbmStaticsUtils { internal static void CheckUserValues(PipelineColumn label, Vector features, Scalar weights, int? numLeaves, From a6da6c13f79423edaef1ea8f5d3a76691fa12c81 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Fri, 28 Sep 2018 11:23:07 -0700 Subject: [PATCH 6/7] namespace rename --- src/Microsoft.ML.FastTree/FastTreeStatic.cs | 4 +--- src/Microsoft.ML.KMeansClustering/KMeansStatic.cs | 3 +-- .../FactorizationMachine/FactorizationMachineStatic.cs | 4 +--- src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs | 4 +--- test/Microsoft.ML.StaticPipelineTesting/Training.cs | 3 +-- 5 files changed, 5 insertions(+), 13 deletions(-) diff --git a/src/Microsoft.ML.FastTree/FastTreeStatic.cs b/src/Microsoft.ML.FastTree/FastTreeStatic.cs index c5c0f4c86f..41d1c1d4a0 100644 --- a/src/Microsoft.ML.FastTree/FastTreeStatic.cs +++ b/src/Microsoft.ML.FastTree/FastTreeStatic.cs @@ -6,12 +6,10 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.FastTree; using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Runtime.Training; -using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; using System; -namespace Microsoft.ML.Trainers +namespace Microsoft.ML.StaticPipe { /// /// FastTree extension methods. diff --git a/src/Microsoft.ML.KMeansClustering/KMeansStatic.cs b/src/Microsoft.ML.KMeansClustering/KMeansStatic.cs index 2a64ef8d31..b22af197ee 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansStatic.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansStatic.cs @@ -4,11 +4,10 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.KMeans; -using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; using System; -namespace Microsoft.ML.Trainers +namespace Microsoft.ML.StaticPipe { /// /// The trainer context extensions for the . diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs index ade11d4fbe..4b6fc19462 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs @@ -7,14 +7,12 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.FactorizationMachine; using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Training; -using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; using System; using System.Collections.Generic; using System.Linq; -namespace Microsoft.ML.Trainers +namespace Microsoft.ML.StaticPipe { /// /// Extension methods and utilities for instantiating FFM trainer estimators inside statically typed pipelines. diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs index f49b17d0a4..2496f72da9 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs @@ -7,11 +7,9 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Learners; -using Microsoft.ML.Runtime.Training; -using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; -namespace Microsoft.ML.Trainers +namespace Microsoft.ML.StaticPipe { /// /// Extension methods and utilities for instantiating SDCA trainer estimators inside statically typed pipelines. diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs index c187bdce05..2185529402 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs @@ -12,8 +12,7 @@ using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.LightGBM; using Microsoft.ML.Runtime.RunTests; -using Microsoft.ML.Runtime.Training; -using Microsoft.ML.Trainers; +using Microsoft.ML.StaticPipe; using System; using System.Linq; using Xunit; From 13e8d0bb0154f23d4e82ddbf3166d5938d449537 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Fri, 28 Sep 2018 19:54:10 -0700 Subject: [PATCH 7/7] post merge conflict. --- .../Scenarios/Api/CookbookSamples/CookbookSamples.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs index 15ffd7b0c6..d7cef14224 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs @@ -11,7 +11,6 @@ using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.StaticPipe; using Microsoft.ML.TestFramework; -using Microsoft.ML.Trainers; using Microsoft.ML.Transforms.Text; using Microsoft.ML.Transforms; using System;