Skip to content

Commit 35e93ef

Browse files
committed
added catalog extensions and moved tensorflow arguments
1 parent 8a3df28 commit 35e93ef

File tree

7 files changed

+149
-150
lines changed

7 files changed

+149
-150
lines changed

src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ private void MapDist(in VBuffer<float> src, ref float score, ref float prob)
234234
}
235235

236236
/// <summary>
237-
/// Learns the prior distribution for 0/1 class labels and just outputs that.
237+
/// Learns the prior distribution for 0/1 class labels and outputs that.
238238
/// </summary>
239239
public sealed class PriorTrainer : TrainerBase<PriorModelParameters>,
240240
ITrainerEstimator<BinaryPredictionTransformer<PriorModelParameters>, PriorModelParameters>
@@ -263,8 +263,8 @@ internal PriorTrainer(IHostEnvironment env, Options options)
263263
/// <summary>
264264
/// Initializes PriorTrainer object.
265265
/// </summary>
266-
internal PriorTrainer(IHost host, String labelColumn, String weightColunn = null)
267-
: base(host, LoadNameValue)
266+
internal PriorTrainer(IHostEnvironment env, String labelColumn, String weightColunn = null)
267+
: base(env, LoadNameValue)
268268
{
269269
Contracts.CheckValue(labelColumn, nameof(labelColumn));
270270
Contracts.CheckValueOrNull(weightColunn);

src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs

+16-17
Original file line numberDiff line numberDiff line change
@@ -558,28 +558,27 @@ public static LinearSvmTrainer LinearSupportVectorMachines(this BinaryClassifica
558558
}
559559

560560
/// <summary>
561-
/// Predict a target using a linear binary classification model trained with the <see cref="LinearSvmTrainer"/> trainer.
561+
/// Predict a target using a random binary classification model <see cref="RandomTrainer"/>.
562562
/// </summary>
563-
/// <remarks>
564-
/// <para>
565-
/// The idea behind support vector machines, is to map instances into a high dimensional space
566-
/// in which the two classes are linearly separable, i.e., there exists a hyperplane such that all the positive examples are on one side of it,
567-
/// and all the negative examples are on the other.
568-
/// </para>
569-
/// <para>
570-
/// After this mapping, quadratic programming is used to find the separating hyperplane that maximizes the
571-
/// margin, i.e., the minimal distance between it and the instances.
572-
/// </para>
573-
/// </remarks>
574563
/// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
575-
/// <param name="options">Advanced arguments to the algorithm.</param>
576-
public static LinearSvmTrainer LinearSupportVectorMachines(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
577-
LinearSvmTrainer.Options options)
564+
public static RandomTrainer Random(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog)
578565
{
579566
Contracts.CheckValue(catalog, nameof(catalog));
580-
Contracts.CheckValue(options, nameof(options));
567+
return new RandomTrainer(CatalogUtils.GetEnvironment(catalog), new RandomTrainer.Options());
568+
}
581569

582-
return new LinearSvmTrainer(CatalogUtils.GetEnvironment(catalog), options);
570+
/// <summary>
571+
/// Predict a target using a binary classification model trained with <see cref="PriorTrainer"/> trainer.
572+
/// </summary>
573+
/// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
574+
/// <param name="labelColumn">The name of the label column. </param>
575+
/// <param name="weightsColumn">The optional name of the weights column.</param>
576+
public static PriorTrainer Prior(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
577+
string labelColumn = DefaultColumnNames.Label,
578+
string weightsColumn = null)
579+
{
580+
Contracts.CheckValue(catalog, nameof(catalog));
581+
return new PriorTrainer(CatalogUtils.GetEnvironment(catalog), labelColumn, weightsColumn);
583582
}
584583
}
585584
}

src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs

+7-7
Original file line numberDiff line numberDiff line change
@@ -85,23 +85,23 @@ public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog ca
8585
=> new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnNames, inputColumnNames, tensorFlowModel);
8686

8787
/// <summary>
88-
/// Score or Retrain a tensorflow model (based on setting of the <see cref="TensorFlowTransformer.Options.ReTrain"/>) setting.
89-
/// The model is specified in the <see cref="TensorFlowTransformer.Options.ModelLocation"/>.
88+
/// Score or Retrain a tensorflow model (based on setting of the <see cref="TensorFlowEstimator.Options.ReTrain"/>) setting.
89+
/// The model is specified in the <see cref="TensorFlowEstimator.Options.ModelLocation"/>.
9090
/// </summary>
9191
/// <param name="catalog">The transform's catalog.</param>
92-
/// <param name="options">The <see cref="TensorFlowTransformer.Options"/> specifying the inputs and the settings of the <see cref="TensorFlowEstimator"/>.</param>
92+
/// <param name="options">The <see cref="TensorFlowEstimator.Options"/> specifying the inputs and the settings of the <see cref="TensorFlowEstimator"/>.</param>
9393
public static TensorFlowEstimator TensorFlow(this TransformsCatalog catalog,
94-
TensorFlowTransformer.Options options)
94+
TensorFlowEstimator.Options options)
9595
=> new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), options);
9696

9797
/// <summary>
98-
/// Scores or retrains (based on setting of the <see cref="TensorFlowTransformer.Options.ReTrain"/>) a pre-traiend TensorFlow model specified via <paramref name="tensorFlowModel"/>.
98+
/// Scores or retrains (based on setting of the <see cref="TensorFlowEstimator.Options.ReTrain"/>) a pre-traiend TensorFlow model specified via <paramref name="tensorFlowModel"/>.
9999
/// </summary>
100100
/// <param name="catalog">The transform's catalog.</param>
101-
/// <param name="options">The <see cref="TensorFlowTransformer.Options"/> specifying the inputs and the settings of the <see cref="TensorFlowEstimator"/>.</param>
101+
/// <param name="options">The <see cref="TensorFlowEstimator.Options"/> specifying the inputs and the settings of the <see cref="TensorFlowEstimator"/>.</param>
102102
/// <param name="tensorFlowModel">The pre-loaded TensorFlow model.</param>
103103
public static TensorFlowEstimator TensorFlow(this TransformsCatalog catalog,
104-
TensorFlowTransformer.Options options,
104+
TensorFlowEstimator.Options options,
105105
TensorFlowModelInfo tensorFlowModel)
106106
=> new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), options, tensorFlowModel);
107107
}

0 commit comments

Comments
 (0)