diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs
index 8a24b86aef..25772bb82d 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs
@@ -28,7 +28,7 @@ public static void Example()
var mlContext = new MLContext();
// Create a text loader.
- var reader = mlContext.Data.CreateTextLoader(new TextLoader.Arguments()
+ var reader = mlContext.Data.CreateTextLoader(new TextLoader.Options()
{
Separators = new[] { '\t' },
HasHeader = true,
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/ConvertToGrayScale.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/ConvertToGrayScale.cs
index 87152b211b..b168575290 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/ConvertToGrayScale.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/ConvertToGrayScale.cs
@@ -23,7 +23,7 @@ public static void Example()
// hotdog.jpg hotdog
// tomato.jpg tomato
- var data = mlContext.Data.CreateTextLoader(new TextLoader.Arguments()
+ var data = mlContext.Data.CreateTextLoader(new TextLoader.Options()
{
Columns = new[]
{
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/ExtractPixels.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/ExtractPixels.cs
index 4f88cdf531..eb7e164004 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/ExtractPixels.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/ExtractPixels.cs
@@ -24,7 +24,7 @@ public static void Example()
// hotdog.jpg hotdog
// tomato.jpg tomato
- var data = mlContext.Data.CreateTextLoader(new TextLoader.Arguments()
+ var data = mlContext.Data.CreateTextLoader(new TextLoader.Options()
{
Columns = new[]
{
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/LoadImages.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/LoadImages.cs
index 5979c4bb3b..541f564283 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/LoadImages.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/LoadImages.cs
@@ -23,7 +23,7 @@ public static void Example()
// hotdog.jpg hotdog
// tomato.jpg tomato
- var data = mlContext.Data.CreateTextLoader(new TextLoader.Arguments()
+ var data = mlContext.Data.CreateTextLoader(new TextLoader.Options()
{
Columns = new[]
{
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/ResizeImages.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/ResizeImages.cs
index 4e59c20a83..03ada5304e 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/ResizeImages.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/ResizeImages.cs
@@ -23,7 +23,7 @@ public static void Example()
// hotdog.jpg hotdog
// tomato.jpg tomato
- var data = mlContext.Data.CreateTextLoader(new TextLoader.Arguments()
+ var data = mlContext.Data.CreateTextLoader(new TextLoader.Options()
{
Columns = new[]
{
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs
index aaf8d02e50..1a6aacfe33 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs
@@ -31,7 +31,7 @@ public static void Example()
// 14. Column: native-country (text/categorical)
// 15. Column: Column [Label]: IsOver50K (boolean)
- var reader = ml.Data.CreateTextLoader(new TextLoader.Arguments
+ var reader = ml.Data.CreateTextLoader(new TextLoader.Options
{
Separators = new[] { ',' },
HasHeader = true,
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs
index ee568bff92..830b5981cc 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs
@@ -24,7 +24,7 @@ public static void Example()
// Define the trainer options.
var options = new AveragedPerceptronTrainer.Options()
{
- LossFunction = new SmoothedHingeLoss.Arguments(),
+ LossFunction = new SmoothedHingeLoss.Options(),
LearningRate = 0.1f,
DoLazyUpdates = false,
RecencyGain = 0.1f,
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/OrdinaryLeastSquares.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/OrdinaryLeastSquares.cs
index c660b603c4..3a8a17952b 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/OrdinaryLeastSquares.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/OrdinaryLeastSquares.cs
@@ -22,7 +22,7 @@ public static void Example()
// The data is tab separated with all numeric columns.
// The first column being the label and rest are numeric features
// Here only seven numeric columns are used as features
- var dataView = mlContext.Data.ReadFromTextFile(dataFile, new TextLoader.Arguments
+ var dataView = mlContext.Data.ReadFromTextFile(dataFile, new TextLoader.Options
{
Separators = new[] { '\t' },
HasHeader = true,
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/OrdinaryLeastSquaresWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/OrdinaryLeastSquaresWithOptions.cs
index 1d72e487cd..519a9ef683 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/OrdinaryLeastSquaresWithOptions.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/OrdinaryLeastSquaresWithOptions.cs
@@ -23,7 +23,7 @@ public static void Example()
// The data is tab separated with all numeric columns.
// The first column being the label and rest are numeric features
// Here only seven numeric columns are used as features
- var dataView = mlContext.Data.ReadFromTextFile(dataFile, new TextLoader.Arguments
+ var dataView = mlContext.Data.ReadFromTextFile(dataFile, new TextLoader.Options
{
Separators = new[] { '\t' },
HasHeader = true,
diff --git a/src/Microsoft.ML.Data/Commands/DataCommand.cs b/src/Microsoft.ML.Data/Commands/DataCommand.cs
index 8997cfd29b..8e8a390f1e 100644
--- a/src/Microsoft.ML.Data/Commands/DataCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/DataCommand.cs
@@ -364,11 +364,11 @@ protected IDataLoader CreateRawLoader(
var isBinary = string.Equals(ext, ".idv", StringComparison.OrdinalIgnoreCase);
var isTranspose = string.Equals(ext, ".tdv", StringComparison.OrdinalIgnoreCase);
- return isText ? TextLoader.Create(Host, new TextLoader.Arguments(), fileSource) :
+ return isText ? TextLoader.Create(Host, new TextLoader.Options(), fileSource) :
isBinary ? new BinaryLoader(Host, new BinaryLoader.Arguments(), fileSource) :
isTranspose ? new TransposeLoader(Host, new TransposeLoader.Arguments(), fileSource) :
defaultLoaderFactory != null ? defaultLoaderFactory(Host, fileSource) :
- TextLoader.Create(Host, new TextLoader.Arguments(), fileSource);
+ TextLoader.Create(Host, new TextLoader.Options(), fileSource);
}
else
{
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
index 17ed87c05d..b1c0a40d88 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
@@ -15,7 +15,7 @@
using Microsoft.ML.Model;
using Float = System.Single;
-[assembly: LoadableClass(TextLoader.Summary, typeof(IDataLoader), typeof(TextLoader), typeof(TextLoader.Arguments), typeof(SignatureDataLoader),
+[assembly: LoadableClass(TextLoader.Summary, typeof(IDataLoader), typeof(TextLoader), typeof(TextLoader.Options), typeof(SignatureDataLoader),
"Text Loader", "TextLoader", "Text", DocName = "loader/TextLoader.md")]
[assembly: LoadableClass(TextLoader.Summary, typeof(IDataLoader), typeof(TextLoader), null, typeof(SignatureLoadDataLoader),
@@ -379,7 +379,7 @@ public bool IsValid()
}
}
- public sealed class Arguments : ArgumentsCore
+ public sealed class Options : ArgumentsCore
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Use separate parsing threads?", ShortName = "threads", Hide = true)]
public bool UseThreads = true;
@@ -936,7 +936,7 @@ private static VersionInfo GetVersionInfo()
/// bumping the version number.
///
[Flags]
- private enum Options : uint
+ private enum OptionFlags : uint
{
TrimWhitespace = 0x01,
HasHeader = 0x02,
@@ -950,7 +950,7 @@ private enum Options : uint
private const int SrcLim = int.MaxValue;
private readonly bool _useThreads;
- private readonly Options _flags;
+ private readonly OptionFlags _flags;
private readonly long _maxRows;
// Input size is zero for unknown - determined by the data (including sparse rows).
private readonly int _inputSize;
@@ -961,7 +961,7 @@ private enum Options : uint
private bool HasHeader
{
- get { return (_flags & Options.HasHeader) != 0; }
+ get { return (_flags & OptionFlags.HasHeader) != 0; }
}
private readonly IHost _host;
@@ -980,10 +980,10 @@ public TextLoader(IHostEnvironment env, Column[] columns, bool hasHeader = false
{
}
- private static Arguments MakeArgs(Column[] columns, bool hasHeader, char[] separatorChars)
+ private static Options MakeArgs(Column[] columns, bool hasHeader, char[] separatorChars)
{
Contracts.AssertValue(separatorChars);
- var result = new Arguments { Columns = columns, HasHeader = hasHeader, Separators = separatorChars};
+ var result = new Options { Columns = columns, HasHeader = hasHeader, Separators = separatorChars};
return result;
}
@@ -991,27 +991,27 @@ private static Arguments MakeArgs(Column[] columns, bool hasHeader, char[] separ
/// Loads a text file into an . Supports basic mapping from input columns to IDataView columns.
///
/// The environment to use.
- /// Defines the settings of the load operation.
+ /// Defines the settings of the load operation.
/// Allows to expose items that can be used for reading.
- public TextLoader(IHostEnvironment env, Arguments args = null, IMultiStreamSource dataSample = null)
+ public TextLoader(IHostEnvironment env, Options options = null, IMultiStreamSource dataSample = null)
{
- args = args ?? new Arguments();
+ options = options ?? new Options();
Contracts.CheckValue(env, nameof(env));
_host = env.Register(RegistrationName);
- _host.CheckValue(args, nameof(args));
+ _host.CheckValue(options, nameof(options));
_host.CheckValueOrNull(dataSample);
if (dataSample == null)
dataSample = new MultiFileSource(null);
IMultiStreamSource headerFile = null;
- if (!string.IsNullOrWhiteSpace(args.HeaderFile))
- headerFile = new MultiFileSource(args.HeaderFile);
+ if (!string.IsNullOrWhiteSpace(options.HeaderFile))
+ headerFile = new MultiFileSource(options.HeaderFile);
- var cols = args.Columns;
+ var cols = options.Columns;
bool error;
- if (Utils.Size(cols) == 0 && !TryParseSchema(_host, headerFile ?? dataSample, ref args, out cols, out error))
+ if (Utils.Size(cols) == 0 && !TryParseSchema(_host, headerFile ?? dataSample, ref options, out cols, out error))
{
if (error)
throw _host.Except("TextLoader options embedded in the file are invalid");
@@ -1026,43 +1026,43 @@ public TextLoader(IHostEnvironment env, Arguments args = null, IMultiStreamSourc
}
_host.Assert(Utils.Size(cols) > 0);
- _useThreads = args.UseThreads;
+ _useThreads = options.UseThreads;
- if (args.TrimWhitespace)
- _flags |= Options.TrimWhitespace;
- if (headerFile == null && args.HasHeader)
- _flags |= Options.HasHeader;
- if (args.AllowQuoting)
- _flags |= Options.AllowQuoting;
- if (args.AllowSparse)
- _flags |= Options.AllowSparse;
+ if (options.TrimWhitespace)
+ _flags |= OptionFlags.TrimWhitespace;
+ if (headerFile == null && options.HasHeader)
+ _flags |= OptionFlags.HasHeader;
+ if (options.AllowQuoting)
+ _flags |= OptionFlags.AllowQuoting;
+ if (options.AllowSparse)
+ _flags |= OptionFlags.AllowSparse;
// REVIEW: This should be persisted (if it should be maintained).
- _maxRows = args.MaxRows ?? long.MaxValue;
- _host.CheckUserArg(_maxRows >= 0, nameof(args.MaxRows));
+ _maxRows = options.MaxRows ?? long.MaxValue;
+ _host.CheckUserArg(_maxRows >= 0, nameof(options.MaxRows));
// Note that _maxDim == 0 means sparsity is illegal.
- _inputSize = args.InputSize ?? 0;
+ _inputSize = options.InputSize ?? 0;
_host.Check(_inputSize >= 0, "inputSize");
if (_inputSize >= SrcLim)
_inputSize = SrcLim - 1;
- _host.CheckNonEmpty(args.Separator, nameof(args.Separator), "Must specify a separator");
+ _host.CheckNonEmpty(options.Separator, nameof(options.Separator), "Must specify a separator");
//Default arg.Separator is tab and default args.Separators is also a '\t'.
//At a time only one default can be different and whichever is different that will
//be used.
- if (args.Separators.Length > 1 || args.Separators[0] != '\t')
+ if (options.Separators.Length > 1 || options.Separators[0] != '\t')
{
var separators = new HashSet();
- foreach (char c in args.Separators)
+ foreach (char c in options.Separators)
separators.Add(NormalizeSeparator(c.ToString()));
_separators = separators.ToArray();
}
else
{
- string sep = args.Separator.ToLowerInvariant();
+ string sep = options.Separator.ToLowerInvariant();
if (sep == ",")
_separators = new char[] { ',' };
else
@@ -1103,7 +1103,7 @@ private char NormalizeSeparator(string sep)
return ',';
case "colon":
case ":":
- _host.CheckUserArg((_flags & Options.AllowSparse) == 0, nameof(Arguments.Separator),
+ _host.CheckUserArg((_flags & OptionFlags.AllowSparse) == 0, nameof(Options.Separator),
"When the separator is colon, turn off allowSparse");
return ':';
case "semicolon":
@@ -1115,7 +1115,7 @@ private char NormalizeSeparator(string sep)
default:
char ch = sep[0];
if (sep.Length != 1 || ch < ' ' || '0' <= ch && ch <= '9' || ch == '"')
- throw _host.ExceptUserArg(nameof(Arguments.Separator), "Illegal separator: '{0}'", sep);
+ throw _host.ExceptUserArg(nameof(Options.Separator), "Illegal separator: '{0}'", sep);
return sep[0];
}
}
@@ -1134,7 +1134,7 @@ private sealed class LoaderHolder
// If so, update args and set cols to the combined set of columns.
// If not, set error to true if there was an error condition.
private static bool TryParseSchema(IHost host, IMultiStreamSource files,
- ref Arguments args, out Column[] cols, out bool error)
+ ref Options options, out Column[] cols, out bool error)
{
host.AssertValue(host);
host.AssertValue(files);
@@ -1144,7 +1144,7 @@ private static bool TryParseSchema(IHost host, IMultiStreamSource files,
// Verify that the current schema-defining arguments are default.
// Get settings just for core arguments, not everything.
- string tmp = CmdParser.GetSettings(host, args, new ArgumentsCore());
+ string tmp = CmdParser.GetSettings(host, options, new ArgumentsCore());
// Try to get the schema information from the file.
string str = Cursor.GetEmbeddedArgs(files);
@@ -1176,12 +1176,12 @@ private static bool TryParseSchema(IHost host, IMultiStreamSource files,
// Make sure the loader binds to us.
var info = host.ComponentCatalog.GetLoadableClassInfo(loader.Name);
- if (info.Type != typeof(IDataLoader) || info.ArgType != typeof(Arguments))
+ if (info.Type != typeof(IDataLoader) || info.ArgType != typeof(Options))
goto LDone;
- var argsNew = new Arguments();
+ var argsNew = new Options();
// Copy the non-core arguments to the new args (we already know that all the core arguments are default).
- var parsed = CmdParser.ParseArguments(host, CmdParser.GetSettings(host, args, new Arguments()), argsNew);
+ var parsed = CmdParser.ParseArguments(host, CmdParser.GetSettings(host, options, new Options()), argsNew);
ch.Assert(parsed);
// Copy the core arguments to the new args.
if (!CmdParser.ParseArguments(host, loader.GetSettingsString(), argsNew, typeof(ArgumentsCore), msg => ch.Error(msg)))
@@ -1192,7 +1192,7 @@ private static bool TryParseSchema(IHost host, IMultiStreamSource files,
goto LDone;
error = false;
- args = argsNew;
+ options = argsNew;
LDone:
return !error;
@@ -1202,16 +1202,16 @@ private static bool TryParseSchema(IHost host, IMultiStreamSource files,
///
/// Checks whether the source contains the valid TextLoader.Arguments depiction.
///
- public static bool FileContainsValidSchema(IHostEnvironment env, IMultiStreamSource files, out Arguments args)
+ public static bool FileContainsValidSchema(IHostEnvironment env, IMultiStreamSource files, out Options options)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(RegistrationName);
h.CheckValue(files, nameof(files));
- args = new Arguments();
+ options = new Options();
Column[] cols;
bool error;
- bool found = TryParseSchema(h, files, ref args, out cols, out error);
- return found && !error && args.IsValid();
+ bool found = TryParseSchema(h, files, ref options, out cols, out error);
+ return found && !error && options.IsValid();
}
private TextLoader(IHost host, ModelLoadContext ctx)
@@ -1236,8 +1236,8 @@ private TextLoader(IHost host, ModelLoadContext ctx)
host.CheckDecode(cbFloat == sizeof(Float));
_maxRows = ctx.Reader.ReadInt64();
host.CheckDecode(_maxRows > 0);
- _flags = (Options)ctx.Reader.ReadUInt32();
- host.CheckDecode((_flags & ~Options.All) == 0);
+ _flags = (OptionFlags)ctx.Reader.ReadUInt32();
+ host.CheckDecode((_flags & ~OptionFlags.All) == 0);
_inputSize = ctx.Reader.ReadInt32();
host.CheckDecode(0 <= _inputSize && _inputSize < SrcLim);
@@ -1253,7 +1253,7 @@ private TextLoader(IHost host, ModelLoadContext ctx)
}
if (_separators.Contains(':'))
- host.CheckDecode((_flags & Options.AllowSparse) == 0);
+ host.CheckDecode((_flags & OptionFlags.AllowSparse) == 0);
_bindings = new Bindings(ctx, this);
_parser = new Parser(this);
@@ -1273,14 +1273,14 @@ internal static TextLoader Create(IHostEnvironment env, ModelLoadContext ctx)
// These are legacy constructors needed for ComponentCatalog.
internal static IDataLoader Create(IHostEnvironment env, ModelLoadContext ctx, IMultiStreamSource files)
=> (IDataLoader)Create(env, ctx).Read(files);
- internal static IDataLoader Create(IHostEnvironment env, Arguments args, IMultiStreamSource files)
- => (IDataLoader)new TextLoader(env, args, files).Read(files);
+ internal static IDataLoader Create(IHostEnvironment env, Options options, IMultiStreamSource files)
+ => (IDataLoader)new TextLoader(env, options, files).Read(files);
///
/// Convenience method to create a and use it to read a specified file.
///
- internal static IDataView ReadFile(IHostEnvironment env, Arguments args, IMultiStreamSource fileSource)
- => new TextLoader(env, args, fileSource).Read(fileSource);
+ internal static IDataView ReadFile(IHostEnvironment env, Options options, IMultiStreamSource fileSource)
+ => new TextLoader(env, options, fileSource).Read(fileSource);
void ICanSaveModel.Save(ModelSaveContext ctx)
{
@@ -1298,7 +1298,7 @@ void ICanSaveModel.Save(ModelSaveContext ctx)
// bindings
ctx.Writer.Write(sizeof(Float));
ctx.Writer.Write(_maxRows);
- _host.Assert((_flags & ~Options.All) == 0);
+ _host.Assert((_flags & ~OptionFlags.All) == 0);
ctx.Writer.Write((uint)_flags);
_host.Assert(0 <= _inputSize && _inputSize < SrcLim);
ctx.Writer.Write(_inputSize);
@@ -1367,7 +1367,7 @@ internal static TextLoader CreateTextReader(IHostEnvironment host,
columns.Add(column);
}
- Arguments args = new Arguments
+ Options options = new Options
{
HasHeader = hasHeader,
Separators = new[] { separator },
@@ -1377,7 +1377,7 @@ internal static TextLoader CreateTextReader(IHostEnvironment host,
Columns = columns.ToArray()
};
- return new TextLoader(host, args);
+ return new TextLoader(host, options);
}
private sealed class BoundLoader : IDataLoader
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs
index 7c7305c723..df6252c6bc 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs
@@ -637,7 +637,7 @@ public void Clear()
}
private readonly char[] _separators;
- private readonly Options _flags;
+ private readonly OptionFlags _flags;
private readonly int _inputSize;
private readonly ColInfo[] _infos;
@@ -705,7 +705,7 @@ public static void GetInputSize(TextLoader parent, List> li
{
foreach (var line in lines)
{
- var text = (parent._flags & Options.TrimWhitespace) != 0 ? ReadOnlyMemoryUtils.TrimEndWhiteSpace(line) : line;
+ var text = (parent._flags & OptionFlags.TrimWhitespace) != 0 ? ReadOnlyMemoryUtils.TrimEndWhiteSpace(line) : line;
if (text.IsEmpty)
continue;
@@ -828,7 +828,7 @@ public void ParseRow(RowSet rows, int irow, Helper helper, bool[] active, string
var impl = (HelperImpl)helper;
var lineSpan = text.AsMemory();
var span = lineSpan.Span;
- if ((_flags & Options.TrimWhitespace) != 0)
+ if ((_flags & OptionFlags.TrimWhitespace) != 0)
lineSpan = TrimEndWhiteSpace(lineSpan, span);
try
{
@@ -883,7 +883,7 @@ private sealed class HelperImpl : Helper
public readonly FieldSet Fields;
- public HelperImpl(ParseStats stats, Options flags, char[] seps, int inputSize, int srcNeeded)
+ public HelperImpl(ParseStats stats, OptionFlags flags, char[] seps, int inputSize, int srcNeeded)
{
Contracts.AssertValue(stats);
// inputSize == 0 means unknown.
@@ -899,8 +899,8 @@ public HelperImpl(ParseStats stats, Options flags, char[] seps, int inputSize, i
_sepContainsSpace = IsSep(' ');
_inputSize = inputSize;
_srcNeeded = srcNeeded;
- _quoting = (flags & Options.AllowQuoting) != 0;
- _sparse = (flags & Options.AllowSparse) != 0;
+ _quoting = (flags & OptionFlags.AllowQuoting) != 0;
+ _sparse = (flags & OptionFlags.AllowSparse) != 0;
_sb = new StringBuilder();
_blank = ReadOnlyMemory.Empty;
Fields = new FieldSet();
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs
index f081220f41..7b5e684b46 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs
@@ -30,12 +30,12 @@ public static TextLoader CreateTextLoader(this DataOperationsCatalog catalog,
/// Create a text loader .
///
/// The catalog.
- /// Defines the settings of the load operation.
+ /// Defines the settings of the load operation.
/// The optional location of a data sample. The sample can be used to infer column names and number of slots in each column.
public static TextLoader CreateTextLoader(this DataOperationsCatalog catalog,
- TextLoader.Arguments args,
+ TextLoader.Options options,
IMultiStreamSource dataSample = null)
- => new TextLoader(CatalogUtils.GetEnvironment(catalog), args, dataSample);
+ => new TextLoader(CatalogUtils.GetEnvironment(catalog), options, dataSample);
///
/// Create a text loader by inferencing the dataset schema from a data model type.
@@ -123,15 +123,15 @@ public static IDataView ReadFromTextFile(this DataOperationsCatalog cata
///
/// The catalog.
/// Specifies a file from which to read.
- /// Defines the settings of the load operation.
- public static IDataView ReadFromTextFile(this DataOperationsCatalog catalog, string path, TextLoader.Arguments args = null)
+ /// Defines the settings of the load operation.
+ public static IDataView ReadFromTextFile(this DataOperationsCatalog catalog, string path, TextLoader.Options options = null)
{
Contracts.CheckNonEmpty(path, nameof(path));
var env = catalog.GetEnvironment();
var source = new MultiFileSource(path);
- return new TextLoader(env, args, source).Read(source);
+ return new TextLoader(env, options, source).Read(source);
}
///
diff --git a/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs b/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs
index 26a58b5b1a..6a53a5d2ce 100644
--- a/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs
+++ b/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs
@@ -9,11 +9,11 @@
using Microsoft.ML.Internal.Internallearn;
using Float = System.Single;
-[assembly: LoadableClass(typeof(TolerantEarlyStoppingCriterion), typeof(TolerantEarlyStoppingCriterion.Arguments), typeof(SignatureEarlyStoppingCriterion), "Tolerant (TR)", "tr")]
-[assembly: LoadableClass(typeof(GLEarlyStoppingCriterion), typeof(GLEarlyStoppingCriterion.Arguments), typeof(SignatureEarlyStoppingCriterion), "Loss of Generality (GL)", "gl")]
-[assembly: LoadableClass(typeof(LPEarlyStoppingCriterion), typeof(LPEarlyStoppingCriterion.Arguments), typeof(SignatureEarlyStoppingCriterion), "Low Progress (LP)", "lp")]
-[assembly: LoadableClass(typeof(PQEarlyStoppingCriterion), typeof(PQEarlyStoppingCriterion.Arguments), typeof(SignatureEarlyStoppingCriterion), "Generality to Progress Ratio (PQ)", "pq")]
-[assembly: LoadableClass(typeof(UPEarlyStoppingCriterion), typeof(UPEarlyStoppingCriterion.Arguments), typeof(SignatureEarlyStoppingCriterion), "Consecutive Loss in Generality (UP)", "up")]
+[assembly: LoadableClass(typeof(TolerantEarlyStoppingCriterion), typeof(TolerantEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Tolerant (TR)", "tr")]
+[assembly: LoadableClass(typeof(GLEarlyStoppingCriterion), typeof(GLEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Loss of Generality (GL)", "gl")]
+[assembly: LoadableClass(typeof(LPEarlyStoppingCriterion), typeof(LPEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Low Progress (LP)", "lp")]
+[assembly: LoadableClass(typeof(PQEarlyStoppingCriterion), typeof(PQEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Generality to Progress Ratio (PQ)", "pq")]
+[assembly: LoadableClass(typeof(UPEarlyStoppingCriterion), typeof(UPEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Consecutive Loss in Generality (UP)", "up")]
[assembly: EntryPointModule(typeof(TolerantEarlyStoppingCriterion))]
[assembly: EntryPointModule(typeof(GLEarlyStoppingCriterion))]
@@ -44,14 +44,14 @@ public interface IEarlyStoppingCriterionFactory : IComponentFactory : IEarlyStoppingCriterion
- where TArguments : EarlyStoppingCriterion.ArgumentsBase
+ public abstract class EarlyStoppingCriterion : IEarlyStoppingCriterion
+ where TOptions : EarlyStoppingCriterion.OptionsBase
{
- public abstract class ArgumentsBase { }
+ public abstract class OptionsBase { }
private Float _bestScore;
- protected readonly TArguments Args;
+ protected readonly TOptions EarlyStoppingCriterionOptions;
protected readonly bool LowerIsBetter;
protected Float BestScore {
get { return _bestScore; }
@@ -62,9 +62,9 @@ protected Float BestScore {
}
}
- internal EarlyStoppingCriterion(TArguments args, bool lowerIsBetter)
+ internal EarlyStoppingCriterion(TOptions options, bool lowerIsBetter)
{
- Args = args;
+ EarlyStoppingCriterionOptions = options;
LowerIsBetter = lowerIsBetter;
_bestScore = LowerIsBetter ? Float.PositiveInfinity : Float.NegativeInfinity;
}
@@ -86,10 +86,10 @@ protected bool CheckBestScore(Float score)
}
}
- public sealed class TolerantEarlyStoppingCriterion : EarlyStoppingCriterion
+ public sealed class TolerantEarlyStoppingCriterion : EarlyStoppingCriterion
{
[TlcModule.Component(FriendlyName = "Tolerant (TR)", Name = "TR", Desc = "Stop if validation score exceeds threshold value.")]
- public class Arguments : ArgumentsBase, IEarlyStoppingCriterionFactory
+ public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Tolerance threshold. (Non negative value)", ShortName = "th")]
[TlcModule.Range(Min = 0.0f)]
@@ -101,10 +101,10 @@ public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerI
}
}
- public TolerantEarlyStoppingCriterion(Arguments args, bool lowerIsBetter)
+ public TolerantEarlyStoppingCriterion(Options args, bool lowerIsBetter)
: base(args, lowerIsBetter)
{
- Contracts.CheckUserArg(Args.Threshold >= 0, nameof(args.Threshold), "Must be non-negative.");
+ Contracts.CheckUserArg(EarlyStoppingCriterionOptions.Threshold >= 0, nameof(args.Threshold), "Must be non-negative.");
}
public override bool CheckScore(Float validationScore, Float trainingScore, out bool isBestCandidate)
@@ -114,9 +114,9 @@ public override bool CheckScore(Float validationScore, Float trainingScore, out
isBestCandidate = CheckBestScore(validationScore);
if (LowerIsBetter)
- return (validationScore - BestScore > Args.Threshold);
+ return (validationScore - BestScore > EarlyStoppingCriterionOptions.Threshold);
else
- return (BestScore - validationScore > Args.Threshold);
+ return (BestScore - validationScore > EarlyStoppingCriterionOptions.Threshold);
}
}
@@ -124,9 +124,9 @@ public override bool CheckScore(Float validationScore, Float trainingScore, out
// Lodwich, Aleksander, Yves Rangoni, and Thomas Breuel. "Evaluation of robustness and performance of early stopping rules with multi layer perceptrons."
// Neural Networks, 2009. IJCNN 2009. International Joint Conference on. IEEE, 2009.
- public abstract class MovingWindowEarlyStoppingCriterion : EarlyStoppingCriterion
+ public abstract class MovingWindowEarlyStoppingCriterion : EarlyStoppingCriterion
{
- public class Arguments : ArgumentsBase
+ public class Options : OptionsBase
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Threshold in range [0,1].", ShortName = "th")]
[TlcModule.Range(Min = 0.0f, Max = 1.0f)]
@@ -139,13 +139,13 @@ public class Arguments : ArgumentsBase
protected Queue PastScores;
- private protected MovingWindowEarlyStoppingCriterion(Arguments args, bool lowerIsBetter)
+ private protected MovingWindowEarlyStoppingCriterion(Options args, bool lowerIsBetter)
: base(args, lowerIsBetter)
{
- Contracts.CheckUserArg(0 <= Args.Threshold && args.Threshold <= 1, nameof(args.Threshold), "Must be in range [0,1].");
- Contracts.CheckUserArg(Args.WindowSize > 0, nameof(args.WindowSize), "Must be positive.");
+ Contracts.CheckUserArg(0 <= EarlyStoppingCriterionOptions.Threshold && args.Threshold <= 1, nameof(args.Threshold), "Must be in range [0,1].");
+ Contracts.CheckUserArg(EarlyStoppingCriterionOptions.WindowSize > 0, nameof(args.WindowSize), "Must be positive.");
- PastScores = new Queue(Args.WindowSize);
+ PastScores = new Queue(EarlyStoppingCriterionOptions.WindowSize);
}
///
@@ -203,11 +203,11 @@ protected bool CheckRecentScores(Float score, int windowSize, out Float recentBe
///
/// Loss of Generality (GL).
///
- public sealed class GLEarlyStoppingCriterion : EarlyStoppingCriterion
+ public sealed class GLEarlyStoppingCriterion : EarlyStoppingCriterion
{
[TlcModule.Component(FriendlyName = "Loss of Generality (GL)", Name = "GL",
Desc = "Stop in case of loss of generality.")]
- public class Arguments : ArgumentsBase, IEarlyStoppingCriterionFactory
+ public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Threshold in range [0,1].", ShortName = "th")]
[TlcModule.Range(Min = 0.0f, Max = 1.0f)]
@@ -219,10 +219,10 @@ public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerI
}
}
- public GLEarlyStoppingCriterion(Arguments args, bool lowerIsBetter)
- : base(args, lowerIsBetter)
+ public GLEarlyStoppingCriterion(Options options, bool lowerIsBetter)
+ : base(options, lowerIsBetter)
{
- Contracts.CheckUserArg(0 <= Args.Threshold && args.Threshold <= 1, nameof(args.Threshold), "Must be in range [0,1].");
+ Contracts.CheckUserArg(0 <= EarlyStoppingCriterionOptions.Threshold && options.Threshold <= 1, nameof(options.Threshold), "Must be in range [0,1].");
}
public override bool CheckScore(Float validationScore, Float trainingScore, out bool isBestCandidate)
@@ -232,9 +232,9 @@ public override bool CheckScore(Float validationScore, Float trainingScore, out
isBestCandidate = CheckBestScore(validationScore);
if (LowerIsBetter)
- return (validationScore > (1 + Args.Threshold) * BestScore);
+ return (validationScore > (1 + EarlyStoppingCriterionOptions.Threshold) * BestScore);
else
- return (validationScore < (1 - Args.Threshold) * BestScore);
+ return (validationScore < (1 - EarlyStoppingCriterionOptions.Threshold) * BestScore);
}
}
@@ -245,7 +245,7 @@ public override bool CheckScore(Float validationScore, Float trainingScore, out
public sealed class LPEarlyStoppingCriterion : MovingWindowEarlyStoppingCriterion
{
[TlcModule.Component(FriendlyName = "Low Progress (LP)", Name = "LP", Desc = "Stops in case of low progress.")]
- public new sealed class Arguments : MovingWindowEarlyStoppingCriterion.Arguments, IEarlyStoppingCriterionFactory
+ public new sealed class Options : MovingWindowEarlyStoppingCriterion.Options, IEarlyStoppingCriterionFactory
{
public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter)
{
@@ -253,7 +253,7 @@ public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerI
}
}
- public LPEarlyStoppingCriterion(Arguments args, bool lowerIsBetter)
+ public LPEarlyStoppingCriterion(Options args, bool lowerIsBetter)
: base(args, lowerIsBetter) { }
public override bool CheckScore(Float validationScore, Float trainingScore, out bool isBestCandidate)
@@ -265,12 +265,12 @@ public override bool CheckScore(Float validationScore, Float trainingScore, out
Float recentBest;
Float recentAverage;
- if (CheckRecentScores(trainingScore, Args.WindowSize, out recentBest, out recentAverage))
+ if (CheckRecentScores(trainingScore, EarlyStoppingCriterionOptions.WindowSize, out recentBest, out recentAverage))
{
if (LowerIsBetter)
- return (recentAverage <= (1 + Args.Threshold) * recentBest);
+ return (recentAverage <= (1 + EarlyStoppingCriterionOptions.Threshold) * recentBest);
else
- return (recentAverage >= (1 - Args.Threshold) * recentBest);
+ return (recentAverage >= (1 - EarlyStoppingCriterionOptions.Threshold) * recentBest);
}
return false;
@@ -283,7 +283,7 @@ public override bool CheckScore(Float validationScore, Float trainingScore, out
public sealed class PQEarlyStoppingCriterion : MovingWindowEarlyStoppingCriterion
{
[TlcModule.Component(FriendlyName = "Generality to Progress Ratio (PQ)", Name = "PQ", Desc = "Stops in case of generality to progress ration exceeds threshold.")]
- public new sealed class Arguments : MovingWindowEarlyStoppingCriterion.Arguments, IEarlyStoppingCriterionFactory
+ public new sealed class Options : MovingWindowEarlyStoppingCriterion.Options, IEarlyStoppingCriterionFactory
{
public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter)
{
@@ -291,7 +291,7 @@ public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerI
}
}
- public PQEarlyStoppingCriterion(Arguments args, bool lowerIsBetter)
+ public PQEarlyStoppingCriterion(Options args, bool lowerIsBetter)
: base(args, lowerIsBetter) { }
public override bool CheckScore(Float validationScore, Float trainingScore, out bool isBestCandidate)
@@ -303,12 +303,12 @@ public override bool CheckScore(Float validationScore, Float trainingScore, out
Float recentBest;
Float recentAverage;
- if (CheckRecentScores(trainingScore, Args.WindowSize, out recentBest, out recentAverage))
+ if (CheckRecentScores(trainingScore, EarlyStoppingCriterionOptions.WindowSize, out recentBest, out recentAverage))
{
if (LowerIsBetter)
- return (validationScore * recentBest >= (1 + Args.Threshold) * BestScore * recentAverage);
+ return (validationScore * recentBest >= (1 + EarlyStoppingCriterionOptions.Threshold) * BestScore * recentAverage);
else
- return (validationScore * recentBest <= (1 - Args.Threshold) * BestScore * recentAverage);
+ return (validationScore * recentBest <= (1 - EarlyStoppingCriterionOptions.Threshold) * BestScore * recentAverage);
}
return false;
@@ -318,11 +318,11 @@ public override bool CheckScore(Float validationScore, Float trainingScore, out
///
/// Consecutive Loss in Generality (UP).
///
- public sealed class UPEarlyStoppingCriterion : EarlyStoppingCriterion
+ public sealed class UPEarlyStoppingCriterion : EarlyStoppingCriterion
{
[TlcModule.Component(FriendlyName = "Consecutive Loss in Generality (UP)", Name = "UP",
Desc = "Stops in case of consecutive loss in generality.")]
- public sealed class Arguments : ArgumentsBase, IEarlyStoppingCriterionFactory
+ public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory
{
[Argument(ArgumentType.AtMostOnce, HelpText = "The window size.", ShortName = "w")]
[TlcModule.Range(Inf = 0)]
@@ -337,10 +337,10 @@ public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerI
private int _count;
private Float _prevScore;
- public UPEarlyStoppingCriterion(Arguments args, bool lowerIsBetter)
- : base(args, lowerIsBetter)
+ public UPEarlyStoppingCriterion(Options options, bool lowerIsBetter)
+ : base(options, lowerIsBetter)
{
- Contracts.CheckUserArg(Args.WindowSize > 0, nameof(args.WindowSize), "Must be positive");
+ Contracts.CheckUserArg(EarlyStoppingCriterionOptions.WindowSize > 0, nameof(options.WindowSize), "Must be positive");
_prevScore = LowerIsBetter ? Float.PositiveInfinity : Float.NegativeInfinity;
}
@@ -354,7 +354,7 @@ public override bool CheckScore(Float validationScore, Float trainingScore, out
_count = ((validationScore < _prevScore) != LowerIsBetter) ? _count + 1 : 0;
_prevScore = validationScore;
- return (_count >= Args.WindowSize);
+ return (_count >= EarlyStoppingCriterionOptions.WindowSize);
}
}
}
diff --git a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs
index 2321fcaaf5..864cceb128 100644
--- a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs
@@ -88,14 +88,14 @@ public static LabelIndicatorTransform Create(IHostEnvironment env,
}
public static LabelIndicatorTransform Create(IHostEnvironment env,
- Options args, IDataView input)
+ Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
IHost h = env.Register(LoaderSignature);
- h.CheckValue(args, nameof(args));
+ h.CheckValue(options, nameof(options));
h.CheckValue(input, nameof(input));
return h.Apply("Loading Model",
- ch => new LabelIndicatorTransform(h, args, input));
+ ch => new LabelIndicatorTransform(h, options, input));
}
private protected override void SaveModel(ModelSaveContext ctx)
@@ -131,16 +131,16 @@ public LabelIndicatorTransform(IHostEnvironment env,
{
}
- public LabelIndicatorTransform(IHostEnvironment env, Options args, IDataView input)
- : base(env, LoadName, Contracts.CheckRef(args, nameof(args)).Columns,
+ public LabelIndicatorTransform(IHostEnvironment env, Options options, IDataView input)
+ : base(env, LoadName, Contracts.CheckRef(options, nameof(options)).Columns,
input, TestIsMulticlassLabel)
{
Host.AssertNonEmpty(Infos);
- Host.Assert(Infos.Length == Utils.Size(args.Columns));
+ Host.Assert(Infos.Length == Utils.Size(options.Columns));
_classIndex = new int[Infos.Length];
for (int iinfo = 0; iinfo < Infos.Length; ++iinfo)
- _classIndex[iinfo] = args.Columns[iinfo].ClassIndex ?? args.ClassIndex;
+ _classIndex[iinfo] = options.Columns[iinfo].ClassIndex ?? options.ClassIndex;
Metadata.Seal();
}
diff --git a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs
index 2431667c04..70b69e574b 100644
--- a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs
+++ b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs
@@ -107,21 +107,21 @@ public RowShufflingTransformer(IHostEnvironment env,
///
/// Constructor corresponding to SignatureDataTransform.
///
- public RowShufflingTransformer(IHostEnvironment env, Options args, IDataView input)
+ public RowShufflingTransformer(IHostEnvironment env, Options options, IDataView input)
: base(env, RegistrationName, input)
{
- Host.CheckValue(args, nameof(args));
+ Host.CheckValue(options, nameof(options));
- Host.CheckUserArg(args.PoolRows > 0, nameof(args.PoolRows), "pool size must be positive");
- _poolRows = args.PoolRows;
- _poolOnly = args.PoolOnly;
- _forceShuffle = args.ForceShuffle;
- _forceShuffleSource = args.ForceShuffleSource ?? (!_poolOnly && _forceShuffle);
+ Host.CheckUserArg(options.PoolRows > 0, nameof(options.PoolRows), "pool size must be positive");
+ _poolRows = options.PoolRows;
+ _poolOnly = options.PoolOnly;
+ _forceShuffle = options.ForceShuffle;
+ _forceShuffleSource = options.ForceShuffleSource ?? (!_poolOnly && _forceShuffle);
Host.CheckUserArg(!(_poolOnly && _forceShuffleSource),
- nameof(args.ForceShuffleSource), "Cannot set both poolOnly and forceShuffleSource");
+ nameof(options.ForceShuffleSource), "Cannot set both poolOnly and forceShuffleSource");
if (_forceShuffle || _forceShuffleSource)
- _forceShuffleSeed = args.ForceShuffleSeed ?? Host.Rand.NextSigned();
+ _forceShuffleSeed = options.ForceShuffleSeed ?? Host.Rand.NextSigned();
_subsetInput = SelectCachableColumns(input, env);
}
diff --git a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs
index 7e71763125..b79610e53e 100644
--- a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs
+++ b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs
@@ -600,7 +600,7 @@ private static IDataTransform Create(IHostEnvironment env, Options options, IDat
{
keyColumn = new TextLoader.Column(keyColumnName, DataKind.TXT, 0);
valueColumn = new TextLoader.Column(valueColumnName, DataKind.TXT, 1);
- var txtArgs = new TextLoader.Arguments()
+ var txtArgs = new TextLoader.Options()
{
Columns = new TextLoader.Column[]
{
@@ -627,7 +627,7 @@ private static IDataTransform Create(IHostEnvironment env, Options options, IDat
loader = TextLoader.Create(
env,
- new TextLoader.Arguments()
+ new TextLoader.Options()
{
Columns = new TextLoader.Column[]
{
diff --git a/src/Microsoft.ML.Data/Utils/LossFunctions.cs b/src/Microsoft.ML.Data/Utils/LossFunctions.cs
index ed3ad17765..bf3eec4af7 100644
--- a/src/Microsoft.ML.Data/Utils/LossFunctions.cs
+++ b/src/Microsoft.ML.Data/Utils/LossFunctions.cs
@@ -12,13 +12,13 @@
[assembly: LoadableClass(LogLoss.Summary, typeof(LogLoss), null, typeof(SignatureClassificationLoss),
"Log Loss", "LogLoss", "Logistic", "CrossEntropy")]
-[assembly: LoadableClass(HingeLoss.Summary, typeof(HingeLoss), typeof(HingeLoss.Arguments), typeof(SignatureClassificationLoss),
+[assembly: LoadableClass(HingeLoss.Summary, typeof(HingeLoss), typeof(HingeLoss.Options), typeof(SignatureClassificationLoss),
"Hinge Loss", "HingeLoss", "Hinge")]
-[assembly: LoadableClass(SmoothedHingeLoss.Summary, typeof(SmoothedHingeLoss), typeof(SmoothedHingeLoss.Arguments), typeof(SignatureClassificationLoss),
+[assembly: LoadableClass(SmoothedHingeLoss.Summary, typeof(SmoothedHingeLoss), typeof(SmoothedHingeLoss.Options), typeof(SignatureClassificationLoss),
"Smoothed Hinge Loss", "SmoothedHingeLoss", "SmoothedHinge")]
-[assembly: LoadableClass(ExpLoss.Summary, typeof(ExpLoss), typeof(ExpLoss.Arguments), typeof(SignatureClassificationLoss),
+[assembly: LoadableClass(ExpLoss.Summary, typeof(ExpLoss), typeof(ExpLoss.Options), typeof(SignatureClassificationLoss),
"Exponential Loss", "ExpLoss", "Exp")]
[assembly: LoadableClass(SquaredLoss.Summary, typeof(SquaredLoss), null, typeof(SignatureRegressionLoss),
@@ -27,16 +27,16 @@
[assembly: LoadableClass(PoissonLoss.Summary, typeof(PoissonLoss), null, typeof(SignatureRegressionLoss),
"Poisson Loss", "PoissonLoss", "Poisson")]
-[assembly: LoadableClass(TweedieLoss.Summary, typeof(TweedieLoss), typeof(TweedieLoss.Arguments), typeof(SignatureRegressionLoss),
+[assembly: LoadableClass(TweedieLoss.Summary, typeof(TweedieLoss), typeof(TweedieLoss.Options), typeof(SignatureRegressionLoss),
"Tweedie Loss", "TweedieLoss", "Tweedie", "Tw")]
-[assembly: EntryPointModule(typeof(ExpLoss.Arguments))]
+[assembly: EntryPointModule(typeof(ExpLoss.Options))]
[assembly: EntryPointModule(typeof(LogLossFactory))]
-[assembly: EntryPointModule(typeof(HingeLoss.Arguments))]
+[assembly: EntryPointModule(typeof(HingeLoss.Options))]
[assembly: EntryPointModule(typeof(PoissonLossFactory))]
-[assembly: EntryPointModule(typeof(SmoothedHingeLoss.Arguments))]
+[assembly: EntryPointModule(typeof(SmoothedHingeLoss.Options))]
[assembly: EntryPointModule(typeof(SquaredLossFactory))]
-[assembly: EntryPointModule(typeof(TweedieLoss.Arguments))]
+[assembly: EntryPointModule(typeof(TweedieLoss.Options))]
namespace Microsoft.ML
{
@@ -162,7 +162,7 @@ private static Double Log(Double x)
public sealed class HingeLoss : ISupportSdcaClassificationLoss
{
[TlcModule.Component(Name = "HingeLoss", FriendlyName = "Hinge loss", Alias = "Hinge", Desc = "Hinge loss.")]
- public sealed class Arguments : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory
+ public sealed class Options : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Margin value", ShortName = "marg")]
public Float Margin = Defaults.Margin;
@@ -176,9 +176,9 @@ public sealed class Arguments : ISupportSdcaClassificationLossFactory, ISupportC
private const Float Threshold = 0.5f;
private readonly Float _margin;
- internal HingeLoss(Arguments args)
+ internal HingeLoss(Options options)
{
- _margin = args.Margin;
+ _margin = options.Margin;
}
private static class Defaults
@@ -187,7 +187,7 @@ private static class Defaults
}
public HingeLoss(float margin = Defaults.Margin)
- : this(new Arguments() { Margin = margin })
+ : this(new Options() { Margin = margin })
{
}
@@ -234,7 +234,7 @@ public sealed class SmoothedHingeLoss : ISupportSdcaClassificationLoss
{
[TlcModule.Component(Name = "SmoothedHingeLoss", FriendlyName = "Smoothed Hinge Loss", Alias = "SmoothedHinge",
Desc = "Smoothed Hinge loss.")]
- public sealed class Arguments : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory
+ public sealed class Options : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Smoothing constant", ShortName = "smooth")]
public Float SmoothingConst = Defaults.SmoothingConst;
@@ -268,8 +268,8 @@ public SmoothedHingeLoss(float smoothingConstant = Defaults.SmoothingConst)
_doubleSmoothConst = _smoothConst * 2;
}
- private SmoothedHingeLoss(IHostEnvironment env, Arguments args)
- : this(args.SmoothingConst)
+ private SmoothedHingeLoss(IHostEnvironment env, Options options)
+ : this(options.SmoothingConst)
{
}
@@ -333,7 +333,7 @@ public Double DualLoss(Float label, Double dual)
public sealed class ExpLoss : IClassificationLoss
{
[TlcModule.Component(Name = "ExpLoss", FriendlyName = "Exponential Loss", Desc = "Exponential loss.")]
- public sealed class Arguments : ISupportClassificationLossFactory
+ public sealed class Options : ISupportClassificationLossFactory
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Beta (dilation)", ShortName = "beta")]
public Float Beta = 1;
@@ -345,9 +345,9 @@ public sealed class Arguments : ISupportClassificationLossFactory
private readonly Float _beta;
- public ExpLoss(Arguments args)
+ public ExpLoss(Options options)
{
- _beta = args.Beta;
+ _beta = options.Beta;
}
public Double Loss(Float output, Float label)
@@ -438,7 +438,7 @@ public Float Derivative(Float output, Float label)
public sealed class TweedieLoss : IRegressionLoss
{
[TlcModule.Component(Name = "TweedieLoss", FriendlyName = "Tweedie Loss", Alias = "tweedie", Desc = "Tweedie loss.")]
- public sealed class Arguments : ISupportRegressionLossFactory
+ public sealed class Options : ISupportRegressionLossFactory
{
[Argument(ArgumentType.LastOccurenceWins, HelpText =
"Index parameter for the Tweedie distribution, in the range [1, 2]. 1 is Poisson loss, 2 is gamma loss, " +
@@ -454,10 +454,10 @@ public sealed class Arguments : ISupportRegressionLossFactory
private readonly Double _index1; // 1 minus the index parameter.
private readonly Double _index2; // 2 minus the index parameter.
- public TweedieLoss(Arguments args)
+ public TweedieLoss(Options options)
{
- Contracts.CheckUserArg(1 <= args.Index && args.Index <= 2, nameof(args.Index), "Must be in the range [1, 2]");
- _index = args.Index;
+ Contracts.CheckUserArg(1 <= options.Index && options.Index <= 2, nameof(options.Index), "Must be in the range [1, 2]");
+ _index = options.Index;
_index1 = 1 - _index;
_index2 = 2 - _index;
}
diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs b/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs
index dd9eaf41b6..67bd565193 100644
--- a/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs
+++ b/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs
@@ -258,10 +258,10 @@ public static CommonOutputs.MulticlassClassificationOutput CreateMultiClassPipel
switch (input.ModelCombiner)
{
case ClassifierCombiner.Median:
- combiner = new MultiMedian(host, new MultiMedian.Arguments() { Normalize = true });
+ combiner = new MultiMedian(host, new MultiMedian.Options() { Normalize = true });
break;
case ClassifierCombiner.Average:
- combiner = new MultiAverage(host, new MultiAverage.Arguments() { Normalize = true });
+ combiner = new MultiAverage(host, new MultiAverage.Options() { Normalize = true });
break;
case ClassifierCombiner.Vote:
combiner = new MultiVoting(host);
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs
index 044ccfb44e..053c8d03ba 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs
@@ -12,7 +12,7 @@ namespace Microsoft.ML.Trainers.Ensemble
{
public abstract class BaseMultiAverager : BaseMultiCombiner
{
- internal BaseMultiAverager(IHostEnvironment env, string name, ArgumentsBase args)
+ internal BaseMultiAverager(IHostEnvironment env, string name, OptionsBase args)
: base(env, name, args)
{
}
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs
index d0ee4583b3..877f08d78a 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs
@@ -15,7 +15,7 @@ public abstract class BaseMultiCombiner : IMultiClassOutputCombiner, ICanSaveMod
{
protected readonly IHost Host;
- public abstract class ArgumentsBase
+ public abstract class OptionsBase
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to normalize the output of base models before combining them",
ShortName = "norm", SortOrder = 50)]
@@ -24,7 +24,7 @@ public abstract class ArgumentsBase
protected readonly bool Normalize;
- internal BaseMultiCombiner(IHostEnvironment env, string name, ArgumentsBase args)
+ internal BaseMultiCombiner(IHostEnvironment env, string name, OptionsBase args)
{
Contracts.AssertValue(env);
env.AssertNonWhiteSpace(name);
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs
index 3cf424b50f..0bfc938fdf 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs
@@ -9,7 +9,7 @@
using Microsoft.ML.Model;
using Microsoft.ML.Trainers.Ensemble;
-[assembly: LoadableClass(typeof(MultiAverage), typeof(MultiAverage.Arguments), typeof(SignatureCombiner),
+[assembly: LoadableClass(typeof(MultiAverage), typeof(MultiAverage.Options), typeof(SignatureCombiner),
Average.UserName, MultiAverage.LoadName)]
[assembly: LoadableClass(typeof(MultiAverage), null, typeof(SignatureLoadModel), Average.UserName,
MultiAverage.LoadName, MultiAverage.LoaderSignature)]
@@ -33,13 +33,13 @@ private static VersionInfo GetVersionInfo()
}
[TlcModule.Component(Name = LoadName, FriendlyName = Average.UserName)]
- public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerFactory
+ public sealed class Options : OptionsBase, ISupportMulticlassOutputCombinerFactory
{
public IMultiClassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiAverage(env, this);
}
- public MultiAverage(IHostEnvironment env, Arguments args)
- : base(env, LoaderSignature, args)
+ public MultiAverage(IHostEnvironment env, Options options)
+ : base(env, LoaderSignature, options)
{
}
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs
index 3b44413397..f7a66c3dd7 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs
@@ -10,7 +10,7 @@
using Microsoft.ML.Model;
using Microsoft.ML.Trainers.Ensemble;
-[assembly: LoadableClass(typeof(MultiMedian), typeof(MultiMedian.Arguments), typeof(SignatureCombiner),
+[assembly: LoadableClass(typeof(MultiMedian), typeof(MultiMedian.Options), typeof(SignatureCombiner),
Median.UserName, MultiMedian.LoadName)]
[assembly: LoadableClass(typeof(MultiMedian), null, typeof(SignatureLoadModel), Median.UserName, MultiMedian.LoaderSignature)]
@@ -36,13 +36,13 @@ private static VersionInfo GetVersionInfo()
}
[TlcModule.Component(Name = LoadName, FriendlyName = Median.UserName)]
- public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerFactory
+ public sealed class Options : OptionsBase, ISupportMulticlassOutputCombinerFactory
{
public IMultiClassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiMedian(env, this);
}
- public MultiMedian(IHostEnvironment env, Arguments args)
- : base(env, LoaderSignature, args)
+ public MultiMedian(IHostEnvironment env, Options options)
+ : base(env, LoaderSignature, options)
{
}
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs
index f943a4176c..917071e42a 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs
@@ -33,7 +33,7 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(MultiVoting).Assembly.FullName);
}
- private sealed class Arguments : ArgumentsBase
+ private sealed class Arguments : OptionsBase
{
}
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs
index aecc3963bc..96f05371fa 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs
@@ -11,7 +11,7 @@
using Microsoft.ML.Model;
using Microsoft.ML.Trainers.Ensemble;
-[assembly: LoadableClass(typeof(MultiWeightedAverage), typeof(MultiWeightedAverage.Arguments), typeof(SignatureCombiner),
+[assembly: LoadableClass(typeof(MultiWeightedAverage), typeof(MultiWeightedAverage.Options), typeof(SignatureCombiner),
MultiWeightedAverage.UserName, MultiWeightedAverage.LoadName)]
[assembly: LoadableClass(typeof(MultiWeightedAverage), null, typeof(SignatureLoadModel),
@@ -40,7 +40,7 @@ private static VersionInfo GetVersionInfo()
}
[TlcModule.Component(Name = LoadName, FriendlyName = UserName)]
- public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerFactory
+ public sealed class Options : OptionsBase, ISupportMulticlassOutputCombinerFactory
{
IMultiClassOutputCombiner IComponentFactory.CreateComponent(IHostEnvironment env) => new MultiWeightedAverage(env, this);
@@ -52,11 +52,11 @@ public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerF
private readonly MultiWeightageKind _weightageKind;
public string WeightageMetricName { get { return _weightageKind.ToString(); } }
- public MultiWeightedAverage(IHostEnvironment env, Arguments args)
- : base(env, LoaderSignature, args)
+ public MultiWeightedAverage(IHostEnvironment env, Options options)
+ : base(env, LoaderSignature, options)
{
- _weightageKind = args.WeightageName;
- Host.CheckUserArg(Enum.IsDefined(typeof(MultiWeightageKind), _weightageKind), nameof(args.WeightageName));
+ _weightageKind = options.WeightageName;
+ Host.CheckUserArg(Enum.IsDefined(typeof(MultiWeightageKind), _weightageKind), nameof(options.WeightageName));
}
private MultiWeightedAverage(IHostEnvironment env, ModelLoadContext ctx)
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs
index e5aa458768..04e6f802c1 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs
@@ -11,7 +11,7 @@
using Microsoft.ML.Model;
using Microsoft.ML.Trainers.Ensemble;
-[assembly: LoadableClass(typeof(WeightedAverage), typeof(WeightedAverage.Arguments), typeof(SignatureCombiner),
+[assembly: LoadableClass(typeof(WeightedAverage), typeof(WeightedAverage.Options), typeof(SignatureCombiner),
WeightedAverage.UserName, WeightedAverage.LoadName)]
[assembly: LoadableClass(typeof(WeightedAverage), null, typeof(SignatureLoadModel),
@@ -37,7 +37,7 @@ private static VersionInfo GetVersionInfo()
}
[TlcModule.Component(Name = LoadName, FriendlyName = UserName)]
- public sealed class Arguments: ISupportBinaryOutputCombinerFactory
+ public sealed class Options: ISupportBinaryOutputCombinerFactory
{
[Argument(ArgumentType.AtMostOnce, HelpText = "The metric type to be used to find the weights for each model", ShortName = "wn", SortOrder = 50)]
[TGUI(Label = "Weightage Name", Description = "The weights are calculated according to the selected metric")]
@@ -50,11 +50,11 @@ public sealed class Arguments: ISupportBinaryOutputCombinerFactory
public string WeightageMetricName { get { return _weightageKind.ToString(); } }
- public WeightedAverage(IHostEnvironment env, Arguments args)
+ public WeightedAverage(IHostEnvironment env, Options options)
: base(env, LoaderSignature)
{
- _weightageKind = args.WeightageName;
- Host.CheckUserArg(Enum.IsDefined(typeof(WeightageKind), _weightageKind), nameof(args.WeightageName));
+ _weightageKind = options.WeightageName;
+ Host.CheckUserArg(Enum.IsDefined(typeof(WeightageKind), _weightageKind), nameof(options.WeightageName));
}
private WeightedAverage(IHostEnvironment env, ModelLoadContext ctx)
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
index 4420b15ba9..0ee01dfed4 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
@@ -44,7 +44,7 @@ public sealed class Arguments : ArgumentsBase
[Argument(ArgumentType.Multiple, HelpText = "Output combiner", ShortName = "oc", SortOrder = 5)]
[TGUI(Label = "Output combiner", Description = "Output combiner type")]
- public ISupportMulticlassOutputCombinerFactory OutputCombiner = new MultiMedian.Arguments();
+ public ISupportMulticlassOutputCombinerFactory OutputCombiner = new MultiMedian.Options();
// REVIEW: If we make this public again it should be an *estimator* of this type of predictor, rather than the (deprecated) ITrainer.
[Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureMultiClassClassifierTrainer))]
diff --git a/src/Microsoft.ML.EntryPoints/ImportTextData.cs b/src/Microsoft.ML.EntryPoints/ImportTextData.cs
index 09b878448c..50544ce16d 100644
--- a/src/Microsoft.ML.EntryPoints/ImportTextData.cs
+++ b/src/Microsoft.ML.EntryPoints/ImportTextData.cs
@@ -49,7 +49,7 @@ public sealed class LoaderInput
public IFileHandle InputFile;
[Argument(ArgumentType.Required, ShortName = "args", HelpText = "Arguments", SortOrder = 2)]
- public TextLoader.Arguments Arguments = new TextLoader.Arguments();
+ public TextLoader.Options Arguments = new TextLoader.Options();
}
[TlcModule.EntryPoint(Name = "Data.TextLoader", Desc = "Import a dataset from a text file")]
diff --git a/src/Microsoft.ML.FastTree/BoostingFastTree.cs b/src/Microsoft.ML.FastTree/BoostingFastTree.cs
index 13facf2757..d154195982 100644
--- a/src/Microsoft.ML.FastTree/BoostingFastTree.cs
+++ b/src/Microsoft.ML.FastTree/BoostingFastTree.cs
@@ -29,32 +29,32 @@ protected BoostingFastTreeTrainerBase(IHostEnvironment env,
double learningRate)
: base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDatapointsInLeaves)
{
- Args.LearningRates = learningRate;
+ FastTreeTrainerOptions.LearningRates = learningRate;
}
protected override void CheckArgs(IChannel ch)
{
- if (Args.OptimizationAlgorithm == BoostedTreeArgs.OptimizationAlgorithmType.AcceleratedGradientDescent)
- Args.UseLineSearch = true;
- if (Args.OptimizationAlgorithm == BoostedTreeArgs.OptimizationAlgorithmType.ConjugateGradientDescent)
- Args.UseLineSearch = true;
+ if (FastTreeTrainerOptions.OptimizationAlgorithm == BoostedTreeArgs.OptimizationAlgorithmType.AcceleratedGradientDescent)
+ FastTreeTrainerOptions.UseLineSearch = true;
+ if (FastTreeTrainerOptions.OptimizationAlgorithm == BoostedTreeArgs.OptimizationAlgorithmType.ConjugateGradientDescent)
+ FastTreeTrainerOptions.UseLineSearch = true;
- if (Args.CompressEnsemble && Args.WriteLastEnsemble)
+ if (FastTreeTrainerOptions.CompressEnsemble && FastTreeTrainerOptions.WriteLastEnsemble)
throw ch.Except("Ensemble compression cannot be done when forcing to write last ensemble (hl)");
- if (Args.NumLeaves > 2 && Args.HistogramPoolSize > Args.NumLeaves - 1)
+ if (FastTreeTrainerOptions.NumLeaves > 2 && FastTreeTrainerOptions.HistogramPoolSize > FastTreeTrainerOptions.NumLeaves - 1)
throw ch.Except("Histogram pool size (ps) must be at least 2.");
- if (Args.NumLeaves > 2 && Args.HistogramPoolSize > Args.NumLeaves - 1)
+ if (FastTreeTrainerOptions.NumLeaves > 2 && FastTreeTrainerOptions.HistogramPoolSize > FastTreeTrainerOptions.NumLeaves - 1)
throw ch.Except("Histogram pool size (ps) must be at most numLeaves - 1.");
- if (Args.EnablePruning && !HasValidSet)
+ if (FastTreeTrainerOptions.EnablePruning && !HasValidSet)
throw ch.Except("Cannot perform pruning (pruning) without a validation set (valid).");
- if (Args.EarlyStoppingRule != null && !HasValidSet)
+ if (FastTreeTrainerOptions.EarlyStoppingRule != null && !HasValidSet)
throw ch.Except("Cannot perform early stopping without a validation set (valid).");
- if (Args.UseTolerantPruning && (!Args.EnablePruning || !HasValidSet))
+ if (FastTreeTrainerOptions.UseTolerantPruning && (!FastTreeTrainerOptions.EnablePruning || !HasValidSet))
throw ch.Except("Cannot perform tolerant pruning (prtol) without pruning (pruning) and a validation set (valid)");
base.CheckArgs(ch);
@@ -63,12 +63,12 @@ protected override void CheckArgs(IChannel ch)
private protected override TreeLearner ConstructTreeLearner(IChannel ch)
{
return new LeastSquaresRegressionTreeLearner(
- TrainSet, Args.NumLeaves, Args.MinDocumentsInLeafs, Args.EntropyCoefficient,
- Args.FeatureFirstUsePenalty, Args.FeatureReusePenalty, Args.SoftmaxTemperature,
- Args.HistogramPoolSize, Args.RngSeed, Args.SplitFraction, Args.FilterZeroLambdas,
- Args.AllowEmptyTrees, Args.GainConfidenceLevel, Args.MaxCategoricalGroupsPerNode,
- Args.MaxCategoricalSplitPoints, BsrMaxTreeOutput(), ParallelTraining,
- Args.MinDocsPercentageForCategoricalSplit, Args.Bundling, Args.MinDocsForCategoricalSplit, Args.Bias);
+ TrainSet, FastTreeTrainerOptions.NumLeaves, FastTreeTrainerOptions.MinDocumentsInLeafs, FastTreeTrainerOptions.EntropyCoefficient,
+ FastTreeTrainerOptions.FeatureFirstUsePenalty, FastTreeTrainerOptions.FeatureReusePenalty, FastTreeTrainerOptions.SoftmaxTemperature,
+ FastTreeTrainerOptions.HistogramPoolSize, FastTreeTrainerOptions.RngSeed, FastTreeTrainerOptions.SplitFraction, FastTreeTrainerOptions.FilterZeroLambdas,
+ FastTreeTrainerOptions.AllowEmptyTrees, FastTreeTrainerOptions.GainConfidenceLevel, FastTreeTrainerOptions.MaxCategoricalGroupsPerNode,
+ FastTreeTrainerOptions.MaxCategoricalSplitPoints, BsrMaxTreeOutput(), ParallelTraining,
+ FastTreeTrainerOptions.MinDocsPercentageForCategoricalSplit, FastTreeTrainerOptions.Bundling, FastTreeTrainerOptions.MinDocsForCategoricalSplit, FastTreeTrainerOptions.Bias);
}
private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
@@ -77,7 +77,7 @@ private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(
OptimizationAlgorithm optimizationAlgorithm;
IGradientAdjuster gradientWrapper = MakeGradientWrapper(ch);
- switch (Args.OptimizationAlgorithm)
+ switch (FastTreeTrainerOptions.OptimizationAlgorithm)
{
case BoostedTreeArgs.OptimizationAlgorithmType.GradientDescent:
optimizationAlgorithm = new GradientDescent(Ensemble, TrainSet, InitTrainScores, gradientWrapper);
@@ -89,14 +89,14 @@ private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(
optimizationAlgorithm = new ConjugateGradientDescent(Ensemble, TrainSet, InitTrainScores, gradientWrapper);
break;
default:
- throw ch.Except("Unknown optimization algorithm '{0}'", Args.OptimizationAlgorithm);
+ throw ch.Except("Unknown optimization algorithm '{0}'", FastTreeTrainerOptions.OptimizationAlgorithm);
}
optimizationAlgorithm.TreeLearner = ConstructTreeLearner(ch);
optimizationAlgorithm.ObjectiveFunction = ConstructObjFunc(ch);
- optimizationAlgorithm.Smoothing = Args.Smoothing;
- optimizationAlgorithm.DropoutRate = Args.DropoutRate;
- optimizationAlgorithm.DropoutRng = new Random(Args.RngSeed);
+ optimizationAlgorithm.Smoothing = FastTreeTrainerOptions.Smoothing;
+ optimizationAlgorithm.DropoutRate = FastTreeTrainerOptions.DropoutRate;
+ optimizationAlgorithm.DropoutRng = new Random(FastTreeTrainerOptions.RngSeed);
optimizationAlgorithm.PreScoreUpdateEvent += PrintTestGraph;
return optimizationAlgorithm;
@@ -104,7 +104,7 @@ private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(
protected override IGradientAdjuster MakeGradientWrapper(IChannel ch)
{
- if (!Args.BestStepRankingRegressionTrees)
+ if (!FastTreeTrainerOptions.BestStepRankingRegressionTrees)
return base.MakeGradientWrapper(ch);
// REVIEW: If this is ranking specific than cmd.bestStepRankingRegressionTrees and
@@ -117,7 +117,7 @@ protected override IGradientAdjuster MakeGradientWrapper(IChannel ch)
protected override bool ShouldStop(IChannel ch, ref IEarlyStoppingCriterion earlyStoppingRule, ref int bestIteration)
{
- if (Args.EarlyStoppingRule == null)
+ if (FastTreeTrainerOptions.EarlyStoppingRule == null)
return false;
ch.AssertValue(ValidTest);
@@ -133,7 +133,7 @@ protected override bool ShouldStop(IChannel ch, ref IEarlyStoppingCriterion earl
// Create early stopping rule.
if (earlyStoppingRule == null)
{
- earlyStoppingRule = Args.EarlyStoppingRule.CreateComponent(Host, lowerIsBetter);
+ earlyStoppingRule = FastTreeTrainerOptions.EarlyStoppingRule.CreateComponent(Host, lowerIsBetter);
ch.Assert(earlyStoppingRule != null);
}
@@ -150,7 +150,7 @@ protected override bool ShouldStop(IChannel ch, ref IEarlyStoppingCriterion earl
protected override int GetBestIteration(IChannel ch)
{
int bestIteration = Ensemble.NumTrees;
- if (!Args.WriteLastEnsemble && PruningTest != null)
+ if (!FastTreeTrainerOptions.WriteLastEnsemble && PruningTest != null)
{
bestIteration = PruningTest.BestIteration;
ch.Info("Pruning picked iteration {0}", bestIteration);
@@ -163,15 +163,15 @@ protected override int GetBestIteration(IChannel ch)
///
protected double BsrMaxTreeOutput()
{
- if (Args.BestStepRankingRegressionTrees)
- return Args.MaxTreeOutput;
+ if (FastTreeTrainerOptions.BestStepRankingRegressionTrees)
+ return FastTreeTrainerOptions.MaxTreeOutput;
else
return -1;
}
protected override bool ShouldRandomStartOptimizer()
{
- return Args.RandomStart;
+ return FastTreeTrainerOptions.RandomStart;
}
}
}
diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs
index a3cd3cc550..b35acb44eb 100644
--- a/src/Microsoft.ML.FastTree/FastTree.cs
+++ b/src/Microsoft.ML.FastTree/FastTree.cs
@@ -46,13 +46,13 @@ internal static class FastTreeShared
public static readonly object TrainLock = new object();
}
- public abstract class FastTreeTrainerBase :
+ public abstract class FastTreeTrainerBase :
TrainerEstimatorBaseWithGroupId
where TTransformer : ISingleFeaturePredictionTransformer
- where TArgs : TreeArgs, new()
+ where TOptions : TreeOptions, new()
where TModel : class
{
- protected readonly TArgs Args;
+ protected readonly TOptions FastTreeTrainerOptions;
protected readonly bool AllowGC;
protected int FeatureCount;
private protected InternalTreeEnsemble TrainedEnsemble;
@@ -95,7 +95,7 @@ public abstract class FastTreeTrainerBase :
// random for active features selection
private Random _featureSelectionRandom;
- protected string InnerArgs => CmdParser.GetSettings(Host, Args, new TArgs());
+ protected string InnerArgs => CmdParser.GetSettings(Host, FastTreeTrainerOptions, new TOptions());
public override TrainerInfo Info { get; }
@@ -116,22 +116,22 @@ private protected FastTreeTrainerBase(IHostEnvironment env,
int minDatapointsInLeaves)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn))
{
- Args = new TArgs();
+ FastTreeTrainerOptions = new TOptions();
// set up the directly provided values
// override with the directly provided values.
- Args.NumLeaves = numLeaves;
- Args.NumTrees = numTrees;
- Args.MinDocumentsInLeafs = minDatapointsInLeaves;
+ FastTreeTrainerOptions.NumLeaves = numLeaves;
+ FastTreeTrainerOptions.NumTrees = numTrees;
+ FastTreeTrainerOptions.MinDocumentsInLeafs = minDatapointsInLeaves;
- Args.LabelColumn = label.Name;
- Args.FeatureColumn = featureColumn;
+ FastTreeTrainerOptions.LabelColumn = label.Name;
+ FastTreeTrainerOptions.FeatureColumn = featureColumn;
if (weightColumn != null)
- Args.WeightColumn = Optional.Explicit(weightColumn);
+ FastTreeTrainerOptions.WeightColumn = Optional.Explicit(weightColumn);
if (groupIdColumn != null)
- Args.GroupIdColumn = Optional.Explicit(groupIdColumn);
+ FastTreeTrainerOptions.GroupIdColumn = Optional.Explicit(groupIdColumn);
// The discretization step renders this trainer non-parametric, and therefore it does not need normalization.
// Also since it builds its own internal discretized columnar structures, it cannot benefit from caching.
@@ -149,11 +149,11 @@ private protected FastTreeTrainerBase(IHostEnvironment env,
///
/// Constructor that is used when invoking the classes deriving from this, through maml.
///
- private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label)
- : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
+ private protected FastTreeTrainerBase(IHostEnvironment env, TOptions options, SchemaShape.Column label)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(options.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(options.WeightColumn))
{
- Host.CheckValue(args, nameof(args));
- Args = args;
+ Host.CheckValue(options, nameof(options));
+ FastTreeTrainerOptions = options;
// The discretization step renders this trainer non-parametric, and therefore it does not need normalization.
// Also since it builds its own internal discretized columnar structures, it cannot benefit from caching.
// Finally, even the binary classifiers, being logitboost, tend to not benefit from external calibration.
@@ -186,7 +186,7 @@ protected virtual Float GetMaxLabel()
private void Initialize(IHostEnvironment env)
{
- int numThreads = Args.NumThreads ?? Environment.ProcessorCount;
+ int numThreads = FastTreeTrainerOptions.NumThreads ?? Environment.ProcessorCount;
if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor)
{
using (var ch = Host.Start("FastTreeTrainerBase"))
@@ -196,7 +196,7 @@ private void Initialize(IHostEnvironment env)
+ "setting of the environment. Using {0} training threads instead.", numThreads);
}
}
- ParallelTraining = Args.ParallelTrainer != null ? Args.ParallelTrainer.CreateComponent(env) : new SingleTrainer();
+ ParallelTraining = FastTreeTrainerOptions.ParallelTrainer != null ? FastTreeTrainerOptions.ParallelTrainer.CreateComponent(env) : new SingleTrainer();
ParallelTraining.InitEnvironment();
Tests = new List();
@@ -207,15 +207,15 @@ private void Initialize(IHostEnvironment env)
private protected void ConvertData(RoleMappedData trainData)
{
MetadataUtils.TryGetCategoricalFeatureIndices(trainData.Schema.Schema, trainData.Schema.Feature.Value.Index, out CategoricalFeatures);
- var useTranspose = UseTranspose(Args.DiskTranspose, trainData) && (ValidData == null || UseTranspose(Args.DiskTranspose, ValidData));
- var instanceConverter = new ExamplesToFastTreeBins(Host, Args.MaxBins, useTranspose, !Args.FeatureFlocks, Args.MinDocumentsInLeafs, GetMaxLabel());
+ var useTranspose = UseTranspose(FastTreeTrainerOptions.DiskTranspose, trainData) && (ValidData == null || UseTranspose(FastTreeTrainerOptions.DiskTranspose, ValidData));
+ var instanceConverter = new ExamplesToFastTreeBins(Host, FastTreeTrainerOptions.MaxBins, useTranspose, !FastTreeTrainerOptions.FeatureFlocks, FastTreeTrainerOptions.MinDocumentsInLeafs, GetMaxLabel());
- TrainSet = instanceConverter.FindBinsAndReturnDataset(trainData, PredictionKind, ParallelTraining, CategoricalFeatures, Args.CategoricalSplit);
+ TrainSet = instanceConverter.FindBinsAndReturnDataset(trainData, PredictionKind, ParallelTraining, CategoricalFeatures, FastTreeTrainerOptions.CategoricalSplit);
FeatureMap = instanceConverter.FeatureMap;
if (ValidData != null)
- ValidSet = instanceConverter.GetCompatibleDataset(ValidData, PredictionKind, CategoricalFeatures, Args.CategoricalSplit);
+ ValidSet = instanceConverter.GetCompatibleDataset(ValidData, PredictionKind, CategoricalFeatures, FastTreeTrainerOptions.CategoricalSplit);
if (TestData != null)
- TestSets = new[] { instanceConverter.GetCompatibleDataset(TestData, PredictionKind, CategoricalFeatures, Args.CategoricalSplit) };
+ TestSets = new[] { instanceConverter.GetCompatibleDataset(TestData, PredictionKind, CategoricalFeatures, FastTreeTrainerOptions.CategoricalSplit) };
}
private bool UseTranspose(bool? useTranspose, RoleMappedData data)
@@ -246,7 +246,7 @@ protected void TrainCore(IChannel ch)
}
using (Timer.Time(TimerEvent.TotalTrain))
Train(ch);
- if (Args.ExecutionTimes)
+ if (FastTreeTrainerOptions.ExecutionTimes)
PrintExecutionTimes(ch);
TrainedEnsemble = Ensemble;
if (FeatureMap != null)
@@ -274,24 +274,24 @@ protected virtual void PrintExecutionTimes(IChannel ch)
protected virtual void CheckArgs(IChannel ch)
{
- Args.Check(ch);
+ FastTreeTrainerOptions.Check(ch);
- IntArray.CompatibilityLevel = Args.FeatureCompressionLevel;
+ IntArray.CompatibilityLevel = FastTreeTrainerOptions.FeatureCompressionLevel;
// change arguments
- if (Args.HistogramPoolSize < 2)
- Args.HistogramPoolSize = Args.NumLeaves * 2 / 3;
- if (Args.HistogramPoolSize > Args.NumLeaves - 1)
- Args.HistogramPoolSize = Args.NumLeaves - 1;
+ if (FastTreeTrainerOptions.HistogramPoolSize < 2)
+ FastTreeTrainerOptions.HistogramPoolSize = FastTreeTrainerOptions.NumLeaves * 2 / 3;
+ if (FastTreeTrainerOptions.HistogramPoolSize > FastTreeTrainerOptions.NumLeaves - 1)
+ FastTreeTrainerOptions.HistogramPoolSize = FastTreeTrainerOptions.NumLeaves - 1;
- if (Args.BaggingSize > 0)
+ if (FastTreeTrainerOptions.BaggingSize > 0)
{
- int bagCount = Args.NumTrees / Args.BaggingSize;
- if (bagCount * Args.BaggingSize != Args.NumTrees)
+ int bagCount = FastTreeTrainerOptions.NumTrees / FastTreeTrainerOptions.BaggingSize;
+ if (bagCount * FastTreeTrainerOptions.BaggingSize != FastTreeTrainerOptions.NumTrees)
throw ch.Except("Number of trees should be a multiple of number bag size");
}
- if (!(0 <= Args.GainConfidenceLevel && Args.GainConfidenceLevel < 1))
+ if (!(0 <= FastTreeTrainerOptions.GainConfidenceLevel && FastTreeTrainerOptions.GainConfidenceLevel < 1))
throw ch.Except("Gain confidence level must be in the range [0,1)");
#if OLD_DATALOAD
@@ -342,7 +342,7 @@ protected void PrintTestGraph(IChannel ch)
// we call Tests computing no matter whether we require to print test graph
ComputeTests();
- if (!Args.PrintTestGraph)
+ if (!FastTreeTrainerOptions.PrintTestGraph)
return;
if (Ensemble.NumTrees == 0)
@@ -430,15 +430,15 @@ private float GetFeaturePercentInMemory(IChannel ch)
protected bool[] GetActiveFeatures()
{
var activeFeatures = Utils.CreateArray(TrainSet.NumFeatures, true);
- if (Args.FeatureFraction < 1.0)
+ if (FastTreeTrainerOptions.FeatureFraction < 1.0)
{
if (_featureSelectionRandom == null)
- _featureSelectionRandom = new Random(Args.FeatureSelectSeed);
+ _featureSelectionRandom = new Random(FastTreeTrainerOptions.FeatureSelectSeed);
for (int i = 0; i < TrainSet.NumFeatures; ++i)
{
if (activeFeatures[i])
- activeFeatures[i] = _featureSelectionRandom.NextDouble() <= Args.FeatureFraction;
+ activeFeatures[i] = _featureSelectionRandom.NextDouble() <= FastTreeTrainerOptions.FeatureFraction;
}
}
@@ -602,8 +602,8 @@ private void GenerateActiveFeatureLists(int numberOfItems)
protected virtual BaggingProvider CreateBaggingProvider()
{
- Contracts.Assert(Args.BaggingSize > 0);
- return new BaggingProvider(TrainSet, Args.NumLeaves, Args.RngSeed, Args.BaggingTrainFraction);
+ Contracts.Assert(FastTreeTrainerOptions.BaggingSize > 0);
+ return new BaggingProvider(TrainSet, FastTreeTrainerOptions.NumLeaves, FastTreeTrainerOptions.RngSeed, FastTreeTrainerOptions.BaggingTrainFraction);
}
protected virtual bool ShouldRandomStartOptimizer()
@@ -614,7 +614,7 @@ protected virtual bool ShouldRandomStartOptimizer()
protected virtual void Train(IChannel ch)
{
Contracts.AssertValue(ch);
- int numTotalTrees = Args.NumTrees;
+ int numTotalTrees = FastTreeTrainerOptions.NumTrees;
ch.Info(
"Reserved memory for tree learner: {0} bytes",
@@ -634,13 +634,13 @@ protected virtual void Train(IChannel ch)
if (Ensemble.NumTrees < numTotalTrees && ShouldRandomStartOptimizer())
{
ch.Info("Randomizing start point");
- OptimizationAlgorithm.TrainingScores.RandomizeScores(Args.RngSeed, false);
+ OptimizationAlgorithm.TrainingScores.RandomizeScores(FastTreeTrainerOptions.RngSeed, false);
revertRandomStart = true;
}
ch.Info("Starting to train ...");
- BaggingProvider baggingProvider = Args.BaggingSize > 0 ? CreateBaggingProvider() : null;
+ BaggingProvider baggingProvider = FastTreeTrainerOptions.BaggingSize > 0 ? CreateBaggingProvider() : null;
#if OLD_DATALOAD
#if !NO_STORE
@@ -676,7 +676,7 @@ protected virtual void Train(IChannel ch)
bool[] activeFeatures = _activeFeatureSetQueue.Dequeue();
#endif
- if (Args.BaggingSize > 0 && Ensemble.NumTrees % Args.BaggingSize == 0)
+ if (FastTreeTrainerOptions.BaggingSize > 0 && Ensemble.NumTrees % FastTreeTrainerOptions.BaggingSize == 0)
{
baggingProvider.GenerateNewBag();
OptimizationAlgorithm.TreeLearner.Partitioning =
@@ -699,7 +699,7 @@ protected virtual void Train(IChannel ch)
emptyTrees++;
numTotalTrees--;
}
- else if (Args.BaggingSize > 0 && Ensemble.Trees.Count() > 0)
+ else if (FastTreeTrainerOptions.BaggingSize > 0 && Ensemble.Trees.Count() > 0)
{
ch.Assert(Ensemble.Trees.Last() == tree);
Ensemble.Trees.Last()
@@ -721,7 +721,7 @@ protected virtual void Train(IChannel ch)
{
revertRandomStart = false;
ch.Info("Reverting random score assignment");
- OptimizationAlgorithm.TrainingScores.RandomizeScores(Args.RngSeed, true);
+ OptimizationAlgorithm.TrainingScores.RandomizeScores(FastTreeTrainerOptions.RngSeed, true);
}
#if !NO_STORE
@@ -806,7 +806,7 @@ protected virtual void PrintIterationMessage(IChannel ch, IProgressChannel pch)
protected virtual void PrintTestResults(IChannel ch)
{
- if (Args.TestFrequency != int.MaxValue && (Ensemble.NumTrees % Args.TestFrequency == 0 || Ensemble.NumTrees == Args.NumTrees))
+ if (FastTreeTrainerOptions.TestFrequency != int.MaxValue && (Ensemble.NumTrees % FastTreeTrainerOptions.TestFrequency == 0 || Ensemble.NumTrees == FastTreeTrainerOptions.NumTrees))
{
var sb = new StringBuilder();
using (var sw = new StringWriter(sb))
@@ -826,9 +826,9 @@ protected virtual void PrintPrologInfo(IChannel ch)
{
Contracts.AssertValue(ch);
ch.Trace("Host = {0}", Environment.MachineName);
- ch.Trace("CommandLine = {0}", CmdParser.GetSettings(Host, Args, new TArgs()));
+ ch.Trace("CommandLine = {0}", CmdParser.GetSettings(Host, FastTreeTrainerOptions, new TOptions()));
ch.Trace("GCSettings.IsServerGC = {0}", System.Runtime.GCSettings.IsServerGC);
- ch.Trace("{0}", Args);
+ ch.Trace("{0}", FastTreeTrainerOptions);
}
protected ScoreTracker ConstructScoreTracker(Dataset set)
@@ -856,7 +856,7 @@ protected ScoreTracker ConstructScoreTracker(Dataset set)
private double[] ComputeScoresSmart(IChannel ch, Dataset set)
{
- if (!Args.CompressEnsemble)
+ if (!FastTreeTrainerOptions.CompressEnsemble)
{
foreach (var st in OptimizationAlgorithm.TrackedScores)
if (st.Dataset == set)
diff --git a/src/Microsoft.ML.FastTree/FastTreeArguments.cs b/src/Microsoft.ML.FastTree/FastTreeArguments.cs
index 97eae76ac6..f3afaa18e0 100644
--- a/src/Microsoft.ML.FastTree/FastTreeArguments.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeArguments.cs
@@ -150,7 +150,7 @@ internal static class Defaults
public const double LearningRates = 0.2;
}
- public abstract class TreeArgs : LearnerInputBaseWithGroupId
+ public abstract class TreeOptions : LearnerInputBaseWithGroupId
{
///
/// Allows to choose Parallel FastTree Learning Algorithm.
@@ -442,7 +442,7 @@ internal virtual void Check(IExceptionContext ectx)
}
}
- public abstract class BoostedTreeArgs : TreeArgs
+ public abstract class BoostedTreeArgs : TreeOptions
{
// REVIEW: TLC FR likes to call it bestStepRegressionTrees which might be more appropriate.
//Use the second derivative for split gains (not just outputs). Use MaxTreeOutput to "clip" cases where the second derivative is too close to zero.
diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs
index a2428aec27..e59afb3358 100644
--- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs
@@ -139,7 +139,7 @@ internal FastTreeBinaryClassificationTrainer(IHostEnvironment env,
: base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDatapointsInLeaves, learningRate)
{
// Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss
- _sigmoidParameter = 2.0 * Args.LearningRates;
+ _sigmoidParameter = 2.0 * FastTreeTrainerOptions.LearningRates;
}
///
@@ -151,7 +151,7 @@ internal FastTreeBinaryClassificationTrainer(IHostEnvironment env, Options optio
: base(env, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn))
{
// Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss
- _sigmoidParameter = 2.0 * Args.LearningRates;
+ _sigmoidParameter = 2.0 * FastTreeTrainerOptions.LearningRates;
}
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
@@ -193,25 +193,25 @@ protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
return new ObjectiveImpl(
TrainSet,
_trainSetLabels,
- Args.LearningRates,
- Args.Shrinkage,
+ FastTreeTrainerOptions.LearningRates,
+ FastTreeTrainerOptions.Shrinkage,
_sigmoidParameter,
- Args.UnbalancedSets,
- Args.MaxTreeOutput,
- Args.GetDerivativesSampleRate,
- Args.BestStepRankingRegressionTrees,
- Args.RngSeed,
+ FastTreeTrainerOptions.UnbalancedSets,
+ FastTreeTrainerOptions.MaxTreeOutput,
+ FastTreeTrainerOptions.GetDerivativesSampleRate,
+ FastTreeTrainerOptions.BestStepRankingRegressionTrees,
+ FastTreeTrainerOptions.RngSeed,
ParallelTraining);
}
private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
{
OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch);
- if (Args.UseLineSearch)
+ if (FastTreeTrainerOptions.UseLineSearch)
{
var lossCalculator = new BinaryClassificationTest(optimizationAlgorithm.TrainingScores, _trainSetLabels, _sigmoidParameter);
// REVIEW: we should makeloss indices an enum in BinaryClassificationTest
- optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, Args.UnbalancedSets ? 3 /*Unbalanced sets loss*/ : 1 /*normal loss*/, Args.NumPostBracketSteps, Args.MinStepSize);
+ optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, FastTreeTrainerOptions.UnbalancedSets ? 3 /*Unbalanced sets loss*/ : 1 /*normal loss*/, FastTreeTrainerOptions.NumPostBracketSteps, FastTreeTrainerOptions.MinStepSize);
}
return optimizationAlgorithm;
}
@@ -258,9 +258,9 @@ protected override void InitializeTests()
}
}
- if (Args.EnablePruning && ValidSet != null)
+ if (FastTreeTrainerOptions.EnablePruning && ValidSet != null)
{
- if (!Args.UseTolerantPruning)
+ if (!FastTreeTrainerOptions.UseTolerantPruning)
{
//use simple early stopping condition
PruningTest = new TestHistory(ValidTest, 0);
@@ -268,7 +268,7 @@ protected override void InitializeTests()
else
{
//use tollerant stopping condition
- PruningTest = new TestWindowWithTolerance(ValidTest, 0, Args.PruningWindowSize, Args.PruningThreshold);
+ PruningTest = new TestWindowWithTolerance(ValidTest, 0, FastTreeTrainerOptions.PruningWindowSize, FastTreeTrainerOptions.PruningThreshold);
}
}
}
diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs
index b3d0aa12af..fa4e137376 100644
--- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs
@@ -132,8 +132,8 @@ private Double[] GetLabelGains()
{
try
{
- Host.AssertValue(Args.CustomGains);
- return Args.CustomGains.Split(',').Select(k => Convert.ToDouble(k.Trim())).ToArray();
+ Host.AssertValue(FastTreeTrainerOptions.CustomGains);
+ return FastTreeTrainerOptions.CustomGains.Split(',').Select(k => Convert.ToDouble(k.Trim())).ToArray();
}
catch (Exception ex)
{
@@ -145,12 +145,12 @@ private Double[] GetLabelGains()
protected override void CheckArgs(IChannel ch)
{
- if (!string.IsNullOrEmpty(Args.CustomGains))
+ if (!string.IsNullOrEmpty(FastTreeTrainerOptions.CustomGains))
{
- var stringGain = Args.CustomGains.Split(',');
+ var stringGain = FastTreeTrainerOptions.CustomGains.Split(',');
if (stringGain.Length < 5)
{
- throw ch.ExceptUserArg(nameof(Args.CustomGains),
+ throw ch.ExceptUserArg(nameof(FastTreeTrainerOptions.CustomGains),
"{0} an invalid number of gain levels. We require at least 5. Make certain they're comma separated.",
stringGain.Length);
}
@@ -159,7 +159,7 @@ protected override void CheckArgs(IChannel ch)
{
if (!Double.TryParse(stringGain[i], out gain[i]))
{
- throw ch.ExceptUserArg(nameof(Args.CustomGains),
+ throw ch.ExceptUserArg(nameof(FastTreeTrainerOptions.CustomGains),
"Could not parse '{0}' as a floating point number", stringGain[0]);
}
}
@@ -167,7 +167,7 @@ protected override void CheckArgs(IChannel ch)
Dataset.DatasetSkeleton.LabelGainMap = gain;
}
- ch.CheckUserArg((Args.EarlyStoppingRule == null && !Args.EnablePruning) || (Args.EarlyStoppingMetrics == 1 || Args.EarlyStoppingMetrics == 3), nameof(Args.EarlyStoppingMetrics),
+ ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 3), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics),
"earlyStoppingMetrics should be 1 or 3.");
base.CheckArgs(ch);
@@ -176,33 +176,33 @@ protected override void CheckArgs(IChannel ch)
protected override void Initialize(IChannel ch)
{
base.Initialize(ch);
- if (Args.CompressEnsemble)
+ if (FastTreeTrainerOptions.CompressEnsemble)
{
_ensembleCompressor = new LassoBasedEnsembleCompressor();
- _ensembleCompressor.Initialize(Args.NumTrees, TrainSet, TrainSet.Ratings, Args.RngSeed);
+ _ensembleCompressor.Initialize(FastTreeTrainerOptions.NumTrees, TrainSet, TrainSet.Ratings, FastTreeTrainerOptions.RngSeed);
}
}
protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
{
- return new LambdaRankObjectiveFunction(TrainSet, TrainSet.Ratings, Args, ParallelTraining);
+ return new LambdaRankObjectiveFunction(TrainSet, TrainSet.Ratings, FastTreeTrainerOptions, ParallelTraining);
}
private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
{
OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch);
- if (Args.UseLineSearch)
+ if (FastTreeTrainerOptions.UseLineSearch)
{
- _specialTrainSetTest = new FastNdcgTest(optimizationAlgorithm.TrainingScores, TrainSet.Ratings, Args.SortingAlgorithm, Args.EarlyStoppingMetrics);
- optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(_specialTrainSetTest, 0, Args.NumPostBracketSteps, Args.MinStepSize);
+ _specialTrainSetTest = new FastNdcgTest(optimizationAlgorithm.TrainingScores, TrainSet.Ratings, FastTreeTrainerOptions.SortingAlgorithm, FastTreeTrainerOptions.EarlyStoppingMetrics);
+ optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(_specialTrainSetTest, 0, FastTreeTrainerOptions.NumPostBracketSteps, FastTreeTrainerOptions.MinStepSize);
}
return optimizationAlgorithm;
}
protected override BaggingProvider CreateBaggingProvider()
{
- Host.Assert(Args.BaggingSize > 0);
- return new RankingBaggingProvider(TrainSet, Args.NumLeaves, Args.RngSeed, Args.BaggingTrainFraction);
+ Host.Assert(FastTreeTrainerOptions.BaggingSize > 0);
+ return new RankingBaggingProvider(TrainSet, FastTreeTrainerOptions.NumLeaves, FastTreeTrainerOptions.RngSeed, FastTreeTrainerOptions.BaggingTrainFraction);
}
protected override void PrepareLabels(IChannel ch)
@@ -211,17 +211,17 @@ protected override void PrepareLabels(IChannel ch)
protected override Test ConstructTestForTrainingData()
{
- return new NdcgTest(ConstructScoreTracker(TrainSet), TrainSet.Ratings, Args.SortingAlgorithm);
+ return new NdcgTest(ConstructScoreTracker(TrainSet), TrainSet.Ratings, FastTreeTrainerOptions.SortingAlgorithm);
}
protected override void InitializeTests()
{
- if (Args.TestFrequency != int.MaxValue)
+ if (FastTreeTrainerOptions.TestFrequency != int.MaxValue)
{
AddFullTests();
}
- if (Args.PrintTestGraph)
+ if (FastTreeTrainerOptions.PrintTestGraph)
{
// If FirstTestHistory is null (which means the tests were not intialized due to /tf==infinity)
// We need initialize first set for graph printing
@@ -238,14 +238,14 @@ protected override void InitializeTests()
if (ValidSet != null)
ValidTest = CreateSpecialValidSetTest();
- if (Args.PrintTrainValidGraph && Args.EnablePruning && _specialTrainSetTest == null)
+ if (FastTreeTrainerOptions.PrintTrainValidGraph && FastTreeTrainerOptions.EnablePruning && _specialTrainSetTest == null)
{
_specialTrainSetTest = CreateSpecialTrainSetTest();
}
- if (Args.EnablePruning && ValidTest != null)
+ if (FastTreeTrainerOptions.EnablePruning && ValidTest != null)
{
- if (!Args.UseTolerantPruning)
+ if (!FastTreeTrainerOptions.UseTolerantPruning)
{
//use simple eraly stopping condition
PruningTest = new TestHistory(ValidTest, 0);
@@ -253,7 +253,7 @@ protected override void InitializeTests()
else
{
//use tolerant stopping condition
- PruningTest = new TestWindowWithTolerance(ValidTest, 0, Args.PruningWindowSize, Args.PruningThreshold);
+ PruningTest = new TestWindowWithTolerance(ValidTest, 0, FastTreeTrainerOptions.PruningWindowSize, FastTreeTrainerOptions.PruningThreshold);
}
}
}
@@ -375,7 +375,7 @@ protected override void Train(IChannel ch)
private protected override void CustomizedTrainingIteration(InternalRegressionTree tree)
{
Contracts.AssertValueOrNull(tree);
- if (tree != null && Args.CompressEnsemble)
+ if (tree != null && FastTreeTrainerOptions.CompressEnsemble)
{
double[] trainOutputs = Ensemble.GetTreeAt(Ensemble.NumTrees - 1).GetOutputs(TrainSet);
_ensembleCompressor.SetTreeScores(Ensemble.NumTrees - 1, trainOutputs);
@@ -395,7 +395,7 @@ private Test CreateStandardTest(Dataset dataset)
return new NdcgTest(
ConstructScoreTracker(dataset),
dataset.Ratings,
- Args.SortingAlgorithm);
+ FastTreeTrainerOptions.SortingAlgorithm);
}
///
@@ -408,8 +408,8 @@ private Test CreateSpecialTrainSetTest()
OptimizationAlgorithm.TrainingScores,
OptimizationAlgorithm.ObjectiveFunction as LambdaRankObjectiveFunction,
TrainSet.Ratings,
- Args.SortingAlgorithm,
- Args.EarlyStoppingMetrics);
+ FastTreeTrainerOptions.SortingAlgorithm,
+ FastTreeTrainerOptions.EarlyStoppingMetrics);
}
///
@@ -421,8 +421,8 @@ private Test CreateSpecialValidSetTest()
return new FastNdcgTest(
ConstructScoreTracker(ValidSet),
ValidSet.Ratings,
- Args.SortingAlgorithm,
- Args.EarlyStoppingMetrics);
+ FastTreeTrainerOptions.SortingAlgorithm,
+ FastTreeTrainerOptions.EarlyStoppingMetrics);
}
///
@@ -442,12 +442,12 @@ protected override string GetTestGraphHeader()
{
StringBuilder headerBuilder = new StringBuilder("Eval:\tFileName\tNDCG@1\tNDCG@2\tNDCG@3\tNDCG@4\tNDCG@5\tNDCG@6\tNDCG@7\tNDCG@8\tNDCG@9\tNDCG@10");
- if (Args.PrintTrainValidGraph)
+ if (FastTreeTrainerOptions.PrintTrainValidGraph)
{
headerBuilder.Append("\tNDCG@20\tNDCG@40");
headerBuilder.AppendFormat(
"\nNote: Printing train NDCG@{0} as NDCG@20 and validation NDCG@{0} as NDCG@40..\n",
- Args.EarlyStoppingMetrics);
+ FastTreeTrainerOptions.EarlyStoppingMetrics);
}
return headerBuilder.ToString();
diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs
index 52a9185b9d..e6bdaa7e54 100644
--- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs
@@ -107,7 +107,7 @@ protected override void CheckArgs(IChannel ch)
base.CheckArgs(ch);
- ch.CheckUserArg((Args.EarlyStoppingRule == null && !Args.EnablePruning) || (Args.EarlyStoppingMetrics >= 1 && Args.EarlyStoppingMetrics <= 2), nameof(Args.EarlyStoppingMetrics),
+ ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics),
"earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)");
}
@@ -118,17 +118,17 @@ private static SchemaShape.Column MakeLabelColumn(string labelColumn)
protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
{
- return new ObjectiveImpl(TrainSet, Args);
+ return new ObjectiveImpl(TrainSet, FastTreeTrainerOptions);
}
private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
{
OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch);
- if (Args.UseLineSearch)
+ if (FastTreeTrainerOptions.UseLineSearch)
{
var lossCalculator = new RegressionTest(optimizationAlgorithm.TrainingScores);
// REVIEW: We should make loss indices an enum in BinaryClassificationTest.
- optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, 1 /*L2 error*/, Args.NumPostBracketSteps, Args.MinStepSize);
+ optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, 1 /*L2 error*/, FastTreeTrainerOptions.NumPostBracketSteps, FastTreeTrainerOptions.MinStepSize);
}
return optimizationAlgorithm;
@@ -229,10 +229,10 @@ protected virtual void AddFullNDCGTests()
protected override void InitializeTests()
{
// Initialize regression tests.
- if (Args.TestFrequency != int.MaxValue)
+ if (FastTreeTrainerOptions.TestFrequency != int.MaxValue)
AddFullRegressionTests();
- if (Args.PrintTestGraph)
+ if (FastTreeTrainerOptions.PrintTestGraph)
{
// If FirstTestHistory is null (which means the tests were not intialized due to /tf==infinity),
// we need initialize first set for graph printing.
@@ -244,24 +244,24 @@ protected override void InitializeTests()
}
}
- if (Args.PrintTrainValidGraph && _trainRegressionTest == null)
+ if (FastTreeTrainerOptions.PrintTrainValidGraph && _trainRegressionTest == null)
{
Test trainRegressionTest = new RegressionTest(ConstructScoreTracker(TrainSet));
_trainRegressionTest = trainRegressionTest;
}
- if (Args.PrintTrainValidGraph && _testRegressionTest == null && TestSets != null && TestSets.Length > 0)
+ if (FastTreeTrainerOptions.PrintTrainValidGraph && _testRegressionTest == null && TestSets != null && TestSets.Length > 0)
_testRegressionTest = new RegressionTest(ConstructScoreTracker(TestSets[0]));
// Add early stopping if appropriate.
- TrainTest = new RegressionTest(ConstructScoreTracker(TrainSet), Args.EarlyStoppingMetrics);
+ TrainTest = new RegressionTest(ConstructScoreTracker(TrainSet), FastTreeTrainerOptions.EarlyStoppingMetrics);
if (ValidSet != null)
- ValidTest = new RegressionTest(ConstructScoreTracker(ValidSet), Args.EarlyStoppingMetrics);
+ ValidTest = new RegressionTest(ConstructScoreTracker(ValidSet), FastTreeTrainerOptions.EarlyStoppingMetrics);
- if (Args.EnablePruning && ValidTest != null)
+ if (FastTreeTrainerOptions.EnablePruning && ValidTest != null)
{
- if (Args.UseTolerantPruning) // Use simple early stopping condition.
- PruningTest = new TestWindowWithTolerance(ValidTest, 0, Args.PruningWindowSize, Args.PruningThreshold);
+ if (FastTreeTrainerOptions.UseTolerantPruning) // Use simple early stopping condition.
+ PruningTest = new TestWindowWithTolerance(ValidTest, 0, FastTreeTrainerOptions.PruningWindowSize, FastTreeTrainerOptions.PruningThreshold);
else
PruningTest = new TestHistory(ValidTest, 0);
}
@@ -311,7 +311,7 @@ protected override string GetTestGraphHeader()
{
StringBuilder headerBuilder = new StringBuilder("Eval:\tFileName\tNDCG@1\tNDCG@2\tNDCG@3\tNDCG@4\tNDCG@5\tNDCG@6\tNDCG@7\tNDCG@8\tNDCG@9\tNDCG@10");
- if (Args.PrintTrainValidGraph)
+ if (FastTreeTrainerOptions.PrintTrainValidGraph)
{
headerBuilder.Append("\tNDCG@20\tNDCG@40");
headerBuilder.Append("\nNote: Printing train L2 error as NDCG@20 and test L2 error as NDCG@40..\n");
diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs
index 03df71233d..2e2c27deb7 100644
--- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs
@@ -115,27 +115,27 @@ protected override void CheckArgs(IChannel ch)
// a simple integer, because the metric that we might want is parameterized by this floating point "index" parameter. For now
// we just leave the existing regression checks, though with a warning.
- if (Args.EarlyStoppingMetrics > 0)
+ if (FastTreeTrainerOptions.EarlyStoppingMetrics > 0)
ch.Warning("For Tweedie regression, early stopping does not yet use the Tweedie distribution.");
- ch.CheckUserArg((Args.EarlyStoppingRule == null && !Args.EnablePruning) || (Args.EarlyStoppingMetrics >= 1 && Args.EarlyStoppingMetrics <= 2), nameof(Args.EarlyStoppingMetrics),
+ ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics),
"earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)");
}
protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
{
- return new ObjectiveImpl(TrainSet, Args);
+ return new ObjectiveImpl(TrainSet, FastTreeTrainerOptions);
}
private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
{
OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch);
- if (Args.UseLineSearch)
+ if (FastTreeTrainerOptions.UseLineSearch)
{
var lossCalculator = new RegressionTest(optimizationAlgorithm.TrainingScores);
// REVIEW: We should make loss indices an enum in BinaryClassificationTest.
// REVIEW: Nope, subcomponent.
- optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, 1 /*L2 error*/, Args.NumPostBracketSteps, Args.MinStepSize);
+ optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, 1 /*L2 error*/, FastTreeTrainerOptions.NumPostBracketSteps, FastTreeTrainerOptions.MinStepSize);
}
return optimizationAlgorithm;
@@ -171,7 +171,7 @@ protected override Test ConstructTestForTrainingData()
private void Initialize()
{
- Host.CheckUserArg(1 <= Args.Index && Args.Index <= 2, nameof(Args.Index), "Must be in the range [1, 2]");
+ Host.CheckUserArg(1 <= FastTreeTrainerOptions.Index && FastTreeTrainerOptions.Index <= 2, nameof(FastTreeTrainerOptions.Index), "Must be in the range [1, 2]");
_outputColumns = new[]
{
@@ -202,10 +202,10 @@ private void AddFullRegressionTests()
protected override void InitializeTests()
{
// Initialize regression tests.
- if (Args.TestFrequency != int.MaxValue)
+ if (FastTreeTrainerOptions.TestFrequency != int.MaxValue)
AddFullRegressionTests();
- if (Args.PrintTestGraph)
+ if (FastTreeTrainerOptions.PrintTestGraph)
{
// If FirstTestHistory is null (which means the tests were not intialized due to /tf==infinity),
// we need initialize first set for graph printing.
@@ -217,24 +217,24 @@ protected override void InitializeTests()
}
}
- if (Args.PrintTrainValidGraph && _trainRegressionTest == null)
+ if (FastTreeTrainerOptions.PrintTrainValidGraph && _trainRegressionTest == null)
{
Test trainRegressionTest = new RegressionTest(ConstructScoreTracker(TrainSet));
_trainRegressionTest = trainRegressionTest;
}
- if (Args.PrintTrainValidGraph && _testRegressionTest == null && TestSets != null && TestSets.Length > 0)
+ if (FastTreeTrainerOptions.PrintTrainValidGraph && _testRegressionTest == null && TestSets != null && TestSets.Length > 0)
_testRegressionTest = new RegressionTest(ConstructScoreTracker(TestSets[0]));
// Add early stopping if appropriate.
- TrainTest = new RegressionTest(ConstructScoreTracker(TrainSet), Args.EarlyStoppingMetrics);
+ TrainTest = new RegressionTest(ConstructScoreTracker(TrainSet), FastTreeTrainerOptions.EarlyStoppingMetrics);
if (ValidSet != null)
- ValidTest = new RegressionTest(ConstructScoreTracker(ValidSet), Args.EarlyStoppingMetrics);
+ ValidTest = new RegressionTest(ConstructScoreTracker(ValidSet), FastTreeTrainerOptions.EarlyStoppingMetrics);
- if (Args.EnablePruning && ValidTest != null)
+ if (FastTreeTrainerOptions.EnablePruning && ValidTest != null)
{
- if (Args.UseTolerantPruning) // Use simple early stopping condition.
- PruningTest = new TestWindowWithTolerance(ValidTest, 0, Args.PruningWindowSize, Args.PruningThreshold);
+ if (FastTreeTrainerOptions.UseTolerantPruning) // Use simple early stopping condition.
+ PruningTest = new TestWindowWithTolerance(ValidTest, 0, FastTreeTrainerOptions.PruningWindowSize, FastTreeTrainerOptions.PruningThreshold);
else
PruningTest = new TestHistory(ValidTest, 0);
}
@@ -249,7 +249,7 @@ protected override string GetTestGraphHeader()
{
StringBuilder headerBuilder = new StringBuilder("Eval:\tFileName\tNDCG@1\tNDCG@2\tNDCG@3\tNDCG@4\tNDCG@5\tNDCG@6\tNDCG@7\tNDCG@8\tNDCG@9\tNDCG@10");
- if (Args.PrintTrainValidGraph)
+ if (FastTreeTrainerOptions.PrintTrainValidGraph)
{
headerBuilder.Append("\tNDCG@20\tNDCG@40");
headerBuilder.Append("\nNote: Printing train L2 error as NDCG@20 and test L2 error as NDCG@40..\n");
diff --git a/src/Microsoft.ML.FastTree/GamClassification.cs b/src/Microsoft.ML.FastTree/GamClassification.cs
index 8dd51c03b5..a0d199fc4b 100644
--- a/src/Microsoft.ML.FastTree/GamClassification.cs
+++ b/src/Microsoft.ML.FastTree/GamClassification.cs
@@ -33,7 +33,7 @@ public sealed class BinaryClassificationGamTrainer :
BinaryPredictionTransformer>,
CalibratedModelParametersBase>
{
- public sealed class Options : ArgumentsBase
+ public sealed class Options : OptionsBase
{
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Should we use derivatives optimized for unbalanced sets", ShortName = "us")]
[TGUI(Label = "Optimize for unbalanced")]
diff --git a/src/Microsoft.ML.FastTree/GamRegression.cs b/src/Microsoft.ML.FastTree/GamRegression.cs
index d740e13589..56ec5f6176 100644
--- a/src/Microsoft.ML.FastTree/GamRegression.cs
+++ b/src/Microsoft.ML.FastTree/GamRegression.cs
@@ -26,7 +26,7 @@ namespace Microsoft.ML.Trainers.FastTree
{
public sealed class RegressionGamTrainer : GamTrainerBase, RegressionGamModelParameters>
{
- public partial class Options : ArgumentsBase
+ public partial class Options : OptionsBase
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Metric for pruning. (For regression, 1: L1, 2:L2; default L2)", ShortName = "pmetric")]
[TGUI(Description = "Metric for pruning. (For regression, 1: L1, 2:L2; default L2")]
diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs
index d1017b4f1d..9af1ea1c6b 100644
--- a/src/Microsoft.ML.FastTree/GamTrainer.cs
+++ b/src/Microsoft.ML.FastTree/GamTrainer.cs
@@ -51,10 +51,10 @@ namespace Microsoft.ML.Trainers.FastTree
///
public abstract partial class GamTrainerBase : TrainerEstimatorBase
where TTransformer: ISingleFeaturePredictionTransformer
- where TArgs : GamTrainerBase.ArgumentsBase, new()
+ where TArgs : GamTrainerBase.OptionsBase, new()
where TPredictor : class
{
- public abstract class ArgumentsBase : LearnerInputBaseWithWeight
+ public abstract class OptionsBase : LearnerInputBaseWithWeight
{
[Argument(ArgumentType.LastOccurenceWins, HelpText = "The entropy (regularization) coefficient between 0 and 1", ShortName = "e")]
public double EntropyCoefficient;
diff --git a/src/Microsoft.ML.FastTree/RandomForest.cs b/src/Microsoft.ML.FastTree/RandomForest.cs
index 419f2b0129..425abb73e9 100644
--- a/src/Microsoft.ML.FastTree/RandomForest.cs
+++ b/src/Microsoft.ML.FastTree/RandomForest.cs
@@ -5,7 +5,7 @@
namespace Microsoft.ML.Trainers.FastTree
{
public abstract class RandomForestTrainerBase : FastTreeTrainerBase
- where TArgs : FastForestArgumentsBase, new()
+ where TArgs : FastForestOptionsBase, new()
where TModel : class
where TTransformer: ISingleFeaturePredictionTransformer
{
@@ -46,7 +46,7 @@ private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(
optimizationAlgorithm.TreeLearner = ConstructTreeLearner(ch);
optimizationAlgorithm.ObjectiveFunction = ConstructObjFunc(ch);
- optimizationAlgorithm.Smoothing = Args.Smoothing;
+ optimizationAlgorithm.Smoothing = FastTreeTrainerOptions.Smoothing;
// No notion of dropout for non-boosting applications.
optimizationAlgorithm.DropoutRate = 0;
optimizationAlgorithm.DropoutRng = null;
@@ -62,12 +62,12 @@ protected override void InitializeTests()
private protected override TreeLearner ConstructTreeLearner(IChannel ch)
{
return new RandomForestLeastSquaresTreeLearner(
- TrainSet, Args.NumLeaves, Args.MinDocumentsInLeafs, Args.EntropyCoefficient,
- Args.FeatureFirstUsePenalty, Args.FeatureReusePenalty, Args.SoftmaxTemperature,
- Args.HistogramPoolSize, Args.RngSeed, Args.SplitFraction,
- Args.AllowEmptyTrees, Args.GainConfidenceLevel, Args.MaxCategoricalGroupsPerNode,
- Args.MaxCategoricalSplitPoints, _quantileEnabled, Args.QuantileSampleCount, ParallelTraining,
- Args.MinDocsPercentageForCategoricalSplit, Args.Bundling, Args.MinDocsForCategoricalSplit, Args.Bias);
+ TrainSet, FastTreeTrainerOptions.NumLeaves, FastTreeTrainerOptions.MinDocumentsInLeafs, FastTreeTrainerOptions.EntropyCoefficient,
+ FastTreeTrainerOptions.FeatureFirstUsePenalty, FastTreeTrainerOptions.FeatureReusePenalty, FastTreeTrainerOptions.SoftmaxTemperature,
+ FastTreeTrainerOptions.HistogramPoolSize, FastTreeTrainerOptions.RngSeed, FastTreeTrainerOptions.SplitFraction,
+ FastTreeTrainerOptions.AllowEmptyTrees, FastTreeTrainerOptions.GainConfidenceLevel, FastTreeTrainerOptions.MaxCategoricalGroupsPerNode,
+ FastTreeTrainerOptions.MaxCategoricalSplitPoints, _quantileEnabled, FastTreeTrainerOptions.QuantileSampleCount, ParallelTraining,
+ FastTreeTrainerOptions.MinDocsPercentageForCategoricalSplit, FastTreeTrainerOptions.Bundling, FastTreeTrainerOptions.MinDocsForCategoricalSplit, FastTreeTrainerOptions.Bias);
}
public abstract class RandomForestObjectiveFunction : ObjectiveFunctionBase
diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs
index 348eed8429..232b7d62ff 100644
--- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs
+++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs
@@ -32,12 +32,12 @@
namespace Microsoft.ML.Trainers.FastTree
{
- public abstract class FastForestArgumentsBase : TreeArgs
+ public abstract class FastForestOptionsBase : TreeOptions
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of labels to be sampled from each leaf to make the distribtuion", ShortName = "qsc")]
public int QuantileSampleCount = 100;
- public FastForestArgumentsBase()
+ public FastForestOptionsBase()
{
FeatureFraction = 0.7;
BaggingSize = 1;
@@ -111,7 +111,7 @@ private static IPredictorProducing Create(IHostEnvironment env, ModelLoad
public sealed partial class FastForestClassification :
RandomForestTrainerBase, FastForestClassificationModelParameters>
{
- public sealed class Options : FastForestArgumentsBase
+ public sealed class Options : FastForestOptionsBase
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Upper bound on absolute value of single tree output", ShortName = "mo")]
public Double MaxTreeOutput = 100;
@@ -196,7 +196,7 @@ private protected override FastForestClassificationModelParameters TrainModelCor
protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
{
- return new ObjectiveFunctionImpl(TrainSet, _trainSetLabels, Args);
+ return new ObjectiveFunctionImpl(TrainSet, _trainSetLabels, FastTreeTrainerOptions);
}
protected override void PrepareLabels(IChannel ch)
diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs
index f1a348706c..ddb79ed585 100644
--- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs
+++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs
@@ -138,7 +138,7 @@ ISchemaBindableMapper IQuantileRegressionPredictor.CreateMapper(Double[] quantil
public sealed partial class FastForestRegression
: RandomForestTrainerBase, FastForestRegressionModelParameters>
{
- public sealed class Options : FastForestArgumentsBase
+ public sealed class Options : FastForestOptionsBase
{
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Shuffle the labels on every iteration. " +
"Useful probably only if using this tree as a tree leaf featurizer for multiclass.")]
@@ -204,7 +204,7 @@ private protected override FastForestRegressionModelParameters TrainModelCore(Tr
ConvertData(trainData);
TrainCore(ch);
}
- return new FastForestRegressionModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs, Args.QuantileSampleCount);
+ return new FastForestRegressionModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs, FastTreeTrainerOptions.QuantileSampleCount);
}
protected override void PrepareLabels(IChannel ch)
@@ -213,7 +213,7 @@ protected override void PrepareLabels(IChannel ch)
protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
{
- return ObjectiveFunctionImplBase.Create(TrainSet, Args);
+ return ObjectiveFunctionImplBase.Create(TrainSet, FastTreeTrainerOptions);
}
protected override Test ConstructTestForTrainingData()
diff --git a/src/Microsoft.ML.ImageAnalytics/ImageGrayscale.cs b/src/Microsoft.ML.ImageAnalytics/ImageGrayscale.cs
index 32768bc9a9..35b318e41a 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImageGrayscale.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImageGrayscale.cs
@@ -106,14 +106,14 @@ internal ImageGrayscalingTransformer(IHostEnvironment env, params (string output
}
// Factory method for SignatureDataTransform.
- internal static IDataTransform Create(IHostEnvironment env, Options args, IDataView input)
+ internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
- env.CheckValue(args, nameof(args));
+ env.CheckValue(options, nameof(options));
env.CheckValue(input, nameof(input));
- env.CheckValue(args.Columns, nameof(args.Columns));
+ env.CheckValue(options.Columns, nameof(options.Columns));
- return new ImageGrayscalingTransformer(env, args.Columns.Select(x => (x.Name, x.Source ?? x.Name)).ToArray())
+ return new ImageGrayscalingTransformer(env, options.Columns.Select(x => (x.Name, x.Source ?? x.Name)).ToArray())
.MakeDataTransform(input);
}
diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs
index 1c972e18f0..6906af7c61 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs
@@ -97,9 +97,9 @@ internal ImageLoadingTransformer(IHostEnvironment env, string imageFolder = null
}
// Factory method for SignatureDataTransform.
- internal static IDataTransform Create(IHostEnvironment env, Options args, IDataView data)
+ internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView data)
{
- return new ImageLoadingTransformer(env, args.ImageFolder, args.Columns.Select(x => (x.Name, x.Source ?? x.Name)).ToArray())
+ return new ImageLoadingTransformer(env, options.ImageFolder, options.Columns.Select(x => (x.Name, x.Source ?? x.Name)).ToArray())
.MakeDataTransform(data);
}
diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractor.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractor.cs
index 2b7d2d0243..158e888310 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractor.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractor.cs
@@ -199,19 +199,19 @@ private static (string outputColumnName, string inputColumnName)[] GetColumnPair
}
// Factory method for SignatureDataTransform.
- internal static IDataTransform Create(IHostEnvironment env, Options args, IDataView input)
+ internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
- env.CheckValue(args, nameof(args));
+ env.CheckValue(options, nameof(options));
env.CheckValue(input, nameof(input));
- env.CheckValue(args.Columns, nameof(args.Columns));
+ env.CheckValue(options.Columns, nameof(options.Columns));
- var columns = new ImagePixelExtractingEstimator.ColumnInfo[args.Columns.Length];
+ var columns = new ImagePixelExtractingEstimator.ColumnInfo[options.Columns.Length];
for (int i = 0; i < columns.Length; i++)
{
- var item = args.Columns[i];
- columns[i] = new ImagePixelExtractingEstimator.ColumnInfo(item, args);
+ var item = options.Columns[i];
+ columns[i] = new ImagePixelExtractingEstimator.ColumnInfo(item, options);
}
var transformer = new ImagePixelExtractingTransformer(env, columns);
@@ -549,23 +549,23 @@ public sealed class ColumnInfo
internal bool Green => (Colors & ColorBits.Green) != 0;
internal bool Blue => (Colors & ColorBits.Blue) != 0;
- internal ColumnInfo(ImagePixelExtractingTransformer.Column item, ImagePixelExtractingTransformer.Options args)
+ internal ColumnInfo(ImagePixelExtractingTransformer.Column item, ImagePixelExtractingTransformer.Options options)
{
Contracts.CheckValue(item, nameof(item));
- Contracts.CheckValue(args, nameof(args));
+ Contracts.CheckValue(options, nameof(options));
Name = item.Name;
InputColumnName = item.Source ?? item.Name;
- if (item.UseAlpha ?? args.UseAlpha) { Colors |= ColorBits.Alpha; Planes++; }
- if (item.UseRed ?? args.UseRed) { Colors |= ColorBits.Red; Planes++; }
- if (item.UseGreen ?? args.UseGreen) { Colors |= ColorBits.Green; Planes++; }
- if (item.UseBlue ?? args.UseBlue) { Colors |= ColorBits.Blue; Planes++; }
+ if (item.UseAlpha ?? options.UseAlpha) { Colors |= ColorBits.Alpha; Planes++; }
+ if (item.UseRed ?? options.UseRed) { Colors |= ColorBits.Red; Planes++; }
+ if (item.UseGreen ?? options.UseGreen) { Colors |= ColorBits.Green; Planes++; }
+ if (item.UseBlue ?? options.UseBlue) { Colors |= ColorBits.Blue; Planes++; }
Contracts.CheckUserArg(Planes > 0, nameof(item.UseRed), "Need to use at least one color plane");
- Interleave = item.InterleaveArgb ?? args.InterleaveArgb;
+ Interleave = item.InterleaveArgb ?? options.InterleaveArgb;
- AsFloat = item.Convert ?? args.Convert;
+ AsFloat = item.Convert ?? options.Convert;
if (!AsFloat)
{
Offset = ImagePixelExtractingTransformer.Defaults.Offset;
@@ -573,8 +573,8 @@ internal ColumnInfo(ImagePixelExtractingTransformer.Column item, ImagePixelExtra
}
else
{
- Offset = item.Offset ?? args.Offset ?? ImagePixelExtractingTransformer.Defaults.Offset;
- Scale = item.Scale ?? args.Scale ?? ImagePixelExtractingTransformer.Defaults.Scale;
+ Offset = item.Offset ?? options.Offset ?? ImagePixelExtractingTransformer.Defaults.Offset;
+ Scale = item.Scale ?? options.Scale ?? ImagePixelExtractingTransformer.Defaults.Scale;
Contracts.CheckUserArg(FloatUtils.IsFinite(Offset), nameof(item.Offset));
Contracts.CheckUserArg(FloatUtils.IsFiniteNonZero(Scale), nameof(item.Scale));
}
diff --git a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs
index 82fedef36b..e66e3bb814 100644
--- a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs
+++ b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs
@@ -51,7 +51,7 @@ internal static class Defaults
public const int ClustersCount = 5;
}
- public class Options : UnsupervisedLearnerInputBaseWithWeight
+ public sealed class Options : UnsupervisedLearnerInputBaseWithWeight
{
///
/// The number of clusters.
diff --git a/src/Microsoft.ML.LightGBM/LightGbmArguments.cs b/src/Microsoft.ML.LightGBM/LightGbmArguments.cs
index 0085ad31aa..8cf005f339 100644
--- a/src/Microsoft.ML.LightGBM/LightGbmArguments.cs
+++ b/src/Microsoft.ML.LightGBM/LightGbmArguments.cs
@@ -171,8 +171,8 @@ public class Options : ISupportBoosterParameterFactory
IBoosterParameter IComponentFactory.CreateComponent(IHostEnvironment env) => CreateComponent(env);
}
- internal TreeBooster(Options args)
- : base(args)
+ internal TreeBooster(Options options)
+ : base(options)
{
Contracts.CheckUserArg(Args.MinSplitGain >= 0, nameof(Args.MinSplitGain), "must be >= 0.");
Contracts.CheckUserArg(Args.MinChildWeight >= 0, nameof(Args.MinChildWeight), "must be >= 0.");
@@ -194,7 +194,7 @@ public sealed class DartBooster : BoosterParameter
internal const string FriendlyName = "Tree Dropout Tree Booster";
[TlcModule.Component(Name = Name, FriendlyName = FriendlyName, Desc = "Dropouts meet Multiple Additive Regresion Trees. See https://arxiv.org/abs/1505.01866")]
- public class Options : TreeBooster.Options
+ public sealed class Options : TreeBooster.Options
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Drop ratio for trees. Range:(0,1).")]
[TlcModule.Range(Inf = 0.0, Max = 1.0)]
@@ -217,8 +217,8 @@ public class Options : TreeBooster.Options
internal override IBoosterParameter CreateComponent(IHostEnvironment env) => new DartBooster(this);
}
- internal DartBooster(Options args)
- : base(args)
+ internal DartBooster(Options options)
+ : base(options)
{
Contracts.CheckUserArg(Args.DropRate > 0 && Args.DropRate < 1, nameof(Args.DropRate), "must be in (0,1).");
Contracts.CheckUserArg(Args.MaxDrop > 0, nameof(Args.MaxDrop), "must be > 0.");
@@ -238,7 +238,7 @@ public sealed class GossBooster : BoosterParameter
internal const string FriendlyName = "Gradient-based One-Size Sampling";
[TlcModule.Component(Name = Name, FriendlyName = FriendlyName, Desc = "Gradient-based One-Side Sampling.")]
- public class Options : TreeBooster.Options
+ public sealed class Options : TreeBooster.Options
{
[Argument(ArgumentType.AtMostOnce,
HelpText = "Retain ratio for large gradient instances.")]
@@ -254,8 +254,8 @@ public class Options : TreeBooster.Options
internal override IBoosterParameter CreateComponent(IHostEnvironment env) => new GossBooster(this);
}
- internal GossBooster(Options args)
- : base(args)
+ internal GossBooster(Options options)
+ : base(options)
{
Contracts.CheckUserArg(Args.TopRate > 0 && Args.TopRate < 1, nameof(Args.TopRate), "must be in (0,1).");
Contracts.CheckUserArg(Args.OtherRate >= 0 && Args.OtherRate < 1, nameof(Args.TopRate), "must be in [0,1).");
diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs
index 1c25e8ef08..6fb7cb70e3 100644
--- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs
+++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs
@@ -164,16 +164,16 @@ protected override void GetDefaultParameters(IChannel ch, int numRow, bool hasCa
{
base.GetDefaultParameters(ch, numRow, hasCategorical, totalCats, true);
int numLeaves = (int)Options["num_leaves"];
- int minDataPerLeaf = Args.MinDataPerLeaf ?? DefaultMinDataPerLeaf(numRow, numLeaves, _numClass);
+ int minDataPerLeaf = LightGbmTrainerOptions.MinDataPerLeaf ?? DefaultMinDataPerLeaf(numRow, numLeaves, _numClass);
Options["min_data_per_leaf"] = minDataPerLeaf;
if (!hiddenMsg)
{
- if (!Args.LearningRate.HasValue)
- ch.Info("Auto-tuning parameters: " + nameof(Args.LearningRate) + " = " + Options["learning_rate"]);
- if (!Args.NumLeaves.HasValue)
- ch.Info("Auto-tuning parameters: " + nameof(Args.NumLeaves) + " = " + numLeaves);
- if (!Args.MinDataPerLeaf.HasValue)
- ch.Info("Auto-tuning parameters: " + nameof(Args.MinDataPerLeaf) + " = " + minDataPerLeaf);
+ if (!LightGbmTrainerOptions.LearningRate.HasValue)
+ ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.LearningRate) + " = " + Options["learning_rate"]);
+ if (!LightGbmTrainerOptions.NumLeaves.HasValue)
+ ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.NumLeaves) + " = " + numLeaves);
+ if (!LightGbmTrainerOptions.MinDataPerLeaf.HasValue)
+ ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.MinDataPerLeaf) + " = " + minDataPerLeaf);
}
}
@@ -185,14 +185,14 @@ private protected override void CheckAndUpdateParametersBeforeTraining(IChannel
Options["num_class"] = _numClass;
bool useSoftmax = false;
- if (Args.UseSoftmax.HasValue)
- useSoftmax = Args.UseSoftmax.Value;
+ if (LightGbmTrainerOptions.UseSoftmax.HasValue)
+ useSoftmax = LightGbmTrainerOptions.UseSoftmax.Value;
else
{
if (labels.Length >= _minDataToUseSoftmax)
useSoftmax = true;
- ch.Info("Auto-tuning parameters: " + nameof(Args.UseSoftmax) + " = " + useSoftmax);
+ ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.UseSoftmax) + " = " + useSoftmax);
}
if (useSoftmax)
diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs
index febe35eb9d..67f8d42676 100644
--- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs
+++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs
@@ -40,7 +40,7 @@ private sealed class CategoricalMetaData
public bool[] IsCategoricalFeature;
}
- private protected readonly Options Args;
+ private protected readonly Options LightGbmTrainerOptions;
///
/// Stores argumments as objects to convert them to invariant string type in the end so that
@@ -69,21 +69,21 @@ private protected LightGbmTrainerBase(IHostEnvironment env,
int numBoostRound)
: base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn))
{
- Args = new Options();
+ LightGbmTrainerOptions = new Options();
- Args.NumLeaves = numLeaves;
- Args.MinDataPerLeaf = minDataPerLeaf;
- Args.LearningRate = learningRate;
- Args.NumBoostRound = numBoostRound;
+ LightGbmTrainerOptions.NumLeaves = numLeaves;
+ LightGbmTrainerOptions.MinDataPerLeaf = minDataPerLeaf;
+ LightGbmTrainerOptions.LearningRate = learningRate;
+ LightGbmTrainerOptions.NumBoostRound = numBoostRound;
- Args.LabelColumn = label.Name;
- Args.FeatureColumn = featureColumn;
+ LightGbmTrainerOptions.LabelColumn = label.Name;
+ LightGbmTrainerOptions.FeatureColumn = featureColumn;
if (weightColumn != null)
- Args.WeightColumn = Optional.Explicit(weightColumn);
+ LightGbmTrainerOptions.WeightColumn = Optional.Explicit(weightColumn);
if (groupIdColumn != null)
- Args.GroupIdColumn = Optional.Explicit(groupIdColumn);
+ LightGbmTrainerOptions.GroupIdColumn = Optional.Explicit(groupIdColumn);
InitParallelTraining();
}
@@ -93,7 +93,7 @@ private protected LightGbmTrainerBase(IHostEnvironment env, string name, Options
{
Host.CheckValue(options, nameof(options));
- Args = options;
+ LightGbmTrainerOptions = options;
InitParallelTraining();
}
@@ -132,8 +132,8 @@ private protected override TModel TrainModelCore(TrainContext context)
private void InitParallelTraining()
{
- Options = Args.ToDictionary(Host);
- ParallelTraining = Args.ParallelTrainer != null ? Args.ParallelTrainer.CreateComponent(Host) : new SingleTrainer();
+ Options = LightGbmTrainerOptions.ToDictionary(Host);
+ ParallelTraining = LightGbmTrainerOptions.ParallelTrainer != null ? LightGbmTrainerOptions.ParallelTrainer.CreateComponent(Host) : new SingleTrainer();
if (ParallelTraining.ParallelType() != "serial" && ParallelTraining.NumMachines() > 1)
{
@@ -170,20 +170,20 @@ private protected virtual void CheckDataValid(IChannel ch, RoleMappedData data)
protected virtual void GetDefaultParameters(IChannel ch, int numRow, bool hasCategarical, int totalCats, bool hiddenMsg=false)
{
- double learningRate = Args.LearningRate ?? DefaultLearningRate(numRow, hasCategarical, totalCats);
- int numLeaves = Args.NumLeaves ?? DefaultNumLeaves(numRow, hasCategarical, totalCats);
- int minDataPerLeaf = Args.MinDataPerLeaf ?? DefaultMinDataPerLeaf(numRow, numLeaves, 1);
+ double learningRate = LightGbmTrainerOptions.LearningRate ?? DefaultLearningRate(numRow, hasCategarical, totalCats);
+ int numLeaves = LightGbmTrainerOptions.NumLeaves ?? DefaultNumLeaves(numRow, hasCategarical, totalCats);
+ int minDataPerLeaf = LightGbmTrainerOptions.MinDataPerLeaf ?? DefaultMinDataPerLeaf(numRow, numLeaves, 1);
Options["learning_rate"] = learningRate;
Options["num_leaves"] = numLeaves;
Options["min_data_per_leaf"] = minDataPerLeaf;
if (!hiddenMsg)
{
- if (!Args.LearningRate.HasValue)
- ch.Info("Auto-tuning parameters: " + nameof(Args.LearningRate) + " = " + learningRate);
- if (!Args.NumLeaves.HasValue)
- ch.Info("Auto-tuning parameters: " + nameof(Args.NumLeaves) + " = " + numLeaves);
- if (!Args.MinDataPerLeaf.HasValue)
- ch.Info("Auto-tuning parameters: " + nameof(Args.MinDataPerLeaf) + " = " + minDataPerLeaf);
+ if (!LightGbmTrainerOptions.LearningRate.HasValue)
+ ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.LearningRate) + " = " + learningRate);
+ if (!LightGbmTrainerOptions.NumLeaves.HasValue)
+ ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.NumLeaves) + " = " + numLeaves);
+ if (!LightGbmTrainerOptions.MinDataPerLeaf.HasValue)
+ ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.MinDataPerLeaf) + " = " + minDataPerLeaf);
}
}
@@ -278,9 +278,9 @@ private CategoricalMetaData GetCategoricalMetaData(IChannel ch, RoleMappedData t
int[] categoricalFeatures = null;
const int useCatThreshold = 50000;
// Disable cat when data is too small, reduce the overfitting.
- bool useCat = Args.UseCat ?? numRow > useCatThreshold;
- if (!Args.UseCat.HasValue)
- ch.Info("Auto-tuning parameters: " + nameof(Args.UseCat) + " = " + useCat);
+ bool useCat = LightGbmTrainerOptions.UseCat ?? numRow > useCatThreshold;
+ if (!LightGbmTrainerOptions.UseCat.HasValue)
+ ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.UseCat) + " = " + useCat);
if (useCat)
{
var featureCol = trainData.Schema.Schema[DefaultColumnNames.Features];
@@ -329,7 +329,7 @@ private Dataset LoadTrainingData(IChannel ch, RoleMappedData trainData, out Cate
}
// Push rows into dataset.
- LoadDataset(ch, factory, dtrain, numRow, Args.BatchSize, catMetaData);
+ LoadDataset(ch, factory, dtrain, numRow, LightGbmTrainerOptions.BatchSize, catMetaData);
// Some checks.
CheckAndUpdateParametersBeforeTraining(ch, trainData, labels, groups);
@@ -353,7 +353,7 @@ private Dataset LoadValidationData(IChannel ch, Dataset dtrain, RoleMappedData v
Dataset dvalid = new Dataset(dtrain, numRow, labels, weights, groups);
// Push rows into dataset.
- LoadDataset(ch, factory, dvalid, numRow, Args.BatchSize, catMetaData);
+ LoadDataset(ch, factory, dvalid, numRow, LightGbmTrainerOptions.BatchSize, catMetaData);
return dvalid;
}
@@ -373,8 +373,8 @@ private void TrainCore(IChannel ch, IProgressChannel pch, Dataset dtrain, Catego
{
ch.Info("LightGBM objective={0}", Options["objective"]);
using (Booster bst = WrappedLightGbmTraining.Train(ch, pch, Options, dtrain,
- dvalid: dvalid, numIteration: Args.NumBoostRound,
- verboseEval: Args.VerboseEval, earlyStoppingRound: Args.EarlyStoppingRound))
+ dvalid: dvalid, numIteration: LightGbmTrainerOptions.NumBoostRound,
+ verboseEval: LightGbmTrainerOptions.VerboseEval, earlyStoppingRound: LightGbmTrainerOptions.EarlyStoppingRound))
{
TrainedEnsemble = bst.GetModel(catMetaData.CategoricalBoudaries);
}
diff --git a/src/Microsoft.ML.PCA/PcaTrainer.cs b/src/Microsoft.ML.PCA/PcaTrainer.cs
index e46592eec4..7b51f19ba8 100644
--- a/src/Microsoft.ML.PCA/PcaTrainer.cs
+++ b/src/Microsoft.ML.PCA/PcaTrainer.cs
@@ -48,7 +48,7 @@ public sealed class RandomizedPcaTrainer : TrainerEstimatorBase : TrainerEstimatorBase
where TTransformer : ISingleFeaturePredictionTransformer
where TModel : class
- where TArgs : LbfgsTrainerBase.ArgumentsBase, new ()
+ where TArgs : LbfgsTrainerBase.OptionsBase, new ()
{
- public abstract class ArgumentsBase : LearnerInputBaseWithWeight
+ public abstract class OptionsBase : LearnerInputBaseWithWeight
{
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 regularization weight", ShortName = "l2", SortOrder = 50)]
[TGUI(Label = "L2 Weight", Description = "Weight of L2 regularizer term", SuggestedSweeps = "0,0.1,1")]
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs
index 572def3f98..430c889550 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs
@@ -40,7 +40,7 @@ public sealed partial class LogisticRegression : LbfgsTrainerBase
/// If set to truetraining statistics will be generated at the end of training.
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
index f003566bc0..2a37b751a6 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
@@ -44,7 +44,7 @@ public sealed class MulticlassLogisticRegression : LbfgsTrainerBase
/// Arguments class for averaged linear trainers.
///
- public abstract class AveragedLinearArguments : OnlineLinearArguments
+ public abstract class AveragedLinearOptions : OnlineLinearOptions
{
///
/// Learning rate.
@@ -25,7 +25,7 @@ public abstract class AveragedLinearArguments : OnlineLinearArguments
[Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate", ShortName = "lr", SortOrder = 50)]
[TGUI(Label = "Learning rate", SuggestedSweeps = "0.01,0.1,0.5,1.0")]
[TlcModule.SweepableDiscreteParam("LearningRate", new object[] { 0.01, 0.1, 0.5, 1.0 })]
- public float LearningRate = AveragedDefaultArgs.LearningRate;
+ public float LearningRate = AveragedDefault.LearningRate;
///
/// Determine whether to decrease the or not.
@@ -37,7 +37,7 @@ public abstract class AveragedLinearArguments : OnlineLinearArguments
[Argument(ArgumentType.AtMostOnce, HelpText = "Decrease learning rate", ShortName = "decreaselr", SortOrder = 50)]
[TGUI(Label = "Decrease Learning Rate", Description = "Decrease learning rate as iterations progress")]
[TlcModule.SweepableDiscreteParam("DecreaseLearningRate", new object[] { false, true })]
- public bool DecreaseLearningRate = AveragedDefaultArgs.DecreaseLearningRate;
+ public bool DecreaseLearningRate = AveragedDefault.DecreaseLearningRate;
///
/// Number of examples after which weights will be reset to the current average.
@@ -65,7 +65,7 @@ public abstract class AveragedLinearArguments : OnlineLinearArguments
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 Regularization Weight", ShortName = "reg", SortOrder = 50)]
[TGUI(Label = "L2 Regularization Weight")]
[TlcModule.SweepableFloatParam("L2RegularizerWeight", 0.0f, 0.4f)]
- public float L2RegularizerWeight = AveragedDefaultArgs.L2RegularizerWeight;
+ public float L2RegularizerWeight = AveragedDefault.L2RegularizerWeight;
///
/// Extra weight given to more recent updates.
@@ -104,7 +104,7 @@ public abstract class AveragedLinearArguments : OnlineLinearArguments
internal float AveragedTolerance = (float)1e-2;
[BestFriend]
- internal class AveragedDefaultArgs : OnlineDefaultArgs
+ internal class AveragedDefault : OnlineLinearOptions.OnlineDefault
{
public const float LearningRate = 1;
public const bool DecreaseLearningRate = false;
@@ -118,7 +118,7 @@ public abstract class AveragedLinearTrainer : OnlineLinear
where TTransformer : ISingleFeaturePredictionTransformer
where TModel : class
{
- protected readonly new AveragedLinearArguments Args;
+ protected readonly AveragedLinearOptions AveragedLinearTrainerOptions;
protected IScalarOutputLoss LossFunction;
private protected abstract class AveragedTrainStateBase : TrainStateBase
@@ -140,7 +140,7 @@ private protected abstract class AveragedTrainStateBase : TrainStateBase
protected readonly bool Averaged;
private readonly long _resetWeightsAfterXExamples;
- private readonly AveragedLinearArguments _args;
+ private readonly AveragedLinearOptions _args;
private readonly IScalarOutputLoss _loss;
private protected AveragedTrainStateBase(IChannel ch, int numFeatures, LinearModelParameters predictor, AveragedLinearTrainer parent)
@@ -148,10 +148,10 @@ private protected AveragedTrainStateBase(IChannel ch, int numFeatures, LinearMod
{
// Do the other initializations by setting the setters as if user had set them
// Initialize the averaged weights if needed (i.e., do what happens when Averaged is set)
- Averaged = parent.Args.Averaged;
+ Averaged = parent.AveragedLinearTrainerOptions.Averaged;
if (Averaged)
{
- if (parent.Args.AveragedTolerance > 0)
+ if (parent.AveragedLinearTrainerOptions.AveragedTolerance > 0)
VBufferUtils.Densify(ref Weights);
Weights.CopyTo(ref TotalWeights);
}
@@ -161,8 +161,8 @@ private protected AveragedTrainStateBase(IChannel ch, int numFeatures, LinearMod
// to another vector with each update.
VBufferUtils.Densify(ref Weights);
}
- _resetWeightsAfterXExamples = parent.Args.ResetWeightsAfterXExamples ?? 0;
- _args = parent.Args;
+ _resetWeightsAfterXExamples = parent.AveragedLinearTrainerOptions.ResetWeightsAfterXExamples ?? 0;
+ _args = parent.AveragedLinearTrainerOptions;
_loss = parent.LossFunction;
Gain = 1;
@@ -295,20 +295,20 @@ private void IncrementAverageNonLazy()
}
}
- protected AveragedLinearTrainer(AveragedLinearArguments args, IHostEnvironment env, string name, SchemaShape.Column label)
- : base(args, env, name, label)
+ protected AveragedLinearTrainer(AveragedLinearOptions options, IHostEnvironment env, string name, SchemaShape.Column label)
+ : base(options, env, name, label)
{
- Contracts.CheckUserArg(args.LearningRate > 0, nameof(args.LearningRate), UserErrorPositive);
- Contracts.CheckUserArg(!args.ResetWeightsAfterXExamples.HasValue || args.ResetWeightsAfterXExamples > 0, nameof(args.ResetWeightsAfterXExamples), UserErrorPositive);
+ Contracts.CheckUserArg(options.LearningRate > 0, nameof(options.LearningRate), UserErrorPositive);
+ Contracts.CheckUserArg(!options.ResetWeightsAfterXExamples.HasValue || options.ResetWeightsAfterXExamples > 0, nameof(options.ResetWeightsAfterXExamples), UserErrorPositive);
// Weights are scaled down by 2 * L2 regularization on each update step, so 0.5 would scale all weights to 0, which is not sensible.
- Contracts.CheckUserArg(0 <= args.L2RegularizerWeight && args.L2RegularizerWeight < 0.5, nameof(args.L2RegularizerWeight), "must be in range [0, 0.5)");
- Contracts.CheckUserArg(args.RecencyGain >= 0, nameof(args.RecencyGain), UserErrorNonNegative);
- Contracts.CheckUserArg(args.AveragedTolerance >= 0, nameof(args.AveragedTolerance), UserErrorNonNegative);
+ Contracts.CheckUserArg(0 <= options.L2RegularizerWeight && options.L2RegularizerWeight < 0.5, nameof(options.L2RegularizerWeight), "must be in range [0, 0.5)");
+ Contracts.CheckUserArg(options.RecencyGain >= 0, nameof(options.RecencyGain), UserErrorNonNegative);
+ Contracts.CheckUserArg(options.AveragedTolerance >= 0, nameof(options.AveragedTolerance), UserErrorNonNegative);
// Verify user didn't specify parameters that conflict
- Contracts.Check(!args.DoLazyUpdates || !args.RecencyGainMulti && args.RecencyGain == 0, "Cannot have both recency gain and lazy updates.");
+ Contracts.Check(!options.DoLazyUpdates || !options.RecencyGainMulti && options.RecencyGain == 0, "Cannot have both recency gain and lazy updates.");
- Args = args;
+ AveragedLinearTrainerOptions = options;
}
}
-}
\ No newline at end of file
+}
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs
index 0fe7a7ea15..5ba6dfe2d7 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs
@@ -56,13 +56,13 @@ public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer
/// Options for the averaged perceptron trainer.
///
- public sealed class Options : AveragedLinearArguments
+ public sealed class Options : AveragedLinearOptions
{
///
/// A custom loss.
///
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
- public ISupportClassificationLossFactory LossFunction = new HingeLoss.Arguments();
+ public ISupportClassificationLossFactory LossFunction = new HingeLoss.Options();
///
/// The calibrator for producing probabilities. Default is exponential (aka Platt) calibration.
@@ -132,10 +132,10 @@ internal AveragedPerceptronTrainer(IHostEnvironment env,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
IClassificationLoss lossFunction = null,
- float learningRate = Options.AveragedDefaultArgs.LearningRate,
- bool decreaseLearningRate = Options.AveragedDefaultArgs.DecreaseLearningRate,
- float l2RegularizerWeight = Options.AveragedDefaultArgs.L2RegularizerWeight,
- int numIterations = Options.AveragedDefaultArgs.NumIterations)
+ float learningRate = Options.AveragedDefault.LearningRate,
+ bool decreaseLearningRate = Options.AveragedDefault.DecreaseLearningRate,
+ float l2RegularizerWeight = Options.AveragedDefault.L2RegularizerWeight,
+ int numIterations = Options.AveragedDefault.NumIterations)
: this(env, new Options
{
LabelColumn = labelColumn,
@@ -220,4 +220,4 @@ internal static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnviro
calibrator: input.Calibrator, maxCalibrationExamples: input.MaxCalibrationExamples);
}
}
-}
\ No newline at end of file
+}
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs
index e7b04996ac..14cc1ca85d 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs
@@ -40,7 +40,7 @@ public sealed class LinearSvmTrainer : OnlineLinearTrainer LossFunctionFactory => LossFunction;
[BestFriend]
- internal class OgdDefaultArgs : AveragedDefaultArgs
+ internal class OgdDefaultArgs : AveragedDefault
{
public new const float LearningRate = 0.1f;
public new const bool DecreaseLearningRate = true;
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs
index d8ff4b58be..76d0bd6ea0 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs
@@ -19,7 +19,7 @@ namespace Microsoft.ML.Trainers.Online
///
/// Arguments class for online linear trainers.
///
- public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel
+ public abstract class OnlineLinearOptions : LearnerInputBaseWithLabel
{
///
/// Number of passes through the training dataset.
@@ -27,7 +27,7 @@ public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of iterations", ShortName = "iter, numIterations", SortOrder = 50)]
[TGUI(Label = "Number of Iterations", Description = "Number of training iterations through data", SuggestedSweeps = "1,10,100")]
[TlcModule.SweepableLongParamAttribute("NumIterations", 1, 100, stepSize: 10, isLogScale: true)]
- public int NumberOfIterations = OnlineDefaultArgs.NumIterations;
+ public int NumberOfIterations = OnlineDefault.NumIterations;
///
/// Initial weights and bias, comma-separated.
@@ -60,7 +60,7 @@ public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel
public bool Shuffle = true;
[BestFriend]
- internal class OnlineDefaultArgs
+ internal class OnlineDefault
{
public const int NumIterations = 1;
}
@@ -70,7 +70,7 @@ public abstract class OnlineLinearTrainer : TrainerEstimat
where TTransformer : ISingleFeaturePredictionTransformer
where TModel : class
{
- protected readonly OnlineLinearArguments Args;
+ protected readonly OnlineLinearOptions OnlineLinearTrainerOptions;
protected readonly string Name;
///
@@ -136,10 +136,10 @@ protected TrainStateBase(IChannel ch, int numFeatures, LinearModelParameters pre
VBufferUtils.Densify(ref Weights);
Bias = predictor.Bias;
}
- else if (!string.IsNullOrWhiteSpace(parent.Args.InitialWeights))
+ else if (!string.IsNullOrWhiteSpace(parent.OnlineLinearTrainerOptions.InitialWeights))
{
- ch.Info("Initializing weights and bias to " + parent.Args.InitialWeights);
- string[] weightStr = parent.Args.InitialWeights.Split(',');
+ ch.Info("Initializing weights and bias to " + parent.OnlineLinearTrainerOptions.InitialWeights);
+ string[] weightStr = parent.OnlineLinearTrainerOptions.InitialWeights.Split(',');
if (weightStr.Length != numFeatures + 1)
{
throw ch.Except(
@@ -153,13 +153,13 @@ protected TrainStateBase(IChannel ch, int numFeatures, LinearModelParameters pre
Weights = new VBuffer(numFeatures, weightValues);
Bias = float.Parse(weightStr[numFeatures], CultureInfo.InvariantCulture);
}
- else if (parent.Args.InitialWeightsDiameter > 0)
+ else if (parent.OnlineLinearTrainerOptions.InitialWeightsDiameter > 0)
{
var weightValues = new float[numFeatures];
for (int i = 0; i < numFeatures; i++)
- weightValues[i] = parent.Args.InitialWeightsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5);
+ weightValues[i] = parent.OnlineLinearTrainerOptions.InitialWeightsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5);
Weights = new VBuffer(numFeatures, weightValues);
- Bias = parent.Args.InitialWeightsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5);
+ Bias = parent.OnlineLinearTrainerOptions.InitialWeightsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5);
}
else if (numFeatures <= 1000)
Weights = VBufferUtils.CreateDense(numFeatures);
@@ -253,14 +253,14 @@ public virtual float Margin(in VBuffer feat)
protected virtual bool NeedCalibration => false;
- private protected OnlineLinearTrainer(OnlineLinearArguments args, IHostEnvironment env, string name, SchemaShape.Column label)
- : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.InitialWeights))
+ private protected OnlineLinearTrainer(OnlineLinearOptions options, IHostEnvironment env, string name, SchemaShape.Column label)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(options.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(options.InitialWeights))
{
- Contracts.CheckValue(args, nameof(args));
- Contracts.CheckUserArg(args.NumberOfIterations > 0, nameof(args.NumberOfIterations), UserErrorPositive);
- Contracts.CheckUserArg(args.InitialWeightsDiameter >= 0, nameof(args.InitialWeightsDiameter), UserErrorNonNegative);
+ Contracts.CheckValue(options, nameof(options));
+ Contracts.CheckUserArg(options.NumberOfIterations > 0, nameof(options.NumberOfIterations), UserErrorPositive);
+ Contracts.CheckUserArg(options.InitialWeightsDiameter >= 0, nameof(options.InitialWeightsDiameter), UserErrorNonNegative);
- Args = args;
+ OnlineLinearTrainerOptions = options;
Name = name;
// REVIEW: Caching could be false for one iteration, if we got around the whole shuffling issue.
Info = new TrainerInfo(calibration: NeedCalibration, supportIncrementalTrain: true);
@@ -309,7 +309,7 @@ public TTransformer Fit(IDataView trainData, IPredictor initialPredictor)
private void TrainCore(IChannel ch, RoleMappedData data, TrainStateBase state)
{
- bool shuffle = Args.Shuffle;
+ bool shuffle = OnlineLinearTrainerOptions.Shuffle;
if (shuffle && !data.Data.CanShuffle)
{
ch.Warning("Training data does not support shuffling, so ignoring request to shuffle");
@@ -323,7 +323,7 @@ private void TrainCore(IChannel ch, RoleMappedData data, TrainStateBase state)
var cursorFactory = new FloatLabelCursor.Factory(data, cursorOpt);
long numBad = 0;
- while (state.Iteration < Args.NumberOfIterations)
+ while (state.Iteration < OnlineLinearTrainerOptions.NumberOfIterations)
{
state.BeginIteration(ch);
@@ -341,10 +341,10 @@ private void TrainCore(IChannel ch, RoleMappedData data, TrainStateBase state)
{
ch.Warning(
"Skipped {0} instances with missing features during training (over {1} iterations; {2} inst/iter)",
- numBad, Args.NumberOfIterations, numBad / Args.NumberOfIterations);
+ numBad, OnlineLinearTrainerOptions.NumberOfIterations, numBad / OnlineLinearTrainerOptions.NumberOfIterations);
}
}
private protected abstract TrainStateBase MakeState(IChannel ch, int numFeatures, LinearModelParameters predictor);
}
-}
\ No newline at end of file
+}
diff --git a/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs
index 4ee904c76e..d09a1d9633 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs
@@ -33,7 +33,7 @@ public sealed class PoissonRegression : LbfgsTrainerBase : StochasticTrainerBase
where TTransformer : ISingleFeaturePredictionTransformer
where TModel : class
- where TArgs : SdcaTrainerBase.ArgumentsBase, new()
+ where TArgs : SdcaTrainerBase.OptionsBase, new()
{
// REVIEW: Making it even faster and more accurate:
// 1. Train with not-too-many threads. nt = 2 or 4 seems to be good enough. Didn't seem additional benefit over more threads.
@@ -159,7 +159,7 @@ public abstract class SdcaTrainerBase : StochasticT
// 3. Don't "guess" the iteration to converge. It is very data-set dependent and hard to control. Always check for at least once to ensure convergence.
// 4. Use dual variable updates to infer whether a full iteration of convergence checking is necessary. Convergence checking iteration is time-consuming.
- public abstract class ArgumentsBase : LearnerInputBaseWithLabel
+ public abstract class OptionsBase : LearnerInputBaseWithLabel
{
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 regularizer constant. By default the l2 constant is automatically inferred based on data set.", NullName = "", ShortName = "l2", SortOrder = 1)]
[TGUI(Label = "L2 Regularizer Constant", SuggestedSweeps = ",1e-7,1e-6,1e-5,1e-4,1e-3,1e-2")]
@@ -293,7 +293,7 @@ private protected sealed override TModel TrainCore(IChannel ch, RoleMappedData d
if (Args.NumThreads.HasValue)
{
numThreads = Args.NumThreads.Value;
- Host.CheckUserArg(numThreads > 0, nameof(ArgumentsBase.NumThreads), "The number of threads must be either null or a positive integer.");
+ Host.CheckUserArg(numThreads > 0, nameof(OptionsBase.NumThreads), "The number of threads must be either null or a positive integer.");
if (0 < Host.ConcurrencyFactor && Host.ConcurrencyFactor < numThreads)
{
numThreads = Host.ConcurrencyFactor;
@@ -1417,7 +1417,7 @@ public abstract class SdcaBinaryTrainerBase :
public override TrainerInfo Info { get; }
- public class BinaryArgumentBase : ArgumentsBase
+ public class BinaryArgumentBase : OptionsBase
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Apply weight to the positive class, for imbalanced data", ShortName = "piw")]
public float PositiveInstanceWeight = 1;
diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs
index 1e61c63490..c0f7ea4446 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs
@@ -35,7 +35,7 @@ public class SdcaMultiClassTrainer : SdcaTrainerBase score, Scalar predictedLabel) AveragedPercept
Vector features,
Scalar weights = null,
IClassificationLoss lossFunction = null,
- float learningRate = AveragedLinearArguments.AveragedDefaultArgs.LearningRate,
- bool decreaseLearningRate = AveragedLinearArguments.AveragedDefaultArgs.DecreaseLearningRate,
- float l2RegularizerWeight = AveragedLinearArguments.AveragedDefaultArgs.L2RegularizerWeight,
- int numIterations = AveragedLinearArguments.AveragedDefaultArgs.NumIterations,
+ float learningRate = AveragedLinearOptions.AveragedDefault.LearningRate,
+ bool decreaseLearningRate = AveragedLinearOptions.AveragedDefault.DecreaseLearningRate,
+ float l2RegularizerWeight = AveragedLinearOptions.AveragedDefault.L2RegularizerWeight,
+ int numIterations = AveragedLinearOptions.AveragedDefault.NumIterations,
Action onFit = null
)
{
@@ -168,7 +168,7 @@ public static Scalar OnlineGradientDescent(this RegressionCatalog.Regress
float learningRate = OnlineGradientDescentTrainer.Options.OgdDefaultArgs.LearningRate,
bool decreaseLearningRate = OnlineGradientDescentTrainer.Options.OgdDefaultArgs.DecreaseLearningRate,
float l2RegularizerWeight = OnlineGradientDescentTrainer.Options.OgdDefaultArgs.L2RegularizerWeight,
- int numIterations = OnlineLinearArguments.OnlineDefaultArgs.NumIterations,
+ int numIterations = OnlineLinearOptions.OnlineDefault.NumIterations,
Action onFit = null)
{
OnlineLinearStaticUtils.CheckUserParams(label, features, weights, learningRate, l2RegularizerWeight, numIterations, onFit);
diff --git a/src/Microsoft.ML.StaticPipe/TextLoaderStatic.cs b/src/Microsoft.ML.StaticPipe/TextLoaderStatic.cs
index 5d6aeb6c20..4a114cd88b 100644
--- a/src/Microsoft.ML.StaticPipe/TextLoaderStatic.cs
+++ b/src/Microsoft.ML.StaticPipe/TextLoaderStatic.cs
@@ -45,7 +45,7 @@ public static DataReader CreateReader<[IsShape] TSha
env.CheckValueOrNull(files);
// Populate all args except the columns.
- var args = new TextLoader.Arguments();
+ var args = new TextLoader.Options();
args.AllowQuoting = allowQuoting;
args.AllowSparse = allowSparse;
args.HasHeader = hasHeader;
@@ -65,15 +65,15 @@ public static DataReader CreateReader<[IsShape] TSha
private sealed class TextReconciler : ReaderReconciler
{
- private readonly TextLoader.Arguments _args;
+ private readonly TextLoader.Options _args;
private readonly IMultiStreamSource _files;
- public TextReconciler(TextLoader.Arguments args, IMultiStreamSource files)
+ public TextReconciler(TextLoader.Options options, IMultiStreamSource files)
{
- Contracts.AssertValue(args);
+ Contracts.AssertValue(options);
Contracts.AssertValueOrNull(files);
- _args = args;
+ _args = options;
_files = files;
}
diff --git a/src/Microsoft.ML.Sweeper/Algorithms/Grid.cs b/src/Microsoft.ML.Sweeper/Algorithms/Grid.cs
index 1466b78ae4..281fdbc0ce 100644
--- a/src/Microsoft.ML.Sweeper/Algorithms/Grid.cs
+++ b/src/Microsoft.ML.Sweeper/Algorithms/Grid.cs
@@ -9,9 +9,9 @@
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Sweeper;
-[assembly: LoadableClass(typeof(RandomGridSweeper), typeof(RandomGridSweeper.Arguments), typeof(SignatureSweeper),
+[assembly: LoadableClass(typeof(RandomGridSweeper), typeof(RandomGridSweeper.Options), typeof(SignatureSweeper),
"Random Grid Sweeper", "RandomGridSweeper", "RandomGrid")]
-[assembly: LoadableClass(typeof(RandomGridSweeper), typeof(RandomGridSweeper.Arguments), typeof(SignatureSweeperFromParameterList),
+[assembly: LoadableClass(typeof(RandomGridSweeper), typeof(RandomGridSweeper.Options), typeof(SignatureSweeperFromParameterList),
"Random Grid Sweeper", "RandomGridSweeperParamList", "RandomGridpl")]
namespace Microsoft.ML.Sweeper
@@ -26,7 +26,7 @@ namespace Microsoft.ML.Sweeper
///
public abstract class SweeperBase : ISweeper
{
- public class ArgumentsBase
+ public class OptionsBase
{
[Argument(ArgumentType.Multiple, HelpText = "Swept parameters", ShortName = "p", SignatureType = typeof(SignatureSweeperParameter))]
public IComponentFactory[] SweptParameters;
@@ -35,11 +35,11 @@ public class ArgumentsBase
public int Retries = 10;
}
- private readonly ArgumentsBase _args;
+ private readonly OptionsBase _args;
protected readonly IValueGenerator[] SweepParameters;
protected readonly IHost Host;
- protected SweeperBase(ArgumentsBase args, IHostEnvironment env, string name)
+ protected SweeperBase(OptionsBase args, IHostEnvironment env, string name)
{
Contracts.CheckValue(env, nameof(env));
env.CheckNonWhiteSpace(name, nameof(name));
@@ -52,7 +52,7 @@ protected SweeperBase(ArgumentsBase args, IHostEnvironment env, string name)
SweepParameters = args.SweptParameters.Select(p => p.CreateComponent(Host)).ToArray();
}
- protected SweeperBase(ArgumentsBase args, IHostEnvironment env, IValueGenerator[] sweepParameters, string name)
+ protected SweeperBase(OptionsBase args, IHostEnvironment env, IValueGenerator[] sweepParameters, string name)
{
Contracts.CheckValue(env, nameof(env));
env.CheckNonWhiteSpace(name, nameof(name));
@@ -110,20 +110,20 @@ public sealed class RandomGridSweeper : SweeperBase
// This is a parallel array to the _permutation array and stores the (already generated) parameter sets
private readonly ParameterSet[] _cache;
- public sealed class Arguments : ArgumentsBase
+ public sealed class Options : OptionsBase
{
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Limit for the number of combinations to generate the entire grid.", ShortName = "maxpoints")]
public int MaxGridPoints = 1000000;
}
- public RandomGridSweeper(IHostEnvironment env, Arguments args)
- : base(args, env, "RandomGrid")
+ public RandomGridSweeper(IHostEnvironment env, Options options)
+ : base(options, env, "RandomGrid")
{
_nGridPoints = 1;
foreach (var sweptParameter in SweepParameters)
{
_nGridPoints *= sweptParameter.Count;
- if (_nGridPoints > args.MaxGridPoints)
+ if (_nGridPoints > options.MaxGridPoints)
_nGridPoints = 0;
}
if (_nGridPoints != 0)
@@ -133,14 +133,14 @@ public RandomGridSweeper(IHostEnvironment env, Arguments args)
}
}
- public RandomGridSweeper(IHostEnvironment env, Arguments args, IValueGenerator[] sweepParameters)
- : base(args, env, sweepParameters, "RandomGrid")
+ public RandomGridSweeper(IHostEnvironment env, Options options, IValueGenerator[] sweepParameters)
+ : base(options, env, sweepParameters, "RandomGrid")
{
_nGridPoints = 1;
foreach (var sweptParameter in SweepParameters)
{
_nGridPoints *= sweptParameter.Count;
- if (_nGridPoints > args.MaxGridPoints)
+ if (_nGridPoints > options.MaxGridPoints)
_nGridPoints = 0;
}
if (_nGridPoints != 0)
diff --git a/src/Microsoft.ML.Sweeper/Algorithms/KdoSweeper.cs b/src/Microsoft.ML.Sweeper/Algorithms/KdoSweeper.cs
index 469988ec7c..cf413f5aad 100644
--- a/src/Microsoft.ML.Sweeper/Algorithms/KdoSweeper.cs
+++ b/src/Microsoft.ML.Sweeper/Algorithms/KdoSweeper.cs
@@ -12,7 +12,7 @@
using Microsoft.ML.Trainers.FastTree;
using Float = System.Single;
-[assembly: LoadableClass(typeof(KdoSweeper), typeof(KdoSweeper.Arguments), typeof(SignatureSweeper),
+[assembly: LoadableClass(typeof(KdoSweeper), typeof(KdoSweeper.Options), typeof(SignatureSweeper),
"KDO Sweeper", "KDOSweeper", "KDO")]
namespace Microsoft.ML.Sweeper.Algorithms
@@ -36,7 +36,7 @@ namespace Microsoft.ML.Sweeper.Algorithms
public sealed class KdoSweeper : ISweeper
{
- public sealed class Arguments
+ public sealed class Options
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Swept parameters", ShortName = "p", SignatureType = typeof(SignatureSweeperParameter))]
public IComponentFactory[] SweptParameters;
@@ -77,7 +77,7 @@ public sealed class Arguments
private readonly ISweeper _randomSweeper;
private readonly ISweeper _redundantSweeper;
- private readonly Arguments _args;
+ private readonly Options _args;
private readonly IHost _host;
private readonly IValueGenerator[] _sweepParameters;
@@ -85,22 +85,22 @@ public sealed class Arguments
private readonly SortedSet _alreadySeenConfigs;
private readonly List _randomParamSets;
- public KdoSweeper(IHostEnvironment env, Arguments args)
+ public KdoSweeper(IHostEnvironment env, Options options)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register("Sweeper");
- _host.CheckUserArg(args.NumberInitialPopulation > 1, nameof(args.NumberInitialPopulation), "Must be greater than 1");
- _host.CheckUserArg(args.HistoryLength > 1, nameof(args.HistoryLength), "Must be greater than 1");
- _host.CheckUserArg(args.MinimumMutationSpread >= 0, nameof(args.MinimumMutationSpread), "Must be nonnegative");
- _host.CheckUserArg(0 <= args.ProportionRandom && args.ProportionRandom <= 1, nameof(args.ProportionRandom), "Must be in [0, 1]");
- _host.CheckUserArg(args.WeightRescalingPower >= 1, nameof(args.WeightRescalingPower), "Must be greater or equal to 1");
-
- _args = args;
- _host.CheckUserArg(Utils.Size(args.SweptParameters) > 0, nameof(args.SweptParameters), "KDO sweeper needs at least one parameter to sweep over");
- _sweepParameters = args.SweptParameters.Select(p => p.CreateComponent(_host)).ToArray();
- _randomSweeper = new UniformRandomSweeper(env, new SweeperBase.ArgumentsBase(), _sweepParameters);
- _redundantSweeper = new UniformRandomSweeper(env, new SweeperBase.ArgumentsBase { Retries = 0 }, _sweepParameters);
+ _host.CheckUserArg(options.NumberInitialPopulation > 1, nameof(options.NumberInitialPopulation), "Must be greater than 1");
+ _host.CheckUserArg(options.HistoryLength > 1, nameof(options.HistoryLength), "Must be greater than 1");
+ _host.CheckUserArg(options.MinimumMutationSpread >= 0, nameof(options.MinimumMutationSpread), "Must be nonnegative");
+ _host.CheckUserArg(0 <= options.ProportionRandom && options.ProportionRandom <= 1, nameof(options.ProportionRandom), "Must be in [0, 1]");
+ _host.CheckUserArg(options.WeightRescalingPower >= 1, nameof(options.WeightRescalingPower), "Must be greater or equal to 1");
+
+ _args = options;
+ _host.CheckUserArg(Utils.Size(options.SweptParameters) > 0, nameof(options.SweptParameters), "KDO sweeper needs at least one parameter to sweep over");
+ _sweepParameters = options.SweptParameters.Select(p => p.CreateComponent(_host)).ToArray();
+ _randomSweeper = new UniformRandomSweeper(env, new SweeperBase.OptionsBase(), _sweepParameters);
+ _redundantSweeper = new UniformRandomSweeper(env, new SweeperBase.OptionsBase { Retries = 0 }, _sweepParameters);
_spu = new SweeperProbabilityUtils(_host);
_alreadySeenConfigs = new SortedSet(new FloatArrayComparer());
_randomParamSets = new List();
diff --git a/src/Microsoft.ML.Sweeper/Algorithms/NelderMead.cs b/src/Microsoft.ML.Sweeper/Algorithms/NelderMead.cs
index 51dd6c91f0..14c58ca49a 100644
--- a/src/Microsoft.ML.Sweeper/Algorithms/NelderMead.cs
+++ b/src/Microsoft.ML.Sweeper/Algorithms/NelderMead.cs
@@ -11,20 +11,20 @@
using Microsoft.ML.Sweeper;
using Float = System.Single;
-[assembly: LoadableClass(typeof(NelderMeadSweeper), typeof(NelderMeadSweeper.Arguments), typeof(SignatureSweeper),
+[assembly: LoadableClass(typeof(NelderMeadSweeper), typeof(NelderMeadSweeper.Options), typeof(SignatureSweeper),
"Nelder Mead Sweeper", "NelderMeadSweeper", "NelderMead", "NM")]
namespace Microsoft.ML.Sweeper
{
public sealed class NelderMeadSweeper : ISweeper
{
- public sealed class Arguments
+ public sealed class Options
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Swept parameters", ShortName = "p", SignatureType = typeof(SignatureSweeperParameter))]
public IComponentFactory[] SweptParameters;
[Argument(ArgumentType.LastOccurenceWins, HelpText = "The sweeper used to get the initial results.", ShortName = "init", SignatureType = typeof(SignatureSweeperFromParameterList))]
- public IComponentFactory FirstBatchSweeper = ComponentFactoryUtils.CreateFromFunction((host, array) => new UniformRandomSweeper(host, new SweeperBase.ArgumentsBase(), array));
+ public IComponentFactory FirstBatchSweeper = ComponentFactoryUtils.CreateFromFunction((host, array) => new UniformRandomSweeper(host, new SweeperBase.OptionsBase(), array));
[Argument(ArgumentType.AtMostOnce, HelpText = "Seed for the random number generator for the first batch sweeper", ShortName = "seed")]
public int RandomSeed;
@@ -66,7 +66,7 @@ private enum OptimizationStage
}
private readonly ISweeper _initSweeper;
- private readonly Arguments _args;
+ private readonly Options _args;
private SortedList _simplexVertices;
private readonly int _dim;
@@ -84,21 +84,21 @@ private enum OptimizationStage
private readonly List _sweepParameters;
- public NelderMeadSweeper(IHostEnvironment env, Arguments args)
+ public NelderMeadSweeper(IHostEnvironment env, Options options)
{
Contracts.CheckValue(env, nameof(env));
- env.CheckUserArg(-1 < args.DeltaInsideContraction, nameof(args.DeltaInsideContraction), "Must be greater than -1");
- env.CheckUserArg(args.DeltaInsideContraction < 0, nameof(args.DeltaInsideContraction), "Must be less than 0");
- env.CheckUserArg(0 < args.DeltaOutsideContraction, nameof(args.DeltaOutsideContraction), "Must be greater than 0");
- env.CheckUserArg(args.DeltaReflection > args.DeltaOutsideContraction, nameof(args.DeltaReflection), "Must be greater than " + nameof(args.DeltaOutsideContraction));
- env.CheckUserArg(args.DeltaExpansion > args.DeltaReflection, nameof(args.DeltaExpansion), "Must be greater than " + nameof(args.DeltaReflection));
- env.CheckUserArg(0 < args.GammaShrink && args.GammaShrink < 1, nameof(args.GammaShrink), "Must be between 0 and 1");
- env.CheckValue(args.FirstBatchSweeper, nameof(args.FirstBatchSweeper) , "First Batch Sweeper Contains Null Value");
+ env.CheckUserArg(-1 < options.DeltaInsideContraction, nameof(options.DeltaInsideContraction), "Must be greater than -1");
+ env.CheckUserArg(options.DeltaInsideContraction < 0, nameof(options.DeltaInsideContraction), "Must be less than 0");
+ env.CheckUserArg(0 < options.DeltaOutsideContraction, nameof(options.DeltaOutsideContraction), "Must be greater than 0");
+ env.CheckUserArg(options.DeltaReflection > options.DeltaOutsideContraction, nameof(options.DeltaReflection), "Must be greater than " + nameof(options.DeltaOutsideContraction));
+ env.CheckUserArg(options.DeltaExpansion > options.DeltaReflection, nameof(options.DeltaExpansion), "Must be greater than " + nameof(options.DeltaReflection));
+ env.CheckUserArg(0 < options.GammaShrink && options.GammaShrink < 1, nameof(options.GammaShrink), "Must be between 0 and 1");
+ env.CheckValue(options.FirstBatchSweeper, nameof(options.FirstBatchSweeper) , "First Batch Sweeper Contains Null Value");
- _args = args;
+ _args = options;
_sweepParameters = new List();
- foreach (var sweptParameter in args.SweptParameters)
+ foreach (var sweptParameter in options.SweptParameters)
{
var parameter = sweptParameter.CreateComponent(env);
// REVIEW: ideas about how to support discrete values:
@@ -108,13 +108,13 @@ public NelderMeadSweeper(IHostEnvironment env, Arguments args)
// the metric values that we get when using them. (For example, if, for a given discrete value, we get a bad result,
// we lower its weight, but if we get a good result we increase its weight).
var parameterNumeric = parameter as INumericValueGenerator;
- env.CheckUserArg(parameterNumeric != null, nameof(args.SweptParameters), "Nelder-Mead sweeper can only sweep over numeric parameters");
+ env.CheckUserArg(parameterNumeric != null, nameof(options.SweptParameters), "Nelder-Mead sweeper can only sweep over numeric parameters");
_sweepParameters.Add(parameterNumeric);
}
- _initSweeper = args.FirstBatchSweeper.CreateComponent(env, _sweepParameters.ToArray());
+ _initSweeper = options.FirstBatchSweeper.CreateComponent(env, _sweepParameters.ToArray());
_dim = _sweepParameters.Count;
- env.CheckUserArg(_dim > 1, nameof(args.SweptParameters), "Nelder-Mead sweeper needs at least two parameters to sweep over.");
+ env.CheckUserArg(_dim > 1, nameof(options.SweptParameters), "Nelder-Mead sweeper needs at least two parameters to sweep over.");
_simplexVertices = new SortedList(new SimplexVertexComparer());
_stage = OptimizationStage.NeedReflectionPoint;
diff --git a/src/Microsoft.ML.Sweeper/Algorithms/Random.cs b/src/Microsoft.ML.Sweeper/Algorithms/Random.cs
index 6711696bf4..105c6b648d 100644
--- a/src/Microsoft.ML.Sweeper/Algorithms/Random.cs
+++ b/src/Microsoft.ML.Sweeper/Algorithms/Random.cs
@@ -6,9 +6,9 @@
using Microsoft.ML;
using Microsoft.ML.Sweeper;
-[assembly: LoadableClass(typeof(UniformRandomSweeper), typeof(SweeperBase.ArgumentsBase), typeof(SignatureSweeper),
+[assembly: LoadableClass(typeof(UniformRandomSweeper), typeof(SweeperBase.OptionsBase), typeof(SignatureSweeper),
"Uniform Random Sweeper", "UniformRandomSweeper", "UniformRandom")]
-[assembly: LoadableClass(typeof(UniformRandomSweeper), typeof(SweeperBase.ArgumentsBase), typeof(SignatureSweeperFromParameterList),
+[assembly: LoadableClass(typeof(UniformRandomSweeper), typeof(SweeperBase.OptionsBase), typeof(SignatureSweeperFromParameterList),
"Uniform Random Sweeper", "UniformRandomSweeperParamList", "UniformRandompl")]
namespace Microsoft.ML.Sweeper
@@ -18,12 +18,12 @@ namespace Microsoft.ML.Sweeper
///
public sealed class UniformRandomSweeper : SweeperBase
{
- public UniformRandomSweeper(IHostEnvironment env, ArgumentsBase args)
+ public UniformRandomSweeper(IHostEnvironment env, OptionsBase args)
: base(args, env, "UniformRandom")
{
}
- public UniformRandomSweeper(IHostEnvironment env, ArgumentsBase args, IValueGenerator[] sweepParameters)
+ public UniformRandomSweeper(IHostEnvironment env, OptionsBase args, IValueGenerator[] sweepParameters)
: base(args, env, sweepParameters, "UniformRandom")
{
}
diff --git a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs
index b7b1c9563e..6053bdc4bd 100644
--- a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs
+++ b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs
@@ -15,7 +15,7 @@
using Microsoft.ML.Trainers.FastTree;
using Float = System.Single;
-[assembly: LoadableClass(typeof(SmacSweeper), typeof(SmacSweeper.Arguments), typeof(SignatureSweeper),
+[assembly: LoadableClass(typeof(SmacSweeper), typeof(SmacSweeper.Options), typeof(SignatureSweeper),
"SMAC Sweeper", "SMACSweeper", "SMAC")]
namespace Microsoft.ML.Sweeper
@@ -24,7 +24,7 @@ namespace Microsoft.ML.Sweeper
//encapsulating common functionality. This seems like a good plan to persue.
public sealed class SmacSweeper : ISweeper
{
- public sealed class Arguments
+ public sealed class Options
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Swept parameters", ShortName = "p", SignatureType = typeof(SignatureSweeperParameter))]
public IComponentFactory[] SweptParameters;
@@ -61,28 +61,28 @@ public sealed class Arguments
}
private readonly ISweeper _randomSweeper;
- private readonly Arguments _args;
+ private readonly Options _args;
private readonly IHost _host;
private readonly IValueGenerator[] _sweepParameters;
- public SmacSweeper(IHostEnvironment env, Arguments args)
+ public SmacSweeper(IHostEnvironment env, Options options)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register("Sweeper");
- _host.CheckUserArg(args.NumOfTrees > 0, nameof(args.NumOfTrees), "parameter must be greater than 0");
- _host.CheckUserArg(args.NMinForSplit > 1, nameof(args.NMinForSplit), "parameter must be greater than 1");
- _host.CheckUserArg(args.SplitRatio > 0 && args.SplitRatio <= 1, nameof(args.SplitRatio), "parameter must be in range (0,1].");
- _host.CheckUserArg(args.NumberInitialPopulation > 1, nameof(args.NumberInitialPopulation), "parameter must be greater than 1");
- _host.CheckUserArg(args.LocalSearchParentCount > 0, nameof(args.LocalSearchParentCount), "parameter must be greater than 0");
- _host.CheckUserArg(args.NumRandomEISearchConfigurations > 0, nameof(args.NumRandomEISearchConfigurations), "parameter must be greater than 0");
- _host.CheckUserArg(args.NumNeighborsForNumericalParams > 0, nameof(args.NumNeighborsForNumericalParams), "parameter must be greater than 0");
-
- _args = args;
- _host.CheckUserArg(Utils.Size(args.SweptParameters) > 0, nameof(args.SweptParameters), "SMAC sweeper needs at least one parameter to sweep over");
- _sweepParameters = args.SweptParameters.Select(p => p.CreateComponent(_host)).ToArray();
- _randomSweeper = new UniformRandomSweeper(env, new SweeperBase.ArgumentsBase(), _sweepParameters);
+ _host.CheckUserArg(options.NumOfTrees > 0, nameof(options.NumOfTrees), "parameter must be greater than 0");
+ _host.CheckUserArg(options.NMinForSplit > 1, nameof(options.NMinForSplit), "parameter must be greater than 1");
+ _host.CheckUserArg(options.SplitRatio > 0 && options.SplitRatio <= 1, nameof(options.SplitRatio), "parameter must be in range (0,1].");
+ _host.CheckUserArg(options.NumberInitialPopulation > 1, nameof(options.NumberInitialPopulation), "parameter must be greater than 1");
+ _host.CheckUserArg(options.LocalSearchParentCount > 0, nameof(options.LocalSearchParentCount), "parameter must be greater than 0");
+ _host.CheckUserArg(options.NumRandomEISearchConfigurations > 0, nameof(options.NumRandomEISearchConfigurations), "parameter must be greater than 0");
+ _host.CheckUserArg(options.NumNeighborsForNumericalParams > 0, nameof(options.NumNeighborsForNumericalParams), "parameter must be greater than 0");
+
+ _args = options;
+ _host.CheckUserArg(Utils.Size(options.SweptParameters) > 0, nameof(options.SweptParameters), "SMAC sweeper needs at least one parameter to sweep over");
+ _sweepParameters = options.SweptParameters.Select(p => p.CreateComponent(_host)).ToArray();
+ _randomSweeper = new UniformRandomSweeper(env, new SweeperBase.OptionsBase(), _sweepParameters);
}
public ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable previousRuns = null)
diff --git a/src/Microsoft.ML.Sweeper/AsyncSweeper.cs b/src/Microsoft.ML.Sweeper/AsyncSweeper.cs
index 4713bb54dd..579f20c53b 100644
--- a/src/Microsoft.ML.Sweeper/AsyncSweeper.cs
+++ b/src/Microsoft.ML.Sweeper/AsyncSweeper.cs
@@ -13,11 +13,11 @@
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Sweeper;
-[assembly: LoadableClass(typeof(SimpleAsyncSweeper), typeof(SweeperBase.ArgumentsBase), typeof(SignatureAsyncSweeper),
+[assembly: LoadableClass(typeof(SimpleAsyncSweeper), typeof(SweeperBase.OptionsBase), typeof(SignatureAsyncSweeper),
"Asynchronous Uniform Random Sweeper", "UniformRandomSweeper", "UniformRandom")]
-[assembly: LoadableClass(typeof(SimpleAsyncSweeper), typeof(RandomGridSweeper.Arguments), typeof(SignatureAsyncSweeper),
+[assembly: LoadableClass(typeof(SimpleAsyncSweeper), typeof(RandomGridSweeper.Options), typeof(SignatureAsyncSweeper),
"Asynchronous Random Grid Sweeper", "RandomGridSweeper", "RandomGrid")]
-[assembly: LoadableClass(typeof(DeterministicSweeperAsync), typeof(DeterministicSweeperAsync.Arguments), typeof(SignatureAsyncSweeper),
+[assembly: LoadableClass(typeof(DeterministicSweeperAsync), typeof(DeterministicSweeperAsync.Options), typeof(SignatureAsyncSweeper),
"Asynchronous and Deterministic Sweeper", "DeterministicSweeper", "Deterministic")]
namespace Microsoft.ML.Sweeper
@@ -88,13 +88,13 @@ private SimpleAsyncSweeper(ISweeper baseSweeper)
_results = new List();
}
- public SimpleAsyncSweeper(IHostEnvironment env, UniformRandomSweeper.ArgumentsBase args)
- : this(new UniformRandomSweeper(env, args))
+ public SimpleAsyncSweeper(IHostEnvironment env, UniformRandomSweeper.OptionsBase options)
+ : this(new UniformRandomSweeper(env, options))
{
}
- public SimpleAsyncSweeper(IHostEnvironment env, RandomGridSweeper.Arguments args)
- : this(new UniformRandomSweeper(env, args))
+ public SimpleAsyncSweeper(IHostEnvironment env, RandomGridSweeper.Options options)
+ : this(new UniformRandomSweeper(env, options))
{
}
@@ -150,7 +150,7 @@ public void Dispose()
///
public sealed class DeterministicSweeperAsync : IAsyncSweeper, IDisposable
{
- public sealed class Arguments
+ public sealed class Options
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Base sweeper", ShortName = "sweeper", SignatureType = typeof(SignatureSweeper))]
public IComponentFactory Sweeper;
@@ -190,21 +190,21 @@ public sealed class Arguments
// The number of ParameterSets generated so far. Used for indexing.
private int _numGenerated;
- public DeterministicSweeperAsync(IHostEnvironment env, Arguments args)
+ public DeterministicSweeperAsync(IHostEnvironment env, Options options)
{
- _host = env.Register("DeterministicSweeperAsync", args.RandomSeed);
- _host.CheckValue(args.Sweeper, nameof(args.Sweeper), "Please specify a sweeper");
- _host.CheckUserArg(args.BatchSize > 0, nameof(args.BatchSize), "Batch size must be positive");
- _host.CheckUserArg(args.Relaxation >= 0, nameof(args.Relaxation), "Synchronization relaxation must be non-negative");
- _host.CheckUserArg(args.Relaxation <= args.BatchSize, nameof(args.Relaxation),
+ _host = env.Register("DeterministicSweeperAsync", options.RandomSeed);
+ _host.CheckValue(options.Sweeper, nameof(options.Sweeper), "Please specify a sweeper");
+ _host.CheckUserArg(options.BatchSize > 0, nameof(options.BatchSize), "Batch size must be positive");
+ _host.CheckUserArg(options.Relaxation >= 0, nameof(options.Relaxation), "Synchronization relaxation must be non-negative");
+ _host.CheckUserArg(options.Relaxation <= options.BatchSize, nameof(options.Relaxation),
"Synchronization relaxation cannot be larger than batch size");
- _batchSize = args.BatchSize;
- _baseSweeper = args.Sweeper.CreateComponent(_host);
- _host.CheckUserArg(!(_baseSweeper is NelderMeadSweeper) || args.Relaxation == 0, nameof(args.Relaxation),
+ _batchSize = options.BatchSize;
+ _baseSweeper = options.Sweeper.CreateComponent(_host);
+ _host.CheckUserArg(!(_baseSweeper is NelderMeadSweeper) || options.Relaxation == 0, nameof(options.Relaxation),
"Nelder-Mead requires full synchronization (relaxation = 0)");
_cts = new CancellationTokenSource();
- _relaxation = args.Relaxation;
+ _relaxation = options.Relaxation;
_lock = new object();
_results = new List();
_nullRuns = new HashSet();
diff --git a/src/Microsoft.ML.Sweeper/ConfigRunner.cs b/src/Microsoft.ML.Sweeper/ConfigRunner.cs
index cbe2a99900..5f2e5269f7 100644
--- a/src/Microsoft.ML.Sweeper/ConfigRunner.cs
+++ b/src/Microsoft.ML.Sweeper/ConfigRunner.cs
@@ -14,7 +14,7 @@
using ResultProcessorInternal = Microsoft.ML.ResultProcessor;
-[assembly: LoadableClass(typeof(LocalExeConfigRunner), typeof(LocalExeConfigRunner.Arguments), typeof(SignatureConfigRunner),
+[assembly: LoadableClass(typeof(LocalExeConfigRunner), typeof(LocalExeConfigRunner.Options), typeof(SignatureConfigRunner),
"Local Sweep Config Runner", "Local")]
namespace Microsoft.ML.Sweeper
@@ -30,7 +30,7 @@ public interface IConfigRunner
public abstract class ExeConfigRunnerBase : IConfigRunner
{
- public abstract class ArgumentsBase
+ public abstract class OptionsBase
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Command pattern for the sweeps", ShortName = "pattern")]
public string ArgsPattern;
@@ -46,7 +46,7 @@ public abstract class ArgumentsBase
[Argument(ArgumentType.Multiple, HelpText = "Specify how to extract the metrics from the result file.", ShortName = "ev", SignatureType = typeof(SignatureSweepResultEvaluator))]
public IComponentFactory> ResultProcessor = ComponentFactoryUtils.CreateFromFunction(
- env => new InternalSweepResultEvaluator(env, new InternalSweepResultEvaluator.Arguments()));
+ env => new InternalSweepResultEvaluator(env, new InternalSweepResultEvaluator.Options()));
[Argument(ArgumentType.AtMostOnce, Hide = true)]
public bool CalledFromUnitTestSuite;
@@ -63,7 +63,7 @@ public abstract class ArgumentsBase
private readonly bool _calledFromUnitTestSuite;
- protected ExeConfigRunnerBase(ArgumentsBase args, IHostEnvironment env, string registrationName)
+ protected ExeConfigRunnerBase(OptionsBase args, IHostEnvironment env, string registrationName)
{
Contracts.AssertValue(env);
Host = env.Register(registrationName);
@@ -82,7 +82,7 @@ protected virtual void ProcessFullExePath(string exe)
Exe = GetFullExePath(exe);
if (!File.Exists(Exe) && !File.Exists(Exe + ".exe"))
- throw Host.ExceptUserArg(nameof(ArgumentsBase.Exe), "Executable {0} not found", Exe);
+ throw Host.ExceptUserArg(nameof(OptionsBase.Exe), "Executable {0} not found", Exe);
}
protected virtual string GetFullExePath(string exe)
@@ -180,7 +180,7 @@ protected string GetFilePath(int i, string kind)
public sealed class LocalExeConfigRunner : ExeConfigRunnerBase
{
- public sealed class Arguments : ArgumentsBase
+ public sealed class Options : OptionsBase
{
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of threads to use for the sweep (default auto determined by the number of cores)", ShortName = "t")]
public int? NumThreads;
@@ -188,13 +188,13 @@ public sealed class Arguments : ArgumentsBase
private readonly ParallelOptions _parallelOptions;
- public LocalExeConfigRunner(IHostEnvironment env, Arguments args)
- : base(args, env, "LocalExeSweepEvaluator")
+ public LocalExeConfigRunner(IHostEnvironment env, Options options)
+ : base(options, env, "LocalExeSweepEvaluator")
{
- Contracts.CheckParam(args.NumThreads == null || args.NumThreads.Value > 0, nameof(args.NumThreads), "Cannot be 0 or negative");
- _parallelOptions = new ParallelOptions { MaxDegreeOfParallelism = args.NumThreads ?? -1 };
- Contracts.AssertNonEmpty(args.OutputFolderName);
- ProcessFullExePath(args.Exe);
+ Contracts.CheckParam(options.NumThreads == null || options.NumThreads.Value > 0, nameof(options.NumThreads), "Cannot be 0 or negative");
+ _parallelOptions = new ParallelOptions { MaxDegreeOfParallelism = options.NumThreads ?? -1 };
+ Contracts.AssertNonEmpty(options.OutputFolderName);
+ ProcessFullExePath(options.Exe);
}
protected override IEnumerable RunConfigsCore(ParameterSet[] sweeps, IChannel ch, int min)
diff --git a/src/Microsoft.ML.Sweeper/SweepCommand.cs b/src/Microsoft.ML.Sweeper/SweepCommand.cs
index c1959272fb..2e27338040 100644
--- a/src/Microsoft.ML.Sweeper/SweepCommand.cs
+++ b/src/Microsoft.ML.Sweeper/SweepCommand.cs
@@ -24,7 +24,7 @@ public sealed class Arguments
{
[Argument(ArgumentType.Multiple, HelpText = "Config runner", ShortName = "run,ev,evaluator", SignatureType = typeof(SignatureConfigRunner))]
public IComponentFactory Runner = ComponentFactoryUtils.CreateFromFunction(
- env => new LocalExeConfigRunner(env, new LocalExeConfigRunner.Arguments()));
+ env => new LocalExeConfigRunner(env, new LocalExeConfigRunner.Options()));
[Argument(ArgumentType.Multiple, HelpText = "Sweeper", ShortName = "s", SignatureType = typeof(SignatureSweeper))]
public IComponentFactory Sweeper;
diff --git a/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs b/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs
index a55f8d4647..5b5f298009 100644
--- a/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs
+++ b/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs
@@ -10,14 +10,14 @@
using Microsoft.ML.Sweeper;
using ResultProcessor = Microsoft.ML.ResultProcessor;
-[assembly: LoadableClass(typeof(InternalSweepResultEvaluator), typeof(InternalSweepResultEvaluator.Arguments), typeof(SignatureSweepResultEvaluator),
+[assembly: LoadableClass(typeof(InternalSweepResultEvaluator), typeof(InternalSweepResultEvaluator.Options), typeof(SignatureSweepResultEvaluator),
"TLC Sweep Result Evaluator", "TlcEvaluator", "Tlc")]
namespace Microsoft.ML.Sweeper
{
public class InternalSweepResultEvaluator : ISweepResultEvaluator
{
- public class Arguments
+ public sealed class Options
{
[Argument(ArgumentType.LastOccurenceWins, HelpText = "The sweeper used to get the initial results.", ShortName = "m")]
public string Metric = "AUC";
@@ -28,7 +28,7 @@ public class Arguments
private readonly IHost _host;
- public InternalSweepResultEvaluator(IHostEnvironment env, Arguments args)
+ public InternalSweepResultEvaluator(IHostEnvironment env, Options args)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register("InternalSweepResultEvaluator");
diff --git a/src/Microsoft.ML.Sweeper/SynthConfigRunner.cs b/src/Microsoft.ML.Sweeper/SynthConfigRunner.cs
index d112c0659e..21555277af 100644
--- a/src/Microsoft.ML.Sweeper/SynthConfigRunner.cs
+++ b/src/Microsoft.ML.Sweeper/SynthConfigRunner.cs
@@ -11,7 +11,7 @@
using Microsoft.ML.CommandLine;
using Microsoft.ML.Sweeper;
-[assembly: LoadableClass(typeof(SynthConfigRunner), typeof(SynthConfigRunner.Arguments), typeof(SignatureConfigRunner),
+[assembly: LoadableClass(typeof(SynthConfigRunner), typeof(SynthConfigRunner.Options), typeof(SignatureConfigRunner),
"", "Synth")]
namespace Microsoft.ML.Sweeper
@@ -22,7 +22,7 @@ namespace Microsoft.ML.Sweeper
///
public sealed class SynthConfigRunner : ExeConfigRunnerBase
{
- public sealed class Arguments : ArgumentsBase
+ public sealed class Options : OptionsBase
{
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of threads to use for the sweep (default auto determined by the number of cores)", ShortName = "t")]
public int? NumThreads;
@@ -30,13 +30,13 @@ public sealed class Arguments : ArgumentsBase
private readonly ParallelOptions _parallelOptions;
- public SynthConfigRunner(IHostEnvironment env, Arguments args)
- : base(args, env, "SynthSweepEvaluator")
+ public SynthConfigRunner(IHostEnvironment env, Options options)
+ : base(options, env, "SynthSweepEvaluator")
{
- Host.CheckUserArg(args.NumThreads == null || args.NumThreads.Value > 0, nameof(args.NumThreads), "Must be positive");
- _parallelOptions = new ParallelOptions { MaxDegreeOfParallelism = args.NumThreads ?? -1 };
- Host.AssertNonEmpty(args.OutputFolderName);
- ProcessFullExePath(args.Exe);
+ Host.CheckUserArg(options.NumThreads == null || options.NumThreads.Value > 0, nameof(options.NumThreads), "Must be positive");
+ _parallelOptions = new ParallelOptions { MaxDegreeOfParallelism = options.NumThreads ?? -1 };
+ Host.AssertNonEmpty(options.OutputFolderName);
+ ProcessFullExePath(options.Exe);
}
protected override IEnumerable RunConfigsCore(ParameterSet[] sweeps, IChannel ch, int min)
diff --git a/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs b/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs
index 40915c484c..d64150c78d 100644
--- a/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs
+++ b/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs
@@ -10,10 +10,10 @@
using Microsoft.ML.Model;
using Microsoft.ML.Transforms;
-[assembly: LoadableClass(typeof(GaussianFourierSampler), typeof(GaussianFourierSampler.Arguments), typeof(SignatureFourierDistributionSampler),
+[assembly: LoadableClass(typeof(GaussianFourierSampler), typeof(GaussianFourierSampler.Options), typeof(SignatureFourierDistributionSampler),
"Gaussian Kernel", GaussianFourierSampler.LoadName, "Gaussian")]
-[assembly: LoadableClass(typeof(LaplacianFourierSampler), typeof(LaplacianFourierSampler.Arguments), typeof(SignatureFourierDistributionSampler),
+[assembly: LoadableClass(typeof(LaplacianFourierSampler), typeof(LaplacianFourierSampler.Options), typeof(SignatureFourierDistributionSampler),
"Laplacian Kernel", LaplacianFourierSampler.RegistrationName, "Laplacian")]
// This is for deserialization from a binary model file.
@@ -46,7 +46,7 @@ public sealed class GaussianFourierSampler : IFourierDistributionSampler
{
private readonly IHost _host;
- public class Arguments : IFourierDistributionSamplerFactory
+ public sealed class Options : IFourierDistributionSamplerFactory
{
[Argument(ArgumentType.AtMostOnce, HelpText = "gamma in the kernel definition: exp(-gamma*||x-y||^2 / r^2). r is an estimate of the average intra-example distance", ShortName = "g")]
public float Gamma = 1;
@@ -70,7 +70,7 @@ private static VersionInfo GetVersionInfo()
private readonly float _gamma;
- public GaussianFourierSampler(IHostEnvironment env, Arguments args, float avgDist)
+ public GaussianFourierSampler(IHostEnvironment env, Options args, float avgDist)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(LoadName);
@@ -125,7 +125,7 @@ public float Next(Random rand)
public sealed class LaplacianFourierSampler : IFourierDistributionSampler
{
- public class Arguments : IFourierDistributionSamplerFactory
+ public sealed class Options : IFourierDistributionSamplerFactory
{
[Argument(ArgumentType.AtMostOnce, HelpText = "a in the term exp(-a|x| / r). r is an estimate of the average intra-example L1 distance")]
public float A = 1;
@@ -150,7 +150,7 @@ private static VersionInfo GetVersionInfo()
private readonly IHost _host;
private readonly float _a;
- public LaplacianFourierSampler(IHostEnvironment env, Arguments args, float avgDist)
+ public LaplacianFourierSampler(IHostEnvironment env, Options args, float avgDist)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(RegistrationName);
diff --git a/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs
index c1fb8b58d9..2663b37257 100644
--- a/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs
+++ b/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs
@@ -132,30 +132,30 @@ private static IDataView Create(IHostEnvironment env, IDataView input, string ou
}
/// Factory method for SignatureDataTransform.
- internal static IDataTransform Create(IHostEnvironment env, Options args, IDataView input)
+ internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register("Categorical");
- h.CheckValue(args, nameof(args));
+ h.CheckValue(options, nameof(options));
h.CheckValue(input, nameof(input));
- h.CheckUserArg(Utils.Size(args.Columns) > 0, nameof(args.Columns));
+ h.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns));
var replaceCols = new List();
var naIndicatorCols = new List();
var naConvCols = new List();
var concatCols = new List();
var dropCols = new List();
- var tmpIsMissingColNames = input.Schema.GetTempColumnNames(args.Columns.Length, "IsMissing");
- var tmpReplaceColNames = input.Schema.GetTempColumnNames(args.Columns.Length, "Replace");
- for (int i = 0; i < args.Columns.Length; i++)
+ var tmpIsMissingColNames = input.Schema.GetTempColumnNames(options.Columns.Length, "IsMissing");
+ var tmpReplaceColNames = input.Schema.GetTempColumnNames(options.Columns.Length, "Replace");
+ for (int i = 0; i < options.Columns.Length; i++)
{
- var column = args.Columns[i];
+ var column = options.Columns[i];
- var addInd = column.ConcatIndicator ?? args.Concat;
+ var addInd = column.ConcatIndicator ?? options.Concat;
if (!addInd)
{
replaceCols.Add(new MissingValueReplacingEstimator.ColumnInfo(column.Name, column.Source,
- (MissingValueReplacingEstimator.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot));
+ (MissingValueReplacingEstimator.ColumnInfo.ReplacementMode)(column.Kind ?? options.ReplaceWith), column.ImputeBySlot ?? options.ImputeBySlot));
continue;
}
@@ -190,7 +190,7 @@ internal static IDataTransform Create(IHostEnvironment env, Options args, IDataV
// Add the NAReplaceTransform column.
replaceCols.Add(new MissingValueReplacingEstimator.ColumnInfo(tmpReplacementColName, column.Source,
- (MissingValueReplacingEstimator.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot));
+ (MissingValueReplacingEstimator.ColumnInfo.ReplacementMode)(column.Kind ?? options.ReplaceWith), column.ImputeBySlot ?? options.ImputeBySlot));
// Add the ConcatTransform column.
if (replaceType is VectorType)
diff --git a/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs b/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs
index a07b854a2f..96be9fab0e 100644
--- a/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs
+++ b/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs
@@ -44,7 +44,7 @@ internal sealed class Options
public int NewDim = RandomFourierFeaturizingEstimator.Defaults.NewDim;
[Argument(ArgumentType.Multiple, HelpText = "Which kernel to use?", ShortName = "kernel", SignatureType = typeof(SignatureFourierDistributionSampler))]
- public IComponentFactory MatrixGenerator = new GaussianFourierSampler.Arguments();
+ public IComponentFactory MatrixGenerator = new GaussianFourierSampler.Options();
[Argument(ArgumentType.AtMostOnce, HelpText = "Create two features for every random Fourier frequency? (one for cos and one for sin)")]
public bool UseSin = RandomFourierFeaturizingEstimator.Defaults.UseSin;
@@ -666,7 +666,7 @@ public ColumnInfo(string name, int newDim, bool useSin, string inputColumnName =
Contracts.CheckUserArg(newDim > 0, nameof(newDim), "must be positive.");
InputColumnName = inputColumnName ?? name;
Name = name;
- Generator = generator ?? new GaussianFourierSampler.Arguments();
+ Generator = generator ?? new GaussianFourierSampler.Options();
NewDim = newDim;
UseSin = useSin;
Seed = seed;
diff --git a/test/Microsoft.ML.Benchmarks/RffTransform.cs b/test/Microsoft.ML.Benchmarks/RffTransform.cs
index b94eb7930e..815f8a3f33 100644
--- a/test/Microsoft.ML.Benchmarks/RffTransform.cs
+++ b/test/Microsoft.ML.Benchmarks/RffTransform.cs
@@ -30,7 +30,7 @@ public void SetupTrainingSpeedTests()
public void CV_Multiclass_Digits_RffTransform_OVAAveragedPerceptron()
{
var mlContext = new MLContext();
- var reader = mlContext.Data.CreateTextLoader(new TextLoader.Arguments
+ var reader = mlContext.Data.CreateTextLoader(new TextLoader.Options
{
Columns = new[]
{
diff --git a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs
index c08118bee6..7e97160e96 100644
--- a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs
+++ b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs
@@ -77,7 +77,7 @@ private TransformerChain>(scoreColumn.Value.Index);
- var c = new MultiAverage(Env, new MultiAverage.Arguments()).GetCombiner();
+ var c = new MultiAverage(Env, new MultiAverage.Options()).GetCombiner();
VBuffer score = default(VBuffer);
VBuffer[] score0 = new VBuffer[5];
VBuffer scoreSaved = default(VBuffer);
@@ -2513,7 +2513,7 @@ public void TestInputBuilderComponentFactories()
expected = FixWhitespace(expected);
Assert.Equal(expected, json);
- options.LossFunction = new HingeLoss.Arguments();
+ options.LossFunction = new HingeLoss.Options();
result = inputBuilder.GetJsonObject(options, inputBindingMap, inputMap);
json = FixWhitespace(result.ToString(Formatting.Indented));
@@ -2529,7 +2529,7 @@ public void TestInputBuilderComponentFactories()
expected = FixWhitespace(expected);
Assert.Equal(expected, json);
- options.LossFunction = new HingeLoss.Arguments() { Margin = 2 };
+ options.LossFunction = new HingeLoss.Options() { Margin = 2 };
result = inputBuilder.GetJsonObject(options, inputBindingMap, inputMap);
json = FixWhitespace(result.ToString(Formatting.Indented));
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs
index faa1ba5f26..fff6e594f2 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs
@@ -84,7 +84,7 @@ public void LogEventProcessesMessages()
// create a dummy text reader to trigger log messages
env.Data.CreateTextLoader(
- new TextLoader.Arguments {Columns = new[] {new TextLoader.Column("TestColumn", null, 0)}});
+ new TextLoader.Options {Columns = new[] {new TextLoader.Column("TestColumn", null, 0)}});
Assert.True(messages.Count > 0);
}
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestLoss.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestLoss.cs
index ef03fdce16..8a10eb6051 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/TestLoss.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestLoss.cs
@@ -64,8 +64,8 @@ public void LossHinge()
[Fact]
public void LossExponential()
{
- ExpLoss.Arguments args = new ExpLoss.Arguments();
- ExpLoss loss = new ExpLoss(args);
+ ExpLoss.Options options = new ExpLoss.Options();
+ ExpLoss loss = new ExpLoss(options);
TestHelper(loss, 1, 3, Math.Exp(-3), Math.Exp(-3));
TestHelper(loss, 0, 3, Math.Exp(3), -Math.Exp(3));
TestHelper(loss, 0, -3, Math.Exp(-3), -Math.Exp(-3));
diff --git a/test/Microsoft.ML.Functional.Tests/Validation.cs b/test/Microsoft.ML.Functional.Tests/Validation.cs
index a1e675d248..a39bd14884 100644
--- a/test/Microsoft.ML.Functional.Tests/Validation.cs
+++ b/test/Microsoft.ML.Functional.Tests/Validation.cs
@@ -82,7 +82,7 @@ public void TrainWithValidationSet()
var trainedModel = mlContext.Regression.Trainers.FastTree(new Trainers.FastTree.FastTreeRegressionTrainer.Options {
NumTrees = 2,
EarlyStoppingMetrics = 2,
- EarlyStoppingRule = new GLEarlyStoppingCriterion.Arguments()
+ EarlyStoppingRule = new GLEarlyStoppingCriterion.Options()
})
.Fit(trainData: preprocessedTrainData, validationData: preprocessedValidData);
diff --git a/test/Microsoft.ML.Predictor.Tests/TestIniModels.cs b/test/Microsoft.ML.Predictor.Tests/TestIniModels.cs
index 69ddb618ac..234ff0a7d5 100644
--- a/test/Microsoft.ML.Predictor.Tests/TestIniModels.cs
+++ b/test/Microsoft.ML.Predictor.Tests/TestIniModels.cs
@@ -522,7 +522,7 @@ public void TestGamRegressionIni()
{
var mlContext = new MLContext(seed: 0);
var idv = mlContext.Data.CreateTextLoader(
- new TextLoader.Arguments()
+ new TextLoader.Options()
{
HasHeader = false,
Columns = new[]
@@ -561,7 +561,7 @@ public void TestGamBinaryClassificationIni()
{
var mlContext = new MLContext(seed: 0);
var idv = mlContext.Data.CreateTextLoader(
- new TextLoader.Arguments()
+ new TextLoader.Options()
{
HasHeader = false,
Columns = new[]
diff --git a/test/Microsoft.ML.Sweeper.Tests/SweeperTest.cs b/test/Microsoft.ML.Sweeper.Tests/SweeperTest.cs
index 077ed5104c..9f88d678d7 100644
--- a/test/Microsoft.ML.Sweeper.Tests/SweeperTest.cs
+++ b/test/Microsoft.ML.Sweeper.Tests/SweeperTest.cs
@@ -15,7 +15,7 @@ public void UniformRandomSweeperReturnsDistinctValuesWhenProposeSweep()
var env = new MLContext(42);
var sweeper = new UniformRandomSweeper(env,
- new SweeperBase.ArgumentsBase(),
+ new SweeperBase.OptionsBase(),
new[] { valueGenerator });
var results = sweeper.ProposeSweeps(3);
@@ -32,7 +32,7 @@ public void RandomGridSweeperReturnsDistinctValuesWhenProposeSweep()
var env = new MLContext(42);
var sweeper = new RandomGridSweeper(env,
- new RandomGridSweeper.Arguments(),
+ new RandomGridSweeper.Options(),
new[] { valueGenerator });
var results = sweeper.ProposeSweeps(3);
diff --git a/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs b/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs
index 11bd4ae0db..07dd0db895 100644
--- a/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs
+++ b/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs
@@ -90,7 +90,7 @@ public void TestDiscreteValueSweep(double normalizedValue, string expected)
public void TestRandomSweeper()
{
var env = new MLContext(42);
- var args = new SweeperBase.ArgumentsBase()
+ var args = new SweeperBase.OptionsBase()
{
SweptParameters = new[] {
ComponentFactoryUtils.CreateFromFunction(
@@ -131,7 +131,7 @@ public void TestSimpleSweeperAsync()
var random = new Random(42);
var env = new MLContext(42);
const int sweeps = 100;
- var sweeper = new SimpleAsyncSweeper(env, new SweeperBase.ArgumentsBase
+ var sweeper = new SimpleAsyncSweeper(env, new SweeperBase.OptionsBase
{
SweptParameters = new IComponentFactory[] {
ComponentFactoryUtils.CreateFromFunction(
@@ -154,7 +154,7 @@ public void TestSimpleSweeperAsync()
CheckAsyncSweeperResult(paramSets);
// Test consumption without ever calling Update.
- var gridArgs = new RandomGridSweeper.Arguments();
+ var gridArgs = new RandomGridSweeper.Options();
gridArgs.SweptParameters = new IComponentFactory[] {
ComponentFactoryUtils.CreateFromFunction(
environ => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})),
@@ -178,13 +178,13 @@ public void TestDeterministicSweeperAsyncCancellation()
{
var random = new Random(42);
var env = new MLContext(42);
- var args = new DeterministicSweeperAsync.Arguments();
+ var args = new DeterministicSweeperAsync.Options();
args.BatchSize = 5;
args.Relaxation = 1;
args.Sweeper = ComponentFactoryUtils.CreateFromFunction(
environ => new KdoSweeper(environ,
- new KdoSweeper.Arguments()
+ new KdoSweeper.Options()
{
SweptParameters = new IComponentFactory[] {
ComponentFactoryUtils.CreateFromFunction(
@@ -228,13 +228,13 @@ public void TestDeterministicSweeperAsync()
{
var random = new Random(42);
var env = new MLContext(42);
- var args = new DeterministicSweeperAsync.Arguments();
+ var args = new DeterministicSweeperAsync.Options();
args.BatchSize = 5;
args.Relaxation = args.BatchSize - 1;
args.Sweeper = ComponentFactoryUtils.CreateFromFunction(
environ => new SmacSweeper(environ,
- new SmacSweeper.Arguments()
+ new SmacSweeper.Options()
{
SweptParameters = new IComponentFactory[] {
ComponentFactoryUtils.CreateFromFunction(
@@ -300,13 +300,13 @@ public void TestDeterministicSweeperAsyncParallel()
const int batchSize = 5;
const int sweeps = 20;
var paramSets = new List();
- var args = new DeterministicSweeperAsync.Arguments();
+ var args = new DeterministicSweeperAsync.Options();
args.BatchSize = batchSize;
args.Relaxation = batchSize - 2;
args.Sweeper = ComponentFactoryUtils.CreateFromFunction(
environ => new SmacSweeper(environ,
- new SmacSweeper.Arguments()
+ new SmacSweeper.Options()
{
SweptParameters = new IComponentFactory[] {
ComponentFactoryUtils.CreateFromFunction(
@@ -352,7 +352,7 @@ public async Task TestNelderMeadSweeperAsync()
const int batchSize = 5;
const int sweeps = 40;
var paramSets = new List();
- var args = new DeterministicSweeperAsync.Arguments();
+ var args = new DeterministicSweeperAsync.Options();
args.BatchSize = batchSize;
args.Relaxation = 0;
@@ -366,12 +366,12 @@ public async Task TestNelderMeadSweeperAsync()
innerEnviron => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true }))
};
- var nelderMeadSweeperArgs = new NelderMeadSweeper.Arguments()
+ var nelderMeadSweeperArgs = new NelderMeadSweeper.Options()
{
SweptParameters = param,
FirstBatchSweeper = ComponentFactoryUtils.CreateFromFunction(
(firstBatchSweeperEnviron, firstBatchSweeperArgs) =>
- new RandomGridSweeper(environ, new RandomGridSweeper.Arguments() { SweptParameters = param }))
+ new RandomGridSweeper(environ, new RandomGridSweeper.Options() { SweptParameters = param }))
};
return new NelderMeadSweeper(environ, nelderMeadSweeperArgs);
@@ -427,7 +427,7 @@ private void CheckAsyncSweeperResult(List paramSets)
public void TestRandomGridSweeper()
{
var env = new MLContext(42);
- var args = new RandomGridSweeper.Arguments()
+ var args = new RandomGridSweeper.Options()
{
SweptParameters = new[] {
ComponentFactoryUtils.CreateFromFunction(
@@ -544,13 +544,13 @@ public void TestNelderMeadSweeper()
environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true }))
};
- var args = new NelderMeadSweeper.Arguments()
+ var args = new NelderMeadSweeper.Options()
{
SweptParameters = param,
FirstBatchSweeper = ComponentFactoryUtils.CreateFromFunction(
(environ, firstBatchArgs) =>
{
- return new RandomGridSweeper(environ, new RandomGridSweeper.Arguments() { SweptParameters = param });
+ return new RandomGridSweeper(environ, new RandomGridSweeper.Options() { SweptParameters = param });
}
)
};
@@ -600,7 +600,7 @@ public void TestNelderMeadSweeperWithDefaultFirstBatchSweeper()
environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true }))
};
- var args = new NelderMeadSweeper.Arguments();
+ var args = new NelderMeadSweeper.Options();
args.SweptParameters = param;
var sweeper = new NelderMeadSweeper(env, args);
var sweeps = sweeper.ProposeSweeps(5, new List());
@@ -642,7 +642,7 @@ public void TestSmacSweeper()
var random = new Random(42);
var env = new MLContext(42);
const int maxInitSweeps = 5;
- var args = new SmacSweeper.Arguments()
+ var args = new SmacSweeper.Options()
{
NumberInitialPopulation = 20,
SweptParameters = new IComponentFactory[] {
diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs
index 8528a51c1b..10caa9c6de 100644
--- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs
+++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs
@@ -819,7 +819,7 @@ public void SavePipeWithKey()
pipe =>
{
- var argsText = new TextLoader.Arguments();
+ var argsText = new TextLoader.Options();
bool tmp = CmdParser.ParseArguments(Env,
" header=+" +
" col=Label:TX:0" +
diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs
index 5da719472a..b1a4ad34c9 100644
--- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs
+++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs
@@ -403,7 +403,7 @@ protected bool SaveLoadText(IDataView view, IHostEnvironment env,
view = new ChooseColumnsByIndexTransform(env, chooseargs, view);
}
- var args = new TextLoader.Arguments();
+ var args = new TextLoader.Options();
if (!CmdParser.ParseArguments(Env, argsLoader, args))
{
Fail("Couldn't parse the args '{0}' in '{1}'", argsLoader, pathData);
diff --git a/test/Microsoft.ML.Tests/AnomalyDetectionTests.cs b/test/Microsoft.ML.Tests/AnomalyDetectionTests.cs
index 251b3cc611..f9aba9550b 100644
--- a/test/Microsoft.ML.Tests/AnomalyDetectionTests.cs
+++ b/test/Microsoft.ML.Tests/AnomalyDetectionTests.cs
@@ -32,7 +32,7 @@ public void RandomizedPcaTrainerBaselineTest()
var mlContext = new MLContext(seed: 1, conc: 1);
string featureColumn = "NumericFeatures";
- var reader = new TextLoader(Env, new TextLoader.Arguments()
+ var reader = new TextLoader(Env, new TextLoader.Options()
{
HasHeader = true,
Separator = "\t",
diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs
index b2f1c26155..2bf7129e61 100644
--- a/test/Microsoft.ML.Tests/ImagesTests.cs
+++ b/test/Microsoft.ML.Tests/ImagesTests.cs
@@ -29,7 +29,7 @@ public void TestEstimatorChain()
var env = new MLContext();
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ var data = TextLoader.Create(env, new TextLoader.Options()
{
Columns = new[]
{
@@ -37,7 +37,7 @@ public void TestEstimatorChain()
new TextLoader.Column("Name", DataKind.TX, 1),
}
}, new MultiFileSource(dataFile));
- var invalidData = TextLoader.Create(env, new TextLoader.Arguments()
+ var invalidData = TextLoader.Create(env, new TextLoader.Options()
{
Columns = new[]
{
@@ -60,7 +60,7 @@ public void TestEstimatorSaveLoad()
IHostEnvironment env = new MLContext();
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ var data = TextLoader.Create(env, new TextLoader.Options()
{
Columns = new[]
{
@@ -99,7 +99,7 @@ public void TestSaveImages()
var env = new MLContext();
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ var data = TextLoader.Create(env, new TextLoader.Options()
{
Columns = new[]
{
@@ -138,7 +138,7 @@ public void TestGreyscaleTransformImages()
var imageWidth = 100;
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ var data = TextLoader.Create(env, new TextLoader.Options()
{
Columns = new[]
{
@@ -189,7 +189,7 @@ public void TestBackAndForthConversionWithAlphaInterleave()
const int imageWidth = 130;
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ var data = TextLoader.Create(env, new TextLoader.Options()
{
Columns = new[]
{
@@ -256,7 +256,7 @@ public void TestBackAndForthConversionWithoutAlphaInterleave()
const int imageWidth = 130;
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ var data = TextLoader.Create(env, new TextLoader.Options()
{
Columns = new[]
{
@@ -323,7 +323,7 @@ public void TestBackAndForthConversionWithAlphaNoInterleave()
const int imageWidth = 130;
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ var data = TextLoader.Create(env, new TextLoader.Options()
{
Columns = new[]
{
@@ -390,7 +390,7 @@ public void TestBackAndForthConversionWithoutAlphaNoInterleave()
const int imageWidth = 130;
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ var data = TextLoader.Create(env, new TextLoader.Options()
{
Columns = new[]
{
@@ -457,7 +457,7 @@ public void TestBackAndForthConversionWithAlphaInterleaveNoOffset()
const int imageWidth = 130;
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ var data = TextLoader.Create(env, new TextLoader.Options()
{
Columns = new[]
{
@@ -523,7 +523,7 @@ public void TestBackAndForthConversionWithoutAlphaInterleaveNoOffset()
const int imageWidth = 130;
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ var data = TextLoader.Create(env, new TextLoader.Options()
{
Columns = new[]
{
@@ -589,7 +589,7 @@ public void TestBackAndForthConversionWithAlphaNoInterleaveNoOffset()
var imageWidth = 130;
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ var data = TextLoader.Create(env, new TextLoader.Options()
{
Columns = new[]
{
@@ -655,7 +655,7 @@ public void TestBackAndForthConversionWithoutAlphaNoInterleaveNoOffset()
const int imageWidth = 130;
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ var data = TextLoader.Create(env, new TextLoader.Options()
{
Columns = new[]
{
@@ -718,7 +718,7 @@ public void ImageResizerTransformResizingModeFill()
var env = new MLContext();
var dataFile = GetDataPath("images/fillmode.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ var data = TextLoader.Create(env, new TextLoader.Options()
{
Columns = new[]
{
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs
index 817b8c5f3a..3a6c643bef 100644
--- a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs
@@ -171,7 +171,7 @@ public void TrainAveragedPerceptronWithCache()
{
var mlContext = new MLContext(0);
var dataFile = GetDataPath("breast-cancer.txt");
- var loader = TextLoader.Create(mlContext, new TextLoader.Arguments(), new MultiFileSource(dataFile));
+ var loader = TextLoader.Create(mlContext, new TextLoader.Options(), new MultiFileSource(dataFile));
var globalCounter = 0;
var xf = LambdaTransform.CreateFilter