Skip to content

More pigstensions #1084

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 28, 2018
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public Arguments()
BasePredictors = new[]
{
ComponentFactoryUtils.CreateFromFunction(
env => new OnlineGradientDescentTrainer(env, new OnlineGradientDescentTrainer.Arguments()))
env => new OnlineGradientDescentTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features))
};
}
}
Expand Down
28 changes: 17 additions & 11 deletions src/Microsoft.ML.KMeansClustering/KMeansStatic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.KMeans;
using Microsoft.ML.StaticPipe;
using Microsoft.ML.StaticPipe.Runtime;
using System;

namespace Microsoft.ML.Trainers
namespace Microsoft.ML.StaticPipe
{
/// <summary>
/// The trainer context extensions for the <see cref="KMeansPlusPlusTrainer"/>.
Expand All @@ -35,16 +35,22 @@ public static (Vector<float> score, Key<uint> predictedLabel) KMeans(this Cluste
Action<KMeansPlusPlusTrainer.Arguments> advancedSettings = null,
Action<KMeansPredictor> onFit = null)
{
var rec = new TrainerEstimatorReconciler.Clustering(
(env, featuresName, weightsName) =>
{
var trainer = new KMeansPlusPlusTrainer(env, featuresName, clustersCount, weightsName, advancedSettings);
Contracts.CheckValue(features, nameof(features));
Contracts.CheckValueOrNull(weights);
Contracts.CheckParam(clustersCount > 1, nameof(clustersCount), "If provided, must be greater than 1.");
Contracts.CheckValueOrNull(onFit);
Contracts.CheckValueOrNull(advancedSettings);

if (onFit != null)
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
else
return trainer;
}, features, weights);
var rec = new TrainerEstimatorReconciler.Clustering(
(env, featuresName, weightsName) =>
{
var trainer = new KMeansPlusPlusTrainer(env, featuresName, clustersCount, weightsName, advancedSettings);

if (onFit != null)
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
else
return trainer;
}, features, weights);

return rec.Output;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1359,7 +1359,7 @@ public void Add(Double summand)
public sealed class LinearClassificationTrainer : SdcaTrainerBase<BinaryPredictionTransformer<TScalarPredictor>, TScalarPredictor>
{
public const string LoadNameValue = "SDCA";
public const string UserNameValue = "Fast Linear (SA-SDCA)";
internal const string UserNameValue = "Fast Linear (SA-SDCA)";

public sealed class Arguments : ArgumentsBase
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,29 +26,29 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 regularization weight", ShortName = "l2", SortOrder = 50)]
[TGUI(Label = "L2 Weight", Description = "Weight of L2 regularizer term", SuggestedSweeps = "0,0.1,1")]
[TlcModule.SweepableFloatParamAttribute(0.0f, 1.0f, numSteps: 4)]
public float L2Weight = 1;
public float L2Weight = Defaults.L2Weight;

[Argument(ArgumentType.AtMostOnce, HelpText = "L1 regularization weight", ShortName = "l1", SortOrder = 50)]
[TGUI(Label = "L1 Weight", Description = "Weight of L1 regularizer term", SuggestedSweeps = "0,0.1,1")]
[TlcModule.SweepableFloatParamAttribute(0.0f, 1.0f, numSteps: 4)]
public float L1Weight = 1;
public float L1Weight = Defaults.L1Weight;

[Argument(ArgumentType.AtMostOnce, HelpText = "Tolerance parameter for optimization convergence. Lower = slower, more accurate",
ShortName = "ot", SortOrder = 50)]
[TGUI(Label = "Optimization Tolerance", Description = "Threshold for optimizer convergence", SuggestedSweeps = "1e-4,1e-7")]
[TlcModule.SweepableDiscreteParamAttribute(new object[] { 1e-4f, 1e-7f })]
public float OptTol = 1e-7f;
public float OptTol = Defaults.OptTol;

[Argument(ArgumentType.AtMostOnce, HelpText = "Memory size for L-BFGS. Lower=faster, less accurate",
ShortName = "m", SortOrder = 50)]
[TGUI(Description = "Memory size for L-BFGS", SuggestedSweeps = "5,20,50")]
[TlcModule.SweepableDiscreteParamAttribute("MemorySize", new object[] { 5, 20, 50 })]
public int MemorySize = 20;
public int MemorySize = Defaults.MemorySize;

[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum iterations.", ShortName = "maxiter")]
[TGUI(Label = "Max Number of Iterations")]
[TlcModule.SweepableLongParamAttribute("MaxIterations", 1, int.MaxValue)]
public int MaxIterations = int.MaxValue;
public int MaxIterations = Defaults.MaxIterations;

[Argument(ArgumentType.AtMostOnce, HelpText = "Run SGD to initialize LR weights, converging to this tolerance",
ShortName = "sgd")]
Expand Down Expand Up @@ -90,7 +90,17 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight
public bool DenseOptimizer = false;

[Argument(ArgumentType.AtMostOnce, HelpText = "Enforce non-negative weights", ShortName = "nn", SortOrder = 90)]
public bool EnforceNonNegativity = false;
public bool EnforceNonNegativity = Defaults.EnforceNonNegativity;

internal static class Defaults
{
internal const float L2Weight = 1;
internal const float L1Weight = 1;
internal const float OptTol = 1e-7f;
internal const int MemorySize = 20;
internal const int MaxIterations = int.MaxValue;
internal const bool EnforceNonNegativity = false;
}
}

private const string RegisterName = nameof(LbfgsTrainerBase<TArgs, TTransformer, TModel>);
Expand Down Expand Up @@ -142,40 +152,56 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight
public override TrainerInfo Info => _info;

internal LbfgsTrainerBase(IHostEnvironment env, string featureColumn, SchemaShape.Column labelColumn,
string weightColumn = null, Action<TArgs> advancedSettings = null)
: this(env, ArgsInit(featureColumn, labelColumn, weightColumn, advancedSettings), labelColumn)
string weightColumn, Action<TArgs> advancedSettings, float l1Weight,
float l2Weight,
float optimizationTolerance,
int memorySize,
bool enforceNoNegativity)
: this(env, ArgsInit(featureColumn, labelColumn, weightColumn, advancedSettings), labelColumn,
l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity)
{
}

internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column labelColumn)
internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column labelColumn,
float? l1Weight = null,
float? l2Weight = null,
float? optimizationTolerance = null,
int? memorySize = null,
bool? enforceNoNegativity = null)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
{
Host.CheckValue(args, nameof(args));
Args = args;

Contracts.CheckUserArg(!Args.UseThreads || Args.NumThreads > 0 || Args.NumThreads == null,
Host.CheckUserArg(!Args.UseThreads || Args.NumThreads > 0 || Args.NumThreads == null,
nameof(Args.NumThreads), "numThreads must be positive (or empty for default)");
Contracts.CheckUserArg(Args.L2Weight >= 0, nameof(Args.L2Weight), "Must be non-negative");
Contracts.CheckUserArg(Args.L1Weight >= 0, nameof(Args.L1Weight), "Must be non-negative");
Contracts.CheckUserArg(Args.OptTol > 0, nameof(Args.OptTol), "Must be positive");
Contracts.CheckUserArg(Args.MemorySize > 0, nameof(Args.MemorySize), "Must be positive");
Contracts.CheckUserArg(Args.MaxIterations > 0, nameof(Args.MaxIterations), "Must be positive");
Contracts.CheckUserArg(Args.SgdInitializationTolerance >= 0, nameof(Args.SgdInitializationTolerance), "Must be non-negative");
Contracts.CheckUserArg(Args.NumThreads == null || Args.NumThreads.Value >= 0, nameof(Args.NumThreads), "Must be non-negative");

L2Weight = Args.L2Weight;
L1Weight = Args.L1Weight;
OptTol = Args.OptTol;
MemorySize = Args.MemorySize;
Host.CheckUserArg(Args.L2Weight >= 0, nameof(Args.L2Weight), "Must be non-negative");
Host.CheckUserArg(Args.L1Weight >= 0, nameof(Args.L1Weight), "Must be non-negative");
Host.CheckUserArg(Args.OptTol > 0, nameof(Args.OptTol), "Must be positive");
Host.CheckUserArg(Args.MemorySize > 0, nameof(Args.MemorySize), "Must be positive");
Host.CheckUserArg(Args.MaxIterations > 0, nameof(Args.MaxIterations), "Must be positive");
Host.CheckUserArg(Args.SgdInitializationTolerance >= 0, nameof(Args.SgdInitializationTolerance), "Must be non-negative");
Host.CheckUserArg(Args.NumThreads == null || Args.NumThreads.Value >= 0, nameof(Args.NumThreads), "Must be non-negative");

Host.CheckParam(!(l2Weight < 0), nameof(l2Weight), "Must be non-negative, if provided.");
Host.CheckParam(!(l1Weight < 0), nameof(l1Weight), "Must be non-negative, if provided");
Host.CheckParam(!(optimizationTolerance <= 0), nameof(optimizationTolerance), "Must be positive, if provided.");
Host.CheckParam(!(memorySize <= 0), nameof(memorySize), "Must be positive, if provided.");

// Review: Warn about the overriding behavior
L2Weight = l2Weight ?? Args.L2Weight;
L1Weight = l1Weight ?? Args.L1Weight;
OptTol = optimizationTolerance ?? Args.OptTol;
MemorySize = memorySize ?? Args.MemorySize;
MaxIterations = Args.MaxIterations;
SgdInitializationTolerance = Args.SgdInitializationTolerance;
Quiet = Args.Quiet;
InitWtsDiameter = Args.InitWtsDiameter;
UseThreads = Args.UseThreads;
NumThreads = Args.NumThreads;
DenseOptimizer = Args.DenseOptimizer;
EnforceNonNegativity = Args.EnforceNonNegativity;
EnforceNonNegativity = enforceNoNegativity ?? Args.EnforceNonNegativity;

if (EnforceNonNegativity && ShowTrainingStats)
{
Expand Down
Loading