diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs index 8a24b86aef..25772bb82d 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs @@ -28,7 +28,7 @@ public static void Example() var mlContext = new MLContext(); // Create a text loader. - var reader = mlContext.Data.CreateTextLoader(new TextLoader.Arguments() + var reader = mlContext.Data.CreateTextLoader(new TextLoader.Options() { Separators = new[] { '\t' }, HasHeader = true, diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/ConvertToGrayScale.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/ConvertToGrayScale.cs index 87152b211b..b168575290 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/ConvertToGrayScale.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/ConvertToGrayScale.cs @@ -23,7 +23,7 @@ public static void Example() // hotdog.jpg hotdog // tomato.jpg tomato - var data = mlContext.Data.CreateTextLoader(new TextLoader.Arguments() + var data = mlContext.Data.CreateTextLoader(new TextLoader.Options() { Columns = new[] { diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/ExtractPixels.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/ExtractPixels.cs index 4f88cdf531..eb7e164004 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/ExtractPixels.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/ExtractPixels.cs @@ -24,7 +24,7 @@ public static void Example() // hotdog.jpg hotdog // tomato.jpg tomato - var data = mlContext.Data.CreateTextLoader(new TextLoader.Arguments() + var data = mlContext.Data.CreateTextLoader(new TextLoader.Options() { Columns = new[] { diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/LoadImages.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/LoadImages.cs index 5979c4bb3b..541f564283 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/LoadImages.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/LoadImages.cs @@ -23,7 +23,7 @@ public static void Example() // hotdog.jpg hotdog // tomato.jpg tomato - var data = mlContext.Data.CreateTextLoader(new TextLoader.Arguments() + var data = mlContext.Data.CreateTextLoader(new TextLoader.Options() { Columns = new[] { diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/ResizeImages.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/ResizeImages.cs index 4e59c20a83..03ada5304e 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/ResizeImages.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/ImageAnalytics/ResizeImages.cs @@ -23,7 +23,7 @@ public static void Example() // hotdog.jpg hotdog // tomato.jpg tomato - var data = mlContext.Data.CreateTextLoader(new TextLoader.Arguments() + var data = mlContext.Data.CreateTextLoader(new TextLoader.Options() { Columns = new[] { diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs index aaf8d02e50..1a6aacfe33 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs @@ -31,7 +31,7 @@ public static void Example() // 14. Column: native-country (text/categorical) // 15. Column: Column [Label]: IsOver50K (boolean) - var reader = ml.Data.CreateTextLoader(new TextLoader.Arguments + var reader = ml.Data.CreateTextLoader(new TextLoader.Options { Separators = new[] { ',' }, HasHeader = true, diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs index ee568bff92..830b5981cc 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs @@ -24,7 +24,7 @@ public static void Example() // Define the trainer options. var options = new AveragedPerceptronTrainer.Options() { - LossFunction = new SmoothedHingeLoss.Arguments(), + LossFunction = new SmoothedHingeLoss.Options(), LearningRate = 0.1f, DoLazyUpdates = false, RecencyGain = 0.1f, diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/OrdinaryLeastSquares.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/OrdinaryLeastSquares.cs index c660b603c4..3a8a17952b 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/OrdinaryLeastSquares.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/OrdinaryLeastSquares.cs @@ -22,7 +22,7 @@ public static void Example() // The data is tab separated with all numeric columns. // The first column being the label and rest are numeric features // Here only seven numeric columns are used as features - var dataView = mlContext.Data.ReadFromTextFile(dataFile, new TextLoader.Arguments + var dataView = mlContext.Data.ReadFromTextFile(dataFile, new TextLoader.Options { Separators = new[] { '\t' }, HasHeader = true, diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/OrdinaryLeastSquaresWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/OrdinaryLeastSquaresWithOptions.cs index 1d72e487cd..519a9ef683 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/OrdinaryLeastSquaresWithOptions.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/OrdinaryLeastSquaresWithOptions.cs @@ -23,7 +23,7 @@ public static void Example() // The data is tab separated with all numeric columns. // The first column being the label and rest are numeric features // Here only seven numeric columns are used as features - var dataView = mlContext.Data.ReadFromTextFile(dataFile, new TextLoader.Arguments + var dataView = mlContext.Data.ReadFromTextFile(dataFile, new TextLoader.Options { Separators = new[] { '\t' }, HasHeader = true, diff --git a/src/Microsoft.ML.Data/Commands/DataCommand.cs b/src/Microsoft.ML.Data/Commands/DataCommand.cs index 8997cfd29b..8e8a390f1e 100644 --- a/src/Microsoft.ML.Data/Commands/DataCommand.cs +++ b/src/Microsoft.ML.Data/Commands/DataCommand.cs @@ -364,11 +364,11 @@ protected IDataLoader CreateRawLoader( var isBinary = string.Equals(ext, ".idv", StringComparison.OrdinalIgnoreCase); var isTranspose = string.Equals(ext, ".tdv", StringComparison.OrdinalIgnoreCase); - return isText ? TextLoader.Create(Host, new TextLoader.Arguments(), fileSource) : + return isText ? TextLoader.Create(Host, new TextLoader.Options(), fileSource) : isBinary ? new BinaryLoader(Host, new BinaryLoader.Arguments(), fileSource) : isTranspose ? new TransposeLoader(Host, new TransposeLoader.Arguments(), fileSource) : defaultLoaderFactory != null ? defaultLoaderFactory(Host, fileSource) : - TextLoader.Create(Host, new TextLoader.Arguments(), fileSource); + TextLoader.Create(Host, new TextLoader.Options(), fileSource); } else { diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs index 17ed87c05d..b1c0a40d88 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs @@ -15,7 +15,7 @@ using Microsoft.ML.Model; using Float = System.Single; -[assembly: LoadableClass(TextLoader.Summary, typeof(IDataLoader), typeof(TextLoader), typeof(TextLoader.Arguments), typeof(SignatureDataLoader), +[assembly: LoadableClass(TextLoader.Summary, typeof(IDataLoader), typeof(TextLoader), typeof(TextLoader.Options), typeof(SignatureDataLoader), "Text Loader", "TextLoader", "Text", DocName = "loader/TextLoader.md")] [assembly: LoadableClass(TextLoader.Summary, typeof(IDataLoader), typeof(TextLoader), null, typeof(SignatureLoadDataLoader), @@ -379,7 +379,7 @@ public bool IsValid() } } - public sealed class Arguments : ArgumentsCore + public sealed class Options : ArgumentsCore { [Argument(ArgumentType.AtMostOnce, HelpText = "Use separate parsing threads?", ShortName = "threads", Hide = true)] public bool UseThreads = true; @@ -936,7 +936,7 @@ private static VersionInfo GetVersionInfo() /// bumping the version number. /// [Flags] - private enum Options : uint + private enum OptionFlags : uint { TrimWhitespace = 0x01, HasHeader = 0x02, @@ -950,7 +950,7 @@ private enum Options : uint private const int SrcLim = int.MaxValue; private readonly bool _useThreads; - private readonly Options _flags; + private readonly OptionFlags _flags; private readonly long _maxRows; // Input size is zero for unknown - determined by the data (including sparse rows). private readonly int _inputSize; @@ -961,7 +961,7 @@ private enum Options : uint private bool HasHeader { - get { return (_flags & Options.HasHeader) != 0; } + get { return (_flags & OptionFlags.HasHeader) != 0; } } private readonly IHost _host; @@ -980,10 +980,10 @@ public TextLoader(IHostEnvironment env, Column[] columns, bool hasHeader = false { } - private static Arguments MakeArgs(Column[] columns, bool hasHeader, char[] separatorChars) + private static Options MakeArgs(Column[] columns, bool hasHeader, char[] separatorChars) { Contracts.AssertValue(separatorChars); - var result = new Arguments { Columns = columns, HasHeader = hasHeader, Separators = separatorChars}; + var result = new Options { Columns = columns, HasHeader = hasHeader, Separators = separatorChars}; return result; } @@ -991,27 +991,27 @@ private static Arguments MakeArgs(Column[] columns, bool hasHeader, char[] separ /// Loads a text file into an . Supports basic mapping from input columns to IDataView columns. /// /// The environment to use. - /// Defines the settings of the load operation. + /// Defines the settings of the load operation. /// Allows to expose items that can be used for reading. - public TextLoader(IHostEnvironment env, Arguments args = null, IMultiStreamSource dataSample = null) + public TextLoader(IHostEnvironment env, Options options = null, IMultiStreamSource dataSample = null) { - args = args ?? new Arguments(); + options = options ?? new Options(); Contracts.CheckValue(env, nameof(env)); _host = env.Register(RegistrationName); - _host.CheckValue(args, nameof(args)); + _host.CheckValue(options, nameof(options)); _host.CheckValueOrNull(dataSample); if (dataSample == null) dataSample = new MultiFileSource(null); IMultiStreamSource headerFile = null; - if (!string.IsNullOrWhiteSpace(args.HeaderFile)) - headerFile = new MultiFileSource(args.HeaderFile); + if (!string.IsNullOrWhiteSpace(options.HeaderFile)) + headerFile = new MultiFileSource(options.HeaderFile); - var cols = args.Columns; + var cols = options.Columns; bool error; - if (Utils.Size(cols) == 0 && !TryParseSchema(_host, headerFile ?? dataSample, ref args, out cols, out error)) + if (Utils.Size(cols) == 0 && !TryParseSchema(_host, headerFile ?? dataSample, ref options, out cols, out error)) { if (error) throw _host.Except("TextLoader options embedded in the file are invalid"); @@ -1026,43 +1026,43 @@ public TextLoader(IHostEnvironment env, Arguments args = null, IMultiStreamSourc } _host.Assert(Utils.Size(cols) > 0); - _useThreads = args.UseThreads; + _useThreads = options.UseThreads; - if (args.TrimWhitespace) - _flags |= Options.TrimWhitespace; - if (headerFile == null && args.HasHeader) - _flags |= Options.HasHeader; - if (args.AllowQuoting) - _flags |= Options.AllowQuoting; - if (args.AllowSparse) - _flags |= Options.AllowSparse; + if (options.TrimWhitespace) + _flags |= OptionFlags.TrimWhitespace; + if (headerFile == null && options.HasHeader) + _flags |= OptionFlags.HasHeader; + if (options.AllowQuoting) + _flags |= OptionFlags.AllowQuoting; + if (options.AllowSparse) + _flags |= OptionFlags.AllowSparse; // REVIEW: This should be persisted (if it should be maintained). - _maxRows = args.MaxRows ?? long.MaxValue; - _host.CheckUserArg(_maxRows >= 0, nameof(args.MaxRows)); + _maxRows = options.MaxRows ?? long.MaxValue; + _host.CheckUserArg(_maxRows >= 0, nameof(options.MaxRows)); // Note that _maxDim == 0 means sparsity is illegal. - _inputSize = args.InputSize ?? 0; + _inputSize = options.InputSize ?? 0; _host.Check(_inputSize >= 0, "inputSize"); if (_inputSize >= SrcLim) _inputSize = SrcLim - 1; - _host.CheckNonEmpty(args.Separator, nameof(args.Separator), "Must specify a separator"); + _host.CheckNonEmpty(options.Separator, nameof(options.Separator), "Must specify a separator"); //Default arg.Separator is tab and default args.Separators is also a '\t'. //At a time only one default can be different and whichever is different that will //be used. - if (args.Separators.Length > 1 || args.Separators[0] != '\t') + if (options.Separators.Length > 1 || options.Separators[0] != '\t') { var separators = new HashSet(); - foreach (char c in args.Separators) + foreach (char c in options.Separators) separators.Add(NormalizeSeparator(c.ToString())); _separators = separators.ToArray(); } else { - string sep = args.Separator.ToLowerInvariant(); + string sep = options.Separator.ToLowerInvariant(); if (sep == ",") _separators = new char[] { ',' }; else @@ -1103,7 +1103,7 @@ private char NormalizeSeparator(string sep) return ','; case "colon": case ":": - _host.CheckUserArg((_flags & Options.AllowSparse) == 0, nameof(Arguments.Separator), + _host.CheckUserArg((_flags & OptionFlags.AllowSparse) == 0, nameof(Options.Separator), "When the separator is colon, turn off allowSparse"); return ':'; case "semicolon": @@ -1115,7 +1115,7 @@ private char NormalizeSeparator(string sep) default: char ch = sep[0]; if (sep.Length != 1 || ch < ' ' || '0' <= ch && ch <= '9' || ch == '"') - throw _host.ExceptUserArg(nameof(Arguments.Separator), "Illegal separator: '{0}'", sep); + throw _host.ExceptUserArg(nameof(Options.Separator), "Illegal separator: '{0}'", sep); return sep[0]; } } @@ -1134,7 +1134,7 @@ private sealed class LoaderHolder // If so, update args and set cols to the combined set of columns. // If not, set error to true if there was an error condition. private static bool TryParseSchema(IHost host, IMultiStreamSource files, - ref Arguments args, out Column[] cols, out bool error) + ref Options options, out Column[] cols, out bool error) { host.AssertValue(host); host.AssertValue(files); @@ -1144,7 +1144,7 @@ private static bool TryParseSchema(IHost host, IMultiStreamSource files, // Verify that the current schema-defining arguments are default. // Get settings just for core arguments, not everything. - string tmp = CmdParser.GetSettings(host, args, new ArgumentsCore()); + string tmp = CmdParser.GetSettings(host, options, new ArgumentsCore()); // Try to get the schema information from the file. string str = Cursor.GetEmbeddedArgs(files); @@ -1176,12 +1176,12 @@ private static bool TryParseSchema(IHost host, IMultiStreamSource files, // Make sure the loader binds to us. var info = host.ComponentCatalog.GetLoadableClassInfo(loader.Name); - if (info.Type != typeof(IDataLoader) || info.ArgType != typeof(Arguments)) + if (info.Type != typeof(IDataLoader) || info.ArgType != typeof(Options)) goto LDone; - var argsNew = new Arguments(); + var argsNew = new Options(); // Copy the non-core arguments to the new args (we already know that all the core arguments are default). - var parsed = CmdParser.ParseArguments(host, CmdParser.GetSettings(host, args, new Arguments()), argsNew); + var parsed = CmdParser.ParseArguments(host, CmdParser.GetSettings(host, options, new Options()), argsNew); ch.Assert(parsed); // Copy the core arguments to the new args. if (!CmdParser.ParseArguments(host, loader.GetSettingsString(), argsNew, typeof(ArgumentsCore), msg => ch.Error(msg))) @@ -1192,7 +1192,7 @@ private static bool TryParseSchema(IHost host, IMultiStreamSource files, goto LDone; error = false; - args = argsNew; + options = argsNew; LDone: return !error; @@ -1202,16 +1202,16 @@ private static bool TryParseSchema(IHost host, IMultiStreamSource files, /// /// Checks whether the source contains the valid TextLoader.Arguments depiction. /// - public static bool FileContainsValidSchema(IHostEnvironment env, IMultiStreamSource files, out Arguments args) + public static bool FileContainsValidSchema(IHostEnvironment env, IMultiStreamSource files, out Options options) { Contracts.CheckValue(env, nameof(env)); var h = env.Register(RegistrationName); h.CheckValue(files, nameof(files)); - args = new Arguments(); + options = new Options(); Column[] cols; bool error; - bool found = TryParseSchema(h, files, ref args, out cols, out error); - return found && !error && args.IsValid(); + bool found = TryParseSchema(h, files, ref options, out cols, out error); + return found && !error && options.IsValid(); } private TextLoader(IHost host, ModelLoadContext ctx) @@ -1236,8 +1236,8 @@ private TextLoader(IHost host, ModelLoadContext ctx) host.CheckDecode(cbFloat == sizeof(Float)); _maxRows = ctx.Reader.ReadInt64(); host.CheckDecode(_maxRows > 0); - _flags = (Options)ctx.Reader.ReadUInt32(); - host.CheckDecode((_flags & ~Options.All) == 0); + _flags = (OptionFlags)ctx.Reader.ReadUInt32(); + host.CheckDecode((_flags & ~OptionFlags.All) == 0); _inputSize = ctx.Reader.ReadInt32(); host.CheckDecode(0 <= _inputSize && _inputSize < SrcLim); @@ -1253,7 +1253,7 @@ private TextLoader(IHost host, ModelLoadContext ctx) } if (_separators.Contains(':')) - host.CheckDecode((_flags & Options.AllowSparse) == 0); + host.CheckDecode((_flags & OptionFlags.AllowSparse) == 0); _bindings = new Bindings(ctx, this); _parser = new Parser(this); @@ -1273,14 +1273,14 @@ internal static TextLoader Create(IHostEnvironment env, ModelLoadContext ctx) // These are legacy constructors needed for ComponentCatalog. internal static IDataLoader Create(IHostEnvironment env, ModelLoadContext ctx, IMultiStreamSource files) => (IDataLoader)Create(env, ctx).Read(files); - internal static IDataLoader Create(IHostEnvironment env, Arguments args, IMultiStreamSource files) - => (IDataLoader)new TextLoader(env, args, files).Read(files); + internal static IDataLoader Create(IHostEnvironment env, Options options, IMultiStreamSource files) + => (IDataLoader)new TextLoader(env, options, files).Read(files); /// /// Convenience method to create a and use it to read a specified file. /// - internal static IDataView ReadFile(IHostEnvironment env, Arguments args, IMultiStreamSource fileSource) - => new TextLoader(env, args, fileSource).Read(fileSource); + internal static IDataView ReadFile(IHostEnvironment env, Options options, IMultiStreamSource fileSource) + => new TextLoader(env, options, fileSource).Read(fileSource); void ICanSaveModel.Save(ModelSaveContext ctx) { @@ -1298,7 +1298,7 @@ void ICanSaveModel.Save(ModelSaveContext ctx) // bindings ctx.Writer.Write(sizeof(Float)); ctx.Writer.Write(_maxRows); - _host.Assert((_flags & ~Options.All) == 0); + _host.Assert((_flags & ~OptionFlags.All) == 0); ctx.Writer.Write((uint)_flags); _host.Assert(0 <= _inputSize && _inputSize < SrcLim); ctx.Writer.Write(_inputSize); @@ -1367,7 +1367,7 @@ internal static TextLoader CreateTextReader(IHostEnvironment host, columns.Add(column); } - Arguments args = new Arguments + Options options = new Options { HasHeader = hasHeader, Separators = new[] { separator }, @@ -1377,7 +1377,7 @@ internal static TextLoader CreateTextReader(IHostEnvironment host, Columns = columns.ToArray() }; - return new TextLoader(host, args); + return new TextLoader(host, options); } private sealed class BoundLoader : IDataLoader diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs index 7c7305c723..df6252c6bc 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs @@ -637,7 +637,7 @@ public void Clear() } private readonly char[] _separators; - private readonly Options _flags; + private readonly OptionFlags _flags; private readonly int _inputSize; private readonly ColInfo[] _infos; @@ -705,7 +705,7 @@ public static void GetInputSize(TextLoader parent, List> li { foreach (var line in lines) { - var text = (parent._flags & Options.TrimWhitespace) != 0 ? ReadOnlyMemoryUtils.TrimEndWhiteSpace(line) : line; + var text = (parent._flags & OptionFlags.TrimWhitespace) != 0 ? ReadOnlyMemoryUtils.TrimEndWhiteSpace(line) : line; if (text.IsEmpty) continue; @@ -828,7 +828,7 @@ public void ParseRow(RowSet rows, int irow, Helper helper, bool[] active, string var impl = (HelperImpl)helper; var lineSpan = text.AsMemory(); var span = lineSpan.Span; - if ((_flags & Options.TrimWhitespace) != 0) + if ((_flags & OptionFlags.TrimWhitespace) != 0) lineSpan = TrimEndWhiteSpace(lineSpan, span); try { @@ -883,7 +883,7 @@ private sealed class HelperImpl : Helper public readonly FieldSet Fields; - public HelperImpl(ParseStats stats, Options flags, char[] seps, int inputSize, int srcNeeded) + public HelperImpl(ParseStats stats, OptionFlags flags, char[] seps, int inputSize, int srcNeeded) { Contracts.AssertValue(stats); // inputSize == 0 means unknown. @@ -899,8 +899,8 @@ public HelperImpl(ParseStats stats, Options flags, char[] seps, int inputSize, i _sepContainsSpace = IsSep(' '); _inputSize = inputSize; _srcNeeded = srcNeeded; - _quoting = (flags & Options.AllowQuoting) != 0; - _sparse = (flags & Options.AllowSparse) != 0; + _quoting = (flags & OptionFlags.AllowQuoting) != 0; + _sparse = (flags & OptionFlags.AllowSparse) != 0; _sb = new StringBuilder(); _blank = ReadOnlyMemory.Empty; Fields = new FieldSet(); diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs index f081220f41..7b5e684b46 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs @@ -30,12 +30,12 @@ public static TextLoader CreateTextLoader(this DataOperationsCatalog catalog, /// Create a text loader . /// /// The catalog. - /// Defines the settings of the load operation. + /// Defines the settings of the load operation. /// The optional location of a data sample. The sample can be used to infer column names and number of slots in each column. public static TextLoader CreateTextLoader(this DataOperationsCatalog catalog, - TextLoader.Arguments args, + TextLoader.Options options, IMultiStreamSource dataSample = null) - => new TextLoader(CatalogUtils.GetEnvironment(catalog), args, dataSample); + => new TextLoader(CatalogUtils.GetEnvironment(catalog), options, dataSample); /// /// Create a text loader by inferencing the dataset schema from a data model type. @@ -123,15 +123,15 @@ public static IDataView ReadFromTextFile(this DataOperationsCatalog cata /// /// The catalog. /// Specifies a file from which to read. - /// Defines the settings of the load operation. - public static IDataView ReadFromTextFile(this DataOperationsCatalog catalog, string path, TextLoader.Arguments args = null) + /// Defines the settings of the load operation. + public static IDataView ReadFromTextFile(this DataOperationsCatalog catalog, string path, TextLoader.Options options = null) { Contracts.CheckNonEmpty(path, nameof(path)); var env = catalog.GetEnvironment(); var source = new MultiFileSource(path); - return new TextLoader(env, args, source).Read(source); + return new TextLoader(env, options, source).Read(source); } /// diff --git a/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs b/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs index 26a58b5b1a..6a53a5d2ce 100644 --- a/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs +++ b/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs @@ -9,11 +9,11 @@ using Microsoft.ML.Internal.Internallearn; using Float = System.Single; -[assembly: LoadableClass(typeof(TolerantEarlyStoppingCriterion), typeof(TolerantEarlyStoppingCriterion.Arguments), typeof(SignatureEarlyStoppingCriterion), "Tolerant (TR)", "tr")] -[assembly: LoadableClass(typeof(GLEarlyStoppingCriterion), typeof(GLEarlyStoppingCriterion.Arguments), typeof(SignatureEarlyStoppingCriterion), "Loss of Generality (GL)", "gl")] -[assembly: LoadableClass(typeof(LPEarlyStoppingCriterion), typeof(LPEarlyStoppingCriterion.Arguments), typeof(SignatureEarlyStoppingCriterion), "Low Progress (LP)", "lp")] -[assembly: LoadableClass(typeof(PQEarlyStoppingCriterion), typeof(PQEarlyStoppingCriterion.Arguments), typeof(SignatureEarlyStoppingCriterion), "Generality to Progress Ratio (PQ)", "pq")] -[assembly: LoadableClass(typeof(UPEarlyStoppingCriterion), typeof(UPEarlyStoppingCriterion.Arguments), typeof(SignatureEarlyStoppingCriterion), "Consecutive Loss in Generality (UP)", "up")] +[assembly: LoadableClass(typeof(TolerantEarlyStoppingCriterion), typeof(TolerantEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Tolerant (TR)", "tr")] +[assembly: LoadableClass(typeof(GLEarlyStoppingCriterion), typeof(GLEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Loss of Generality (GL)", "gl")] +[assembly: LoadableClass(typeof(LPEarlyStoppingCriterion), typeof(LPEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Low Progress (LP)", "lp")] +[assembly: LoadableClass(typeof(PQEarlyStoppingCriterion), typeof(PQEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Generality to Progress Ratio (PQ)", "pq")] +[assembly: LoadableClass(typeof(UPEarlyStoppingCriterion), typeof(UPEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Consecutive Loss in Generality (UP)", "up")] [assembly: EntryPointModule(typeof(TolerantEarlyStoppingCriterion))] [assembly: EntryPointModule(typeof(GLEarlyStoppingCriterion))] @@ -44,14 +44,14 @@ public interface IEarlyStoppingCriterionFactory : IComponentFactory : IEarlyStoppingCriterion - where TArguments : EarlyStoppingCriterion.ArgumentsBase + public abstract class EarlyStoppingCriterion : IEarlyStoppingCriterion + where TOptions : EarlyStoppingCriterion.OptionsBase { - public abstract class ArgumentsBase { } + public abstract class OptionsBase { } private Float _bestScore; - protected readonly TArguments Args; + protected readonly TOptions EarlyStoppingCriterionOptions; protected readonly bool LowerIsBetter; protected Float BestScore { get { return _bestScore; } @@ -62,9 +62,9 @@ protected Float BestScore { } } - internal EarlyStoppingCriterion(TArguments args, bool lowerIsBetter) + internal EarlyStoppingCriterion(TOptions options, bool lowerIsBetter) { - Args = args; + EarlyStoppingCriterionOptions = options; LowerIsBetter = lowerIsBetter; _bestScore = LowerIsBetter ? Float.PositiveInfinity : Float.NegativeInfinity; } @@ -86,10 +86,10 @@ protected bool CheckBestScore(Float score) } } - public sealed class TolerantEarlyStoppingCriterion : EarlyStoppingCriterion + public sealed class TolerantEarlyStoppingCriterion : EarlyStoppingCriterion { [TlcModule.Component(FriendlyName = "Tolerant (TR)", Name = "TR", Desc = "Stop if validation score exceeds threshold value.")] - public class Arguments : ArgumentsBase, IEarlyStoppingCriterionFactory + public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory { [Argument(ArgumentType.AtMostOnce, HelpText = "Tolerance threshold. (Non negative value)", ShortName = "th")] [TlcModule.Range(Min = 0.0f)] @@ -101,10 +101,10 @@ public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerI } } - public TolerantEarlyStoppingCriterion(Arguments args, bool lowerIsBetter) + public TolerantEarlyStoppingCriterion(Options args, bool lowerIsBetter) : base(args, lowerIsBetter) { - Contracts.CheckUserArg(Args.Threshold >= 0, nameof(args.Threshold), "Must be non-negative."); + Contracts.CheckUserArg(EarlyStoppingCriterionOptions.Threshold >= 0, nameof(args.Threshold), "Must be non-negative."); } public override bool CheckScore(Float validationScore, Float trainingScore, out bool isBestCandidate) @@ -114,9 +114,9 @@ public override bool CheckScore(Float validationScore, Float trainingScore, out isBestCandidate = CheckBestScore(validationScore); if (LowerIsBetter) - return (validationScore - BestScore > Args.Threshold); + return (validationScore - BestScore > EarlyStoppingCriterionOptions.Threshold); else - return (BestScore - validationScore > Args.Threshold); + return (BestScore - validationScore > EarlyStoppingCriterionOptions.Threshold); } } @@ -124,9 +124,9 @@ public override bool CheckScore(Float validationScore, Float trainingScore, out // Lodwich, Aleksander, Yves Rangoni, and Thomas Breuel. "Evaluation of robustness and performance of early stopping rules with multi layer perceptrons." // Neural Networks, 2009. IJCNN 2009. International Joint Conference on. IEEE, 2009. - public abstract class MovingWindowEarlyStoppingCriterion : EarlyStoppingCriterion + public abstract class MovingWindowEarlyStoppingCriterion : EarlyStoppingCriterion { - public class Arguments : ArgumentsBase + public class Options : OptionsBase { [Argument(ArgumentType.AtMostOnce, HelpText = "Threshold in range [0,1].", ShortName = "th")] [TlcModule.Range(Min = 0.0f, Max = 1.0f)] @@ -139,13 +139,13 @@ public class Arguments : ArgumentsBase protected Queue PastScores; - private protected MovingWindowEarlyStoppingCriterion(Arguments args, bool lowerIsBetter) + private protected MovingWindowEarlyStoppingCriterion(Options args, bool lowerIsBetter) : base(args, lowerIsBetter) { - Contracts.CheckUserArg(0 <= Args.Threshold && args.Threshold <= 1, nameof(args.Threshold), "Must be in range [0,1]."); - Contracts.CheckUserArg(Args.WindowSize > 0, nameof(args.WindowSize), "Must be positive."); + Contracts.CheckUserArg(0 <= EarlyStoppingCriterionOptions.Threshold && args.Threshold <= 1, nameof(args.Threshold), "Must be in range [0,1]."); + Contracts.CheckUserArg(EarlyStoppingCriterionOptions.WindowSize > 0, nameof(args.WindowSize), "Must be positive."); - PastScores = new Queue(Args.WindowSize); + PastScores = new Queue(EarlyStoppingCriterionOptions.WindowSize); } /// @@ -203,11 +203,11 @@ protected bool CheckRecentScores(Float score, int windowSize, out Float recentBe /// /// Loss of Generality (GL). /// - public sealed class GLEarlyStoppingCriterion : EarlyStoppingCriterion + public sealed class GLEarlyStoppingCriterion : EarlyStoppingCriterion { [TlcModule.Component(FriendlyName = "Loss of Generality (GL)", Name = "GL", Desc = "Stop in case of loss of generality.")] - public class Arguments : ArgumentsBase, IEarlyStoppingCriterionFactory + public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory { [Argument(ArgumentType.AtMostOnce, HelpText = "Threshold in range [0,1].", ShortName = "th")] [TlcModule.Range(Min = 0.0f, Max = 1.0f)] @@ -219,10 +219,10 @@ public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerI } } - public GLEarlyStoppingCriterion(Arguments args, bool lowerIsBetter) - : base(args, lowerIsBetter) + public GLEarlyStoppingCriterion(Options options, bool lowerIsBetter) + : base(options, lowerIsBetter) { - Contracts.CheckUserArg(0 <= Args.Threshold && args.Threshold <= 1, nameof(args.Threshold), "Must be in range [0,1]."); + Contracts.CheckUserArg(0 <= EarlyStoppingCriterionOptions.Threshold && options.Threshold <= 1, nameof(options.Threshold), "Must be in range [0,1]."); } public override bool CheckScore(Float validationScore, Float trainingScore, out bool isBestCandidate) @@ -232,9 +232,9 @@ public override bool CheckScore(Float validationScore, Float trainingScore, out isBestCandidate = CheckBestScore(validationScore); if (LowerIsBetter) - return (validationScore > (1 + Args.Threshold) * BestScore); + return (validationScore > (1 + EarlyStoppingCriterionOptions.Threshold) * BestScore); else - return (validationScore < (1 - Args.Threshold) * BestScore); + return (validationScore < (1 - EarlyStoppingCriterionOptions.Threshold) * BestScore); } } @@ -245,7 +245,7 @@ public override bool CheckScore(Float validationScore, Float trainingScore, out public sealed class LPEarlyStoppingCriterion : MovingWindowEarlyStoppingCriterion { [TlcModule.Component(FriendlyName = "Low Progress (LP)", Name = "LP", Desc = "Stops in case of low progress.")] - public new sealed class Arguments : MovingWindowEarlyStoppingCriterion.Arguments, IEarlyStoppingCriterionFactory + public new sealed class Options : MovingWindowEarlyStoppingCriterion.Options, IEarlyStoppingCriterionFactory { public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter) { @@ -253,7 +253,7 @@ public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerI } } - public LPEarlyStoppingCriterion(Arguments args, bool lowerIsBetter) + public LPEarlyStoppingCriterion(Options args, bool lowerIsBetter) : base(args, lowerIsBetter) { } public override bool CheckScore(Float validationScore, Float trainingScore, out bool isBestCandidate) @@ -265,12 +265,12 @@ public override bool CheckScore(Float validationScore, Float trainingScore, out Float recentBest; Float recentAverage; - if (CheckRecentScores(trainingScore, Args.WindowSize, out recentBest, out recentAverage)) + if (CheckRecentScores(trainingScore, EarlyStoppingCriterionOptions.WindowSize, out recentBest, out recentAverage)) { if (LowerIsBetter) - return (recentAverage <= (1 + Args.Threshold) * recentBest); + return (recentAverage <= (1 + EarlyStoppingCriterionOptions.Threshold) * recentBest); else - return (recentAverage >= (1 - Args.Threshold) * recentBest); + return (recentAverage >= (1 - EarlyStoppingCriterionOptions.Threshold) * recentBest); } return false; @@ -283,7 +283,7 @@ public override bool CheckScore(Float validationScore, Float trainingScore, out public sealed class PQEarlyStoppingCriterion : MovingWindowEarlyStoppingCriterion { [TlcModule.Component(FriendlyName = "Generality to Progress Ratio (PQ)", Name = "PQ", Desc = "Stops in case of generality to progress ration exceeds threshold.")] - public new sealed class Arguments : MovingWindowEarlyStoppingCriterion.Arguments, IEarlyStoppingCriterionFactory + public new sealed class Options : MovingWindowEarlyStoppingCriterion.Options, IEarlyStoppingCriterionFactory { public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter) { @@ -291,7 +291,7 @@ public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerI } } - public PQEarlyStoppingCriterion(Arguments args, bool lowerIsBetter) + public PQEarlyStoppingCriterion(Options args, bool lowerIsBetter) : base(args, lowerIsBetter) { } public override bool CheckScore(Float validationScore, Float trainingScore, out bool isBestCandidate) @@ -303,12 +303,12 @@ public override bool CheckScore(Float validationScore, Float trainingScore, out Float recentBest; Float recentAverage; - if (CheckRecentScores(trainingScore, Args.WindowSize, out recentBest, out recentAverage)) + if (CheckRecentScores(trainingScore, EarlyStoppingCriterionOptions.WindowSize, out recentBest, out recentAverage)) { if (LowerIsBetter) - return (validationScore * recentBest >= (1 + Args.Threshold) * BestScore * recentAverage); + return (validationScore * recentBest >= (1 + EarlyStoppingCriterionOptions.Threshold) * BestScore * recentAverage); else - return (validationScore * recentBest <= (1 - Args.Threshold) * BestScore * recentAverage); + return (validationScore * recentBest <= (1 - EarlyStoppingCriterionOptions.Threshold) * BestScore * recentAverage); } return false; @@ -318,11 +318,11 @@ public override bool CheckScore(Float validationScore, Float trainingScore, out /// /// Consecutive Loss in Generality (UP). /// - public sealed class UPEarlyStoppingCriterion : EarlyStoppingCriterion + public sealed class UPEarlyStoppingCriterion : EarlyStoppingCriterion { [TlcModule.Component(FriendlyName = "Consecutive Loss in Generality (UP)", Name = "UP", Desc = "Stops in case of consecutive loss in generality.")] - public sealed class Arguments : ArgumentsBase, IEarlyStoppingCriterionFactory + public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory { [Argument(ArgumentType.AtMostOnce, HelpText = "The window size.", ShortName = "w")] [TlcModule.Range(Inf = 0)] @@ -337,10 +337,10 @@ public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerI private int _count; private Float _prevScore; - public UPEarlyStoppingCriterion(Arguments args, bool lowerIsBetter) - : base(args, lowerIsBetter) + public UPEarlyStoppingCriterion(Options options, bool lowerIsBetter) + : base(options, lowerIsBetter) { - Contracts.CheckUserArg(Args.WindowSize > 0, nameof(args.WindowSize), "Must be positive"); + Contracts.CheckUserArg(EarlyStoppingCriterionOptions.WindowSize > 0, nameof(options.WindowSize), "Must be positive"); _prevScore = LowerIsBetter ? Float.PositiveInfinity : Float.NegativeInfinity; } @@ -354,7 +354,7 @@ public override bool CheckScore(Float validationScore, Float trainingScore, out _count = ((validationScore < _prevScore) != LowerIsBetter) ? _count + 1 : 0; _prevScore = validationScore; - return (_count >= Args.WindowSize); + return (_count >= EarlyStoppingCriterionOptions.WindowSize); } } } diff --git a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs index 2321fcaaf5..864cceb128 100644 --- a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs @@ -88,14 +88,14 @@ public static LabelIndicatorTransform Create(IHostEnvironment env, } public static LabelIndicatorTransform Create(IHostEnvironment env, - Options args, IDataView input) + Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); IHost h = env.Register(LoaderSignature); - h.CheckValue(args, nameof(args)); + h.CheckValue(options, nameof(options)); h.CheckValue(input, nameof(input)); return h.Apply("Loading Model", - ch => new LabelIndicatorTransform(h, args, input)); + ch => new LabelIndicatorTransform(h, options, input)); } private protected override void SaveModel(ModelSaveContext ctx) @@ -131,16 +131,16 @@ public LabelIndicatorTransform(IHostEnvironment env, { } - public LabelIndicatorTransform(IHostEnvironment env, Options args, IDataView input) - : base(env, LoadName, Contracts.CheckRef(args, nameof(args)).Columns, + public LabelIndicatorTransform(IHostEnvironment env, Options options, IDataView input) + : base(env, LoadName, Contracts.CheckRef(options, nameof(options)).Columns, input, TestIsMulticlassLabel) { Host.AssertNonEmpty(Infos); - Host.Assert(Infos.Length == Utils.Size(args.Columns)); + Host.Assert(Infos.Length == Utils.Size(options.Columns)); _classIndex = new int[Infos.Length]; for (int iinfo = 0; iinfo < Infos.Length; ++iinfo) - _classIndex[iinfo] = args.Columns[iinfo].ClassIndex ?? args.ClassIndex; + _classIndex[iinfo] = options.Columns[iinfo].ClassIndex ?? options.ClassIndex; Metadata.Seal(); } diff --git a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs index 2431667c04..70b69e574b 100644 --- a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs @@ -107,21 +107,21 @@ public RowShufflingTransformer(IHostEnvironment env, /// /// Constructor corresponding to SignatureDataTransform. /// - public RowShufflingTransformer(IHostEnvironment env, Options args, IDataView input) + public RowShufflingTransformer(IHostEnvironment env, Options options, IDataView input) : base(env, RegistrationName, input) { - Host.CheckValue(args, nameof(args)); + Host.CheckValue(options, nameof(options)); - Host.CheckUserArg(args.PoolRows > 0, nameof(args.PoolRows), "pool size must be positive"); - _poolRows = args.PoolRows; - _poolOnly = args.PoolOnly; - _forceShuffle = args.ForceShuffle; - _forceShuffleSource = args.ForceShuffleSource ?? (!_poolOnly && _forceShuffle); + Host.CheckUserArg(options.PoolRows > 0, nameof(options.PoolRows), "pool size must be positive"); + _poolRows = options.PoolRows; + _poolOnly = options.PoolOnly; + _forceShuffle = options.ForceShuffle; + _forceShuffleSource = options.ForceShuffleSource ?? (!_poolOnly && _forceShuffle); Host.CheckUserArg(!(_poolOnly && _forceShuffleSource), - nameof(args.ForceShuffleSource), "Cannot set both poolOnly and forceShuffleSource"); + nameof(options.ForceShuffleSource), "Cannot set both poolOnly and forceShuffleSource"); if (_forceShuffle || _forceShuffleSource) - _forceShuffleSeed = args.ForceShuffleSeed ?? Host.Rand.NextSigned(); + _forceShuffleSeed = options.ForceShuffleSeed ?? Host.Rand.NextSigned(); _subsetInput = SelectCachableColumns(input, env); } diff --git a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs index 7e71763125..b79610e53e 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs @@ -600,7 +600,7 @@ private static IDataTransform Create(IHostEnvironment env, Options options, IDat { keyColumn = new TextLoader.Column(keyColumnName, DataKind.TXT, 0); valueColumn = new TextLoader.Column(valueColumnName, DataKind.TXT, 1); - var txtArgs = new TextLoader.Arguments() + var txtArgs = new TextLoader.Options() { Columns = new TextLoader.Column[] { @@ -627,7 +627,7 @@ private static IDataTransform Create(IHostEnvironment env, Options options, IDat loader = TextLoader.Create( env, - new TextLoader.Arguments() + new TextLoader.Options() { Columns = new TextLoader.Column[] { diff --git a/src/Microsoft.ML.Data/Utils/LossFunctions.cs b/src/Microsoft.ML.Data/Utils/LossFunctions.cs index ed3ad17765..bf3eec4af7 100644 --- a/src/Microsoft.ML.Data/Utils/LossFunctions.cs +++ b/src/Microsoft.ML.Data/Utils/LossFunctions.cs @@ -12,13 +12,13 @@ [assembly: LoadableClass(LogLoss.Summary, typeof(LogLoss), null, typeof(SignatureClassificationLoss), "Log Loss", "LogLoss", "Logistic", "CrossEntropy")] -[assembly: LoadableClass(HingeLoss.Summary, typeof(HingeLoss), typeof(HingeLoss.Arguments), typeof(SignatureClassificationLoss), +[assembly: LoadableClass(HingeLoss.Summary, typeof(HingeLoss), typeof(HingeLoss.Options), typeof(SignatureClassificationLoss), "Hinge Loss", "HingeLoss", "Hinge")] -[assembly: LoadableClass(SmoothedHingeLoss.Summary, typeof(SmoothedHingeLoss), typeof(SmoothedHingeLoss.Arguments), typeof(SignatureClassificationLoss), +[assembly: LoadableClass(SmoothedHingeLoss.Summary, typeof(SmoothedHingeLoss), typeof(SmoothedHingeLoss.Options), typeof(SignatureClassificationLoss), "Smoothed Hinge Loss", "SmoothedHingeLoss", "SmoothedHinge")] -[assembly: LoadableClass(ExpLoss.Summary, typeof(ExpLoss), typeof(ExpLoss.Arguments), typeof(SignatureClassificationLoss), +[assembly: LoadableClass(ExpLoss.Summary, typeof(ExpLoss), typeof(ExpLoss.Options), typeof(SignatureClassificationLoss), "Exponential Loss", "ExpLoss", "Exp")] [assembly: LoadableClass(SquaredLoss.Summary, typeof(SquaredLoss), null, typeof(SignatureRegressionLoss), @@ -27,16 +27,16 @@ [assembly: LoadableClass(PoissonLoss.Summary, typeof(PoissonLoss), null, typeof(SignatureRegressionLoss), "Poisson Loss", "PoissonLoss", "Poisson")] -[assembly: LoadableClass(TweedieLoss.Summary, typeof(TweedieLoss), typeof(TweedieLoss.Arguments), typeof(SignatureRegressionLoss), +[assembly: LoadableClass(TweedieLoss.Summary, typeof(TweedieLoss), typeof(TweedieLoss.Options), typeof(SignatureRegressionLoss), "Tweedie Loss", "TweedieLoss", "Tweedie", "Tw")] -[assembly: EntryPointModule(typeof(ExpLoss.Arguments))] +[assembly: EntryPointModule(typeof(ExpLoss.Options))] [assembly: EntryPointModule(typeof(LogLossFactory))] -[assembly: EntryPointModule(typeof(HingeLoss.Arguments))] +[assembly: EntryPointModule(typeof(HingeLoss.Options))] [assembly: EntryPointModule(typeof(PoissonLossFactory))] -[assembly: EntryPointModule(typeof(SmoothedHingeLoss.Arguments))] +[assembly: EntryPointModule(typeof(SmoothedHingeLoss.Options))] [assembly: EntryPointModule(typeof(SquaredLossFactory))] -[assembly: EntryPointModule(typeof(TweedieLoss.Arguments))] +[assembly: EntryPointModule(typeof(TweedieLoss.Options))] namespace Microsoft.ML { @@ -162,7 +162,7 @@ private static Double Log(Double x) public sealed class HingeLoss : ISupportSdcaClassificationLoss { [TlcModule.Component(Name = "HingeLoss", FriendlyName = "Hinge loss", Alias = "Hinge", Desc = "Hinge loss.")] - public sealed class Arguments : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory + public sealed class Options : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory { [Argument(ArgumentType.AtMostOnce, HelpText = "Margin value", ShortName = "marg")] public Float Margin = Defaults.Margin; @@ -176,9 +176,9 @@ public sealed class Arguments : ISupportSdcaClassificationLossFactory, ISupportC private const Float Threshold = 0.5f; private readonly Float _margin; - internal HingeLoss(Arguments args) + internal HingeLoss(Options options) { - _margin = args.Margin; + _margin = options.Margin; } private static class Defaults @@ -187,7 +187,7 @@ private static class Defaults } public HingeLoss(float margin = Defaults.Margin) - : this(new Arguments() { Margin = margin }) + : this(new Options() { Margin = margin }) { } @@ -234,7 +234,7 @@ public sealed class SmoothedHingeLoss : ISupportSdcaClassificationLoss { [TlcModule.Component(Name = "SmoothedHingeLoss", FriendlyName = "Smoothed Hinge Loss", Alias = "SmoothedHinge", Desc = "Smoothed Hinge loss.")] - public sealed class Arguments : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory + public sealed class Options : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory { [Argument(ArgumentType.AtMostOnce, HelpText = "Smoothing constant", ShortName = "smooth")] public Float SmoothingConst = Defaults.SmoothingConst; @@ -268,8 +268,8 @@ public SmoothedHingeLoss(float smoothingConstant = Defaults.SmoothingConst) _doubleSmoothConst = _smoothConst * 2; } - private SmoothedHingeLoss(IHostEnvironment env, Arguments args) - : this(args.SmoothingConst) + private SmoothedHingeLoss(IHostEnvironment env, Options options) + : this(options.SmoothingConst) { } @@ -333,7 +333,7 @@ public Double DualLoss(Float label, Double dual) public sealed class ExpLoss : IClassificationLoss { [TlcModule.Component(Name = "ExpLoss", FriendlyName = "Exponential Loss", Desc = "Exponential loss.")] - public sealed class Arguments : ISupportClassificationLossFactory + public sealed class Options : ISupportClassificationLossFactory { [Argument(ArgumentType.AtMostOnce, HelpText = "Beta (dilation)", ShortName = "beta")] public Float Beta = 1; @@ -345,9 +345,9 @@ public sealed class Arguments : ISupportClassificationLossFactory private readonly Float _beta; - public ExpLoss(Arguments args) + public ExpLoss(Options options) { - _beta = args.Beta; + _beta = options.Beta; } public Double Loss(Float output, Float label) @@ -438,7 +438,7 @@ public Float Derivative(Float output, Float label) public sealed class TweedieLoss : IRegressionLoss { [TlcModule.Component(Name = "TweedieLoss", FriendlyName = "Tweedie Loss", Alias = "tweedie", Desc = "Tweedie loss.")] - public sealed class Arguments : ISupportRegressionLossFactory + public sealed class Options : ISupportRegressionLossFactory { [Argument(ArgumentType.LastOccurenceWins, HelpText = "Index parameter for the Tweedie distribution, in the range [1, 2]. 1 is Poisson loss, 2 is gamma loss, " + @@ -454,10 +454,10 @@ public sealed class Arguments : ISupportRegressionLossFactory private readonly Double _index1; // 1 minus the index parameter. private readonly Double _index2; // 2 minus the index parameter. - public TweedieLoss(Arguments args) + public TweedieLoss(Options options) { - Contracts.CheckUserArg(1 <= args.Index && args.Index <= 2, nameof(args.Index), "Must be in the range [1, 2]"); - _index = args.Index; + Contracts.CheckUserArg(1 <= options.Index && options.Index <= 2, nameof(options.Index), "Must be in the range [1, 2]"); + _index = options.Index; _index1 = 1 - _index; _index2 = 2 - _index; } diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs b/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs index dd9eaf41b6..67bd565193 100644 --- a/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs +++ b/src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs @@ -258,10 +258,10 @@ public static CommonOutputs.MulticlassClassificationOutput CreateMultiClassPipel switch (input.ModelCombiner) { case ClassifierCombiner.Median: - combiner = new MultiMedian(host, new MultiMedian.Arguments() { Normalize = true }); + combiner = new MultiMedian(host, new MultiMedian.Options() { Normalize = true }); break; case ClassifierCombiner.Average: - combiner = new MultiAverage(host, new MultiAverage.Arguments() { Normalize = true }); + combiner = new MultiAverage(host, new MultiAverage.Options() { Normalize = true }); break; case ClassifierCombiner.Vote: combiner = new MultiVoting(host); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs index 044ccfb44e..053c8d03ba 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiAverager.cs @@ -12,7 +12,7 @@ namespace Microsoft.ML.Trainers.Ensemble { public abstract class BaseMultiAverager : BaseMultiCombiner { - internal BaseMultiAverager(IHostEnvironment env, string name, ArgumentsBase args) + internal BaseMultiAverager(IHostEnvironment env, string name, OptionsBase args) : base(env, name, args) { } diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs index d0ee4583b3..877f08d78a 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs @@ -15,7 +15,7 @@ public abstract class BaseMultiCombiner : IMultiClassOutputCombiner, ICanSaveMod { protected readonly IHost Host; - public abstract class ArgumentsBase + public abstract class OptionsBase { [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to normalize the output of base models before combining them", ShortName = "norm", SortOrder = 50)] @@ -24,7 +24,7 @@ public abstract class ArgumentsBase protected readonly bool Normalize; - internal BaseMultiCombiner(IHostEnvironment env, string name, ArgumentsBase args) + internal BaseMultiCombiner(IHostEnvironment env, string name, OptionsBase args) { Contracts.AssertValue(env); env.AssertNonWhiteSpace(name); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs index 3cf424b50f..0bfc938fdf 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs @@ -9,7 +9,7 @@ using Microsoft.ML.Model; using Microsoft.ML.Trainers.Ensemble; -[assembly: LoadableClass(typeof(MultiAverage), typeof(MultiAverage.Arguments), typeof(SignatureCombiner), +[assembly: LoadableClass(typeof(MultiAverage), typeof(MultiAverage.Options), typeof(SignatureCombiner), Average.UserName, MultiAverage.LoadName)] [assembly: LoadableClass(typeof(MultiAverage), null, typeof(SignatureLoadModel), Average.UserName, MultiAverage.LoadName, MultiAverage.LoaderSignature)] @@ -33,13 +33,13 @@ private static VersionInfo GetVersionInfo() } [TlcModule.Component(Name = LoadName, FriendlyName = Average.UserName)] - public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerFactory + public sealed class Options : OptionsBase, ISupportMulticlassOutputCombinerFactory { public IMultiClassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiAverage(env, this); } - public MultiAverage(IHostEnvironment env, Arguments args) - : base(env, LoaderSignature, args) + public MultiAverage(IHostEnvironment env, Options options) + : base(env, LoaderSignature, options) { } diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs index 3b44413397..f7a66c3dd7 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs @@ -10,7 +10,7 @@ using Microsoft.ML.Model; using Microsoft.ML.Trainers.Ensemble; -[assembly: LoadableClass(typeof(MultiMedian), typeof(MultiMedian.Arguments), typeof(SignatureCombiner), +[assembly: LoadableClass(typeof(MultiMedian), typeof(MultiMedian.Options), typeof(SignatureCombiner), Median.UserName, MultiMedian.LoadName)] [assembly: LoadableClass(typeof(MultiMedian), null, typeof(SignatureLoadModel), Median.UserName, MultiMedian.LoaderSignature)] @@ -36,13 +36,13 @@ private static VersionInfo GetVersionInfo() } [TlcModule.Component(Name = LoadName, FriendlyName = Median.UserName)] - public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerFactory + public sealed class Options : OptionsBase, ISupportMulticlassOutputCombinerFactory { public IMultiClassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiMedian(env, this); } - public MultiMedian(IHostEnvironment env, Arguments args) - : base(env, LoaderSignature, args) + public MultiMedian(IHostEnvironment env, Options options) + : base(env, LoaderSignature, options) { } diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs index f943a4176c..917071e42a 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs @@ -33,7 +33,7 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(MultiVoting).Assembly.FullName); } - private sealed class Arguments : ArgumentsBase + private sealed class Arguments : OptionsBase { } diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs index aecc3963bc..96f05371fa 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs @@ -11,7 +11,7 @@ using Microsoft.ML.Model; using Microsoft.ML.Trainers.Ensemble; -[assembly: LoadableClass(typeof(MultiWeightedAverage), typeof(MultiWeightedAverage.Arguments), typeof(SignatureCombiner), +[assembly: LoadableClass(typeof(MultiWeightedAverage), typeof(MultiWeightedAverage.Options), typeof(SignatureCombiner), MultiWeightedAverage.UserName, MultiWeightedAverage.LoadName)] [assembly: LoadableClass(typeof(MultiWeightedAverage), null, typeof(SignatureLoadModel), @@ -40,7 +40,7 @@ private static VersionInfo GetVersionInfo() } [TlcModule.Component(Name = LoadName, FriendlyName = UserName)] - public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerFactory + public sealed class Options : OptionsBase, ISupportMulticlassOutputCombinerFactory { IMultiClassOutputCombiner IComponentFactory.CreateComponent(IHostEnvironment env) => new MultiWeightedAverage(env, this); @@ -52,11 +52,11 @@ public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerF private readonly MultiWeightageKind _weightageKind; public string WeightageMetricName { get { return _weightageKind.ToString(); } } - public MultiWeightedAverage(IHostEnvironment env, Arguments args) - : base(env, LoaderSignature, args) + public MultiWeightedAverage(IHostEnvironment env, Options options) + : base(env, LoaderSignature, options) { - _weightageKind = args.WeightageName; - Host.CheckUserArg(Enum.IsDefined(typeof(MultiWeightageKind), _weightageKind), nameof(args.WeightageName)); + _weightageKind = options.WeightageName; + Host.CheckUserArg(Enum.IsDefined(typeof(MultiWeightageKind), _weightageKind), nameof(options.WeightageName)); } private MultiWeightedAverage(IHostEnvironment env, ModelLoadContext ctx) diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs index e5aa458768..04e6f802c1 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs @@ -11,7 +11,7 @@ using Microsoft.ML.Model; using Microsoft.ML.Trainers.Ensemble; -[assembly: LoadableClass(typeof(WeightedAverage), typeof(WeightedAverage.Arguments), typeof(SignatureCombiner), +[assembly: LoadableClass(typeof(WeightedAverage), typeof(WeightedAverage.Options), typeof(SignatureCombiner), WeightedAverage.UserName, WeightedAverage.LoadName)] [assembly: LoadableClass(typeof(WeightedAverage), null, typeof(SignatureLoadModel), @@ -37,7 +37,7 @@ private static VersionInfo GetVersionInfo() } [TlcModule.Component(Name = LoadName, FriendlyName = UserName)] - public sealed class Arguments: ISupportBinaryOutputCombinerFactory + public sealed class Options: ISupportBinaryOutputCombinerFactory { [Argument(ArgumentType.AtMostOnce, HelpText = "The metric type to be used to find the weights for each model", ShortName = "wn", SortOrder = 50)] [TGUI(Label = "Weightage Name", Description = "The weights are calculated according to the selected metric")] @@ -50,11 +50,11 @@ public sealed class Arguments: ISupportBinaryOutputCombinerFactory public string WeightageMetricName { get { return _weightageKind.ToString(); } } - public WeightedAverage(IHostEnvironment env, Arguments args) + public WeightedAverage(IHostEnvironment env, Options options) : base(env, LoaderSignature) { - _weightageKind = args.WeightageName; - Host.CheckUserArg(Enum.IsDefined(typeof(WeightageKind), _weightageKind), nameof(args.WeightageName)); + _weightageKind = options.WeightageName; + Host.CheckUserArg(Enum.IsDefined(typeof(WeightageKind), _weightageKind), nameof(options.WeightageName)); } private WeightedAverage(IHostEnvironment env, ModelLoadContext ctx) diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs index 4420b15ba9..0ee01dfed4 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs @@ -44,7 +44,7 @@ public sealed class Arguments : ArgumentsBase [Argument(ArgumentType.Multiple, HelpText = "Output combiner", ShortName = "oc", SortOrder = 5)] [TGUI(Label = "Output combiner", Description = "Output combiner type")] - public ISupportMulticlassOutputCombinerFactory OutputCombiner = new MultiMedian.Arguments(); + public ISupportMulticlassOutputCombinerFactory OutputCombiner = new MultiMedian.Options(); // REVIEW: If we make this public again it should be an *estimator* of this type of predictor, rather than the (deprecated) ITrainer. [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureMultiClassClassifierTrainer))] diff --git a/src/Microsoft.ML.EntryPoints/ImportTextData.cs b/src/Microsoft.ML.EntryPoints/ImportTextData.cs index 09b878448c..50544ce16d 100644 --- a/src/Microsoft.ML.EntryPoints/ImportTextData.cs +++ b/src/Microsoft.ML.EntryPoints/ImportTextData.cs @@ -49,7 +49,7 @@ public sealed class LoaderInput public IFileHandle InputFile; [Argument(ArgumentType.Required, ShortName = "args", HelpText = "Arguments", SortOrder = 2)] - public TextLoader.Arguments Arguments = new TextLoader.Arguments(); + public TextLoader.Options Arguments = new TextLoader.Options(); } [TlcModule.EntryPoint(Name = "Data.TextLoader", Desc = "Import a dataset from a text file")] diff --git a/src/Microsoft.ML.FastTree/BoostingFastTree.cs b/src/Microsoft.ML.FastTree/BoostingFastTree.cs index 13facf2757..d154195982 100644 --- a/src/Microsoft.ML.FastTree/BoostingFastTree.cs +++ b/src/Microsoft.ML.FastTree/BoostingFastTree.cs @@ -29,32 +29,32 @@ protected BoostingFastTreeTrainerBase(IHostEnvironment env, double learningRate) : base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDatapointsInLeaves) { - Args.LearningRates = learningRate; + FastTreeTrainerOptions.LearningRates = learningRate; } protected override void CheckArgs(IChannel ch) { - if (Args.OptimizationAlgorithm == BoostedTreeArgs.OptimizationAlgorithmType.AcceleratedGradientDescent) - Args.UseLineSearch = true; - if (Args.OptimizationAlgorithm == BoostedTreeArgs.OptimizationAlgorithmType.ConjugateGradientDescent) - Args.UseLineSearch = true; + if (FastTreeTrainerOptions.OptimizationAlgorithm == BoostedTreeArgs.OptimizationAlgorithmType.AcceleratedGradientDescent) + FastTreeTrainerOptions.UseLineSearch = true; + if (FastTreeTrainerOptions.OptimizationAlgorithm == BoostedTreeArgs.OptimizationAlgorithmType.ConjugateGradientDescent) + FastTreeTrainerOptions.UseLineSearch = true; - if (Args.CompressEnsemble && Args.WriteLastEnsemble) + if (FastTreeTrainerOptions.CompressEnsemble && FastTreeTrainerOptions.WriteLastEnsemble) throw ch.Except("Ensemble compression cannot be done when forcing to write last ensemble (hl)"); - if (Args.NumLeaves > 2 && Args.HistogramPoolSize > Args.NumLeaves - 1) + if (FastTreeTrainerOptions.NumLeaves > 2 && FastTreeTrainerOptions.HistogramPoolSize > FastTreeTrainerOptions.NumLeaves - 1) throw ch.Except("Histogram pool size (ps) must be at least 2."); - if (Args.NumLeaves > 2 && Args.HistogramPoolSize > Args.NumLeaves - 1) + if (FastTreeTrainerOptions.NumLeaves > 2 && FastTreeTrainerOptions.HistogramPoolSize > FastTreeTrainerOptions.NumLeaves - 1) throw ch.Except("Histogram pool size (ps) must be at most numLeaves - 1."); - if (Args.EnablePruning && !HasValidSet) + if (FastTreeTrainerOptions.EnablePruning && !HasValidSet) throw ch.Except("Cannot perform pruning (pruning) without a validation set (valid)."); - if (Args.EarlyStoppingRule != null && !HasValidSet) + if (FastTreeTrainerOptions.EarlyStoppingRule != null && !HasValidSet) throw ch.Except("Cannot perform early stopping without a validation set (valid)."); - if (Args.UseTolerantPruning && (!Args.EnablePruning || !HasValidSet)) + if (FastTreeTrainerOptions.UseTolerantPruning && (!FastTreeTrainerOptions.EnablePruning || !HasValidSet)) throw ch.Except("Cannot perform tolerant pruning (prtol) without pruning (pruning) and a validation set (valid)"); base.CheckArgs(ch); @@ -63,12 +63,12 @@ protected override void CheckArgs(IChannel ch) private protected override TreeLearner ConstructTreeLearner(IChannel ch) { return new LeastSquaresRegressionTreeLearner( - TrainSet, Args.NumLeaves, Args.MinDocumentsInLeafs, Args.EntropyCoefficient, - Args.FeatureFirstUsePenalty, Args.FeatureReusePenalty, Args.SoftmaxTemperature, - Args.HistogramPoolSize, Args.RngSeed, Args.SplitFraction, Args.FilterZeroLambdas, - Args.AllowEmptyTrees, Args.GainConfidenceLevel, Args.MaxCategoricalGroupsPerNode, - Args.MaxCategoricalSplitPoints, BsrMaxTreeOutput(), ParallelTraining, - Args.MinDocsPercentageForCategoricalSplit, Args.Bundling, Args.MinDocsForCategoricalSplit, Args.Bias); + TrainSet, FastTreeTrainerOptions.NumLeaves, FastTreeTrainerOptions.MinDocumentsInLeafs, FastTreeTrainerOptions.EntropyCoefficient, + FastTreeTrainerOptions.FeatureFirstUsePenalty, FastTreeTrainerOptions.FeatureReusePenalty, FastTreeTrainerOptions.SoftmaxTemperature, + FastTreeTrainerOptions.HistogramPoolSize, FastTreeTrainerOptions.RngSeed, FastTreeTrainerOptions.SplitFraction, FastTreeTrainerOptions.FilterZeroLambdas, + FastTreeTrainerOptions.AllowEmptyTrees, FastTreeTrainerOptions.GainConfidenceLevel, FastTreeTrainerOptions.MaxCategoricalGroupsPerNode, + FastTreeTrainerOptions.MaxCategoricalSplitPoints, BsrMaxTreeOutput(), ParallelTraining, + FastTreeTrainerOptions.MinDocsPercentageForCategoricalSplit, FastTreeTrainerOptions.Bundling, FastTreeTrainerOptions.MinDocsForCategoricalSplit, FastTreeTrainerOptions.Bias); } private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch) @@ -77,7 +77,7 @@ private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm( OptimizationAlgorithm optimizationAlgorithm; IGradientAdjuster gradientWrapper = MakeGradientWrapper(ch); - switch (Args.OptimizationAlgorithm) + switch (FastTreeTrainerOptions.OptimizationAlgorithm) { case BoostedTreeArgs.OptimizationAlgorithmType.GradientDescent: optimizationAlgorithm = new GradientDescent(Ensemble, TrainSet, InitTrainScores, gradientWrapper); @@ -89,14 +89,14 @@ private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm( optimizationAlgorithm = new ConjugateGradientDescent(Ensemble, TrainSet, InitTrainScores, gradientWrapper); break; default: - throw ch.Except("Unknown optimization algorithm '{0}'", Args.OptimizationAlgorithm); + throw ch.Except("Unknown optimization algorithm '{0}'", FastTreeTrainerOptions.OptimizationAlgorithm); } optimizationAlgorithm.TreeLearner = ConstructTreeLearner(ch); optimizationAlgorithm.ObjectiveFunction = ConstructObjFunc(ch); - optimizationAlgorithm.Smoothing = Args.Smoothing; - optimizationAlgorithm.DropoutRate = Args.DropoutRate; - optimizationAlgorithm.DropoutRng = new Random(Args.RngSeed); + optimizationAlgorithm.Smoothing = FastTreeTrainerOptions.Smoothing; + optimizationAlgorithm.DropoutRate = FastTreeTrainerOptions.DropoutRate; + optimizationAlgorithm.DropoutRng = new Random(FastTreeTrainerOptions.RngSeed); optimizationAlgorithm.PreScoreUpdateEvent += PrintTestGraph; return optimizationAlgorithm; @@ -104,7 +104,7 @@ private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm( protected override IGradientAdjuster MakeGradientWrapper(IChannel ch) { - if (!Args.BestStepRankingRegressionTrees) + if (!FastTreeTrainerOptions.BestStepRankingRegressionTrees) return base.MakeGradientWrapper(ch); // REVIEW: If this is ranking specific than cmd.bestStepRankingRegressionTrees and @@ -117,7 +117,7 @@ protected override IGradientAdjuster MakeGradientWrapper(IChannel ch) protected override bool ShouldStop(IChannel ch, ref IEarlyStoppingCriterion earlyStoppingRule, ref int bestIteration) { - if (Args.EarlyStoppingRule == null) + if (FastTreeTrainerOptions.EarlyStoppingRule == null) return false; ch.AssertValue(ValidTest); @@ -133,7 +133,7 @@ protected override bool ShouldStop(IChannel ch, ref IEarlyStoppingCriterion earl // Create early stopping rule. if (earlyStoppingRule == null) { - earlyStoppingRule = Args.EarlyStoppingRule.CreateComponent(Host, lowerIsBetter); + earlyStoppingRule = FastTreeTrainerOptions.EarlyStoppingRule.CreateComponent(Host, lowerIsBetter); ch.Assert(earlyStoppingRule != null); } @@ -150,7 +150,7 @@ protected override bool ShouldStop(IChannel ch, ref IEarlyStoppingCriterion earl protected override int GetBestIteration(IChannel ch) { int bestIteration = Ensemble.NumTrees; - if (!Args.WriteLastEnsemble && PruningTest != null) + if (!FastTreeTrainerOptions.WriteLastEnsemble && PruningTest != null) { bestIteration = PruningTest.BestIteration; ch.Info("Pruning picked iteration {0}", bestIteration); @@ -163,15 +163,15 @@ protected override int GetBestIteration(IChannel ch) /// protected double BsrMaxTreeOutput() { - if (Args.BestStepRankingRegressionTrees) - return Args.MaxTreeOutput; + if (FastTreeTrainerOptions.BestStepRankingRegressionTrees) + return FastTreeTrainerOptions.MaxTreeOutput; else return -1; } protected override bool ShouldRandomStartOptimizer() { - return Args.RandomStart; + return FastTreeTrainerOptions.RandomStart; } } } diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index a3cd3cc550..b35acb44eb 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -46,13 +46,13 @@ internal static class FastTreeShared public static readonly object TrainLock = new object(); } - public abstract class FastTreeTrainerBase : + public abstract class FastTreeTrainerBase : TrainerEstimatorBaseWithGroupId where TTransformer : ISingleFeaturePredictionTransformer - where TArgs : TreeArgs, new() + where TOptions : TreeOptions, new() where TModel : class { - protected readonly TArgs Args; + protected readonly TOptions FastTreeTrainerOptions; protected readonly bool AllowGC; protected int FeatureCount; private protected InternalTreeEnsemble TrainedEnsemble; @@ -95,7 +95,7 @@ public abstract class FastTreeTrainerBase : // random for active features selection private Random _featureSelectionRandom; - protected string InnerArgs => CmdParser.GetSettings(Host, Args, new TArgs()); + protected string InnerArgs => CmdParser.GetSettings(Host, FastTreeTrainerOptions, new TOptions()); public override TrainerInfo Info { get; } @@ -116,22 +116,22 @@ private protected FastTreeTrainerBase(IHostEnvironment env, int minDatapointsInLeaves) : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn)) { - Args = new TArgs(); + FastTreeTrainerOptions = new TOptions(); // set up the directly provided values // override with the directly provided values. - Args.NumLeaves = numLeaves; - Args.NumTrees = numTrees; - Args.MinDocumentsInLeafs = minDatapointsInLeaves; + FastTreeTrainerOptions.NumLeaves = numLeaves; + FastTreeTrainerOptions.NumTrees = numTrees; + FastTreeTrainerOptions.MinDocumentsInLeafs = minDatapointsInLeaves; - Args.LabelColumn = label.Name; - Args.FeatureColumn = featureColumn; + FastTreeTrainerOptions.LabelColumn = label.Name; + FastTreeTrainerOptions.FeatureColumn = featureColumn; if (weightColumn != null) - Args.WeightColumn = Optional.Explicit(weightColumn); + FastTreeTrainerOptions.WeightColumn = Optional.Explicit(weightColumn); if (groupIdColumn != null) - Args.GroupIdColumn = Optional.Explicit(groupIdColumn); + FastTreeTrainerOptions.GroupIdColumn = Optional.Explicit(groupIdColumn); // The discretization step renders this trainer non-parametric, and therefore it does not need normalization. // Also since it builds its own internal discretized columnar structures, it cannot benefit from caching. @@ -149,11 +149,11 @@ private protected FastTreeTrainerBase(IHostEnvironment env, /// /// Constructor that is used when invoking the classes deriving from this, through maml. /// - private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label) - : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) + private protected FastTreeTrainerBase(IHostEnvironment env, TOptions options, SchemaShape.Column label) + : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(options.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(options.WeightColumn)) { - Host.CheckValue(args, nameof(args)); - Args = args; + Host.CheckValue(options, nameof(options)); + FastTreeTrainerOptions = options; // The discretization step renders this trainer non-parametric, and therefore it does not need normalization. // Also since it builds its own internal discretized columnar structures, it cannot benefit from caching. // Finally, even the binary classifiers, being logitboost, tend to not benefit from external calibration. @@ -186,7 +186,7 @@ protected virtual Float GetMaxLabel() private void Initialize(IHostEnvironment env) { - int numThreads = Args.NumThreads ?? Environment.ProcessorCount; + int numThreads = FastTreeTrainerOptions.NumThreads ?? Environment.ProcessorCount; if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor) { using (var ch = Host.Start("FastTreeTrainerBase")) @@ -196,7 +196,7 @@ private void Initialize(IHostEnvironment env) + "setting of the environment. Using {0} training threads instead.", numThreads); } } - ParallelTraining = Args.ParallelTrainer != null ? Args.ParallelTrainer.CreateComponent(env) : new SingleTrainer(); + ParallelTraining = FastTreeTrainerOptions.ParallelTrainer != null ? FastTreeTrainerOptions.ParallelTrainer.CreateComponent(env) : new SingleTrainer(); ParallelTraining.InitEnvironment(); Tests = new List(); @@ -207,15 +207,15 @@ private void Initialize(IHostEnvironment env) private protected void ConvertData(RoleMappedData trainData) { MetadataUtils.TryGetCategoricalFeatureIndices(trainData.Schema.Schema, trainData.Schema.Feature.Value.Index, out CategoricalFeatures); - var useTranspose = UseTranspose(Args.DiskTranspose, trainData) && (ValidData == null || UseTranspose(Args.DiskTranspose, ValidData)); - var instanceConverter = new ExamplesToFastTreeBins(Host, Args.MaxBins, useTranspose, !Args.FeatureFlocks, Args.MinDocumentsInLeafs, GetMaxLabel()); + var useTranspose = UseTranspose(FastTreeTrainerOptions.DiskTranspose, trainData) && (ValidData == null || UseTranspose(FastTreeTrainerOptions.DiskTranspose, ValidData)); + var instanceConverter = new ExamplesToFastTreeBins(Host, FastTreeTrainerOptions.MaxBins, useTranspose, !FastTreeTrainerOptions.FeatureFlocks, FastTreeTrainerOptions.MinDocumentsInLeafs, GetMaxLabel()); - TrainSet = instanceConverter.FindBinsAndReturnDataset(trainData, PredictionKind, ParallelTraining, CategoricalFeatures, Args.CategoricalSplit); + TrainSet = instanceConverter.FindBinsAndReturnDataset(trainData, PredictionKind, ParallelTraining, CategoricalFeatures, FastTreeTrainerOptions.CategoricalSplit); FeatureMap = instanceConverter.FeatureMap; if (ValidData != null) - ValidSet = instanceConverter.GetCompatibleDataset(ValidData, PredictionKind, CategoricalFeatures, Args.CategoricalSplit); + ValidSet = instanceConverter.GetCompatibleDataset(ValidData, PredictionKind, CategoricalFeatures, FastTreeTrainerOptions.CategoricalSplit); if (TestData != null) - TestSets = new[] { instanceConverter.GetCompatibleDataset(TestData, PredictionKind, CategoricalFeatures, Args.CategoricalSplit) }; + TestSets = new[] { instanceConverter.GetCompatibleDataset(TestData, PredictionKind, CategoricalFeatures, FastTreeTrainerOptions.CategoricalSplit) }; } private bool UseTranspose(bool? useTranspose, RoleMappedData data) @@ -246,7 +246,7 @@ protected void TrainCore(IChannel ch) } using (Timer.Time(TimerEvent.TotalTrain)) Train(ch); - if (Args.ExecutionTimes) + if (FastTreeTrainerOptions.ExecutionTimes) PrintExecutionTimes(ch); TrainedEnsemble = Ensemble; if (FeatureMap != null) @@ -274,24 +274,24 @@ protected virtual void PrintExecutionTimes(IChannel ch) protected virtual void CheckArgs(IChannel ch) { - Args.Check(ch); + FastTreeTrainerOptions.Check(ch); - IntArray.CompatibilityLevel = Args.FeatureCompressionLevel; + IntArray.CompatibilityLevel = FastTreeTrainerOptions.FeatureCompressionLevel; // change arguments - if (Args.HistogramPoolSize < 2) - Args.HistogramPoolSize = Args.NumLeaves * 2 / 3; - if (Args.HistogramPoolSize > Args.NumLeaves - 1) - Args.HistogramPoolSize = Args.NumLeaves - 1; + if (FastTreeTrainerOptions.HistogramPoolSize < 2) + FastTreeTrainerOptions.HistogramPoolSize = FastTreeTrainerOptions.NumLeaves * 2 / 3; + if (FastTreeTrainerOptions.HistogramPoolSize > FastTreeTrainerOptions.NumLeaves - 1) + FastTreeTrainerOptions.HistogramPoolSize = FastTreeTrainerOptions.NumLeaves - 1; - if (Args.BaggingSize > 0) + if (FastTreeTrainerOptions.BaggingSize > 0) { - int bagCount = Args.NumTrees / Args.BaggingSize; - if (bagCount * Args.BaggingSize != Args.NumTrees) + int bagCount = FastTreeTrainerOptions.NumTrees / FastTreeTrainerOptions.BaggingSize; + if (bagCount * FastTreeTrainerOptions.BaggingSize != FastTreeTrainerOptions.NumTrees) throw ch.Except("Number of trees should be a multiple of number bag size"); } - if (!(0 <= Args.GainConfidenceLevel && Args.GainConfidenceLevel < 1)) + if (!(0 <= FastTreeTrainerOptions.GainConfidenceLevel && FastTreeTrainerOptions.GainConfidenceLevel < 1)) throw ch.Except("Gain confidence level must be in the range [0,1)"); #if OLD_DATALOAD @@ -342,7 +342,7 @@ protected void PrintTestGraph(IChannel ch) // we call Tests computing no matter whether we require to print test graph ComputeTests(); - if (!Args.PrintTestGraph) + if (!FastTreeTrainerOptions.PrintTestGraph) return; if (Ensemble.NumTrees == 0) @@ -430,15 +430,15 @@ private float GetFeaturePercentInMemory(IChannel ch) protected bool[] GetActiveFeatures() { var activeFeatures = Utils.CreateArray(TrainSet.NumFeatures, true); - if (Args.FeatureFraction < 1.0) + if (FastTreeTrainerOptions.FeatureFraction < 1.0) { if (_featureSelectionRandom == null) - _featureSelectionRandom = new Random(Args.FeatureSelectSeed); + _featureSelectionRandom = new Random(FastTreeTrainerOptions.FeatureSelectSeed); for (int i = 0; i < TrainSet.NumFeatures; ++i) { if (activeFeatures[i]) - activeFeatures[i] = _featureSelectionRandom.NextDouble() <= Args.FeatureFraction; + activeFeatures[i] = _featureSelectionRandom.NextDouble() <= FastTreeTrainerOptions.FeatureFraction; } } @@ -602,8 +602,8 @@ private void GenerateActiveFeatureLists(int numberOfItems) protected virtual BaggingProvider CreateBaggingProvider() { - Contracts.Assert(Args.BaggingSize > 0); - return new BaggingProvider(TrainSet, Args.NumLeaves, Args.RngSeed, Args.BaggingTrainFraction); + Contracts.Assert(FastTreeTrainerOptions.BaggingSize > 0); + return new BaggingProvider(TrainSet, FastTreeTrainerOptions.NumLeaves, FastTreeTrainerOptions.RngSeed, FastTreeTrainerOptions.BaggingTrainFraction); } protected virtual bool ShouldRandomStartOptimizer() @@ -614,7 +614,7 @@ protected virtual bool ShouldRandomStartOptimizer() protected virtual void Train(IChannel ch) { Contracts.AssertValue(ch); - int numTotalTrees = Args.NumTrees; + int numTotalTrees = FastTreeTrainerOptions.NumTrees; ch.Info( "Reserved memory for tree learner: {0} bytes", @@ -634,13 +634,13 @@ protected virtual void Train(IChannel ch) if (Ensemble.NumTrees < numTotalTrees && ShouldRandomStartOptimizer()) { ch.Info("Randomizing start point"); - OptimizationAlgorithm.TrainingScores.RandomizeScores(Args.RngSeed, false); + OptimizationAlgorithm.TrainingScores.RandomizeScores(FastTreeTrainerOptions.RngSeed, false); revertRandomStart = true; } ch.Info("Starting to train ..."); - BaggingProvider baggingProvider = Args.BaggingSize > 0 ? CreateBaggingProvider() : null; + BaggingProvider baggingProvider = FastTreeTrainerOptions.BaggingSize > 0 ? CreateBaggingProvider() : null; #if OLD_DATALOAD #if !NO_STORE @@ -676,7 +676,7 @@ protected virtual void Train(IChannel ch) bool[] activeFeatures = _activeFeatureSetQueue.Dequeue(); #endif - if (Args.BaggingSize > 0 && Ensemble.NumTrees % Args.BaggingSize == 0) + if (FastTreeTrainerOptions.BaggingSize > 0 && Ensemble.NumTrees % FastTreeTrainerOptions.BaggingSize == 0) { baggingProvider.GenerateNewBag(); OptimizationAlgorithm.TreeLearner.Partitioning = @@ -699,7 +699,7 @@ protected virtual void Train(IChannel ch) emptyTrees++; numTotalTrees--; } - else if (Args.BaggingSize > 0 && Ensemble.Trees.Count() > 0) + else if (FastTreeTrainerOptions.BaggingSize > 0 && Ensemble.Trees.Count() > 0) { ch.Assert(Ensemble.Trees.Last() == tree); Ensemble.Trees.Last() @@ -721,7 +721,7 @@ protected virtual void Train(IChannel ch) { revertRandomStart = false; ch.Info("Reverting random score assignment"); - OptimizationAlgorithm.TrainingScores.RandomizeScores(Args.RngSeed, true); + OptimizationAlgorithm.TrainingScores.RandomizeScores(FastTreeTrainerOptions.RngSeed, true); } #if !NO_STORE @@ -806,7 +806,7 @@ protected virtual void PrintIterationMessage(IChannel ch, IProgressChannel pch) protected virtual void PrintTestResults(IChannel ch) { - if (Args.TestFrequency != int.MaxValue && (Ensemble.NumTrees % Args.TestFrequency == 0 || Ensemble.NumTrees == Args.NumTrees)) + if (FastTreeTrainerOptions.TestFrequency != int.MaxValue && (Ensemble.NumTrees % FastTreeTrainerOptions.TestFrequency == 0 || Ensemble.NumTrees == FastTreeTrainerOptions.NumTrees)) { var sb = new StringBuilder(); using (var sw = new StringWriter(sb)) @@ -826,9 +826,9 @@ protected virtual void PrintPrologInfo(IChannel ch) { Contracts.AssertValue(ch); ch.Trace("Host = {0}", Environment.MachineName); - ch.Trace("CommandLine = {0}", CmdParser.GetSettings(Host, Args, new TArgs())); + ch.Trace("CommandLine = {0}", CmdParser.GetSettings(Host, FastTreeTrainerOptions, new TOptions())); ch.Trace("GCSettings.IsServerGC = {0}", System.Runtime.GCSettings.IsServerGC); - ch.Trace("{0}", Args); + ch.Trace("{0}", FastTreeTrainerOptions); } protected ScoreTracker ConstructScoreTracker(Dataset set) @@ -856,7 +856,7 @@ protected ScoreTracker ConstructScoreTracker(Dataset set) private double[] ComputeScoresSmart(IChannel ch, Dataset set) { - if (!Args.CompressEnsemble) + if (!FastTreeTrainerOptions.CompressEnsemble) { foreach (var st in OptimizationAlgorithm.TrackedScores) if (st.Dataset == set) diff --git a/src/Microsoft.ML.FastTree/FastTreeArguments.cs b/src/Microsoft.ML.FastTree/FastTreeArguments.cs index 97eae76ac6..f3afaa18e0 100644 --- a/src/Microsoft.ML.FastTree/FastTreeArguments.cs +++ b/src/Microsoft.ML.FastTree/FastTreeArguments.cs @@ -150,7 +150,7 @@ internal static class Defaults public const double LearningRates = 0.2; } - public abstract class TreeArgs : LearnerInputBaseWithGroupId + public abstract class TreeOptions : LearnerInputBaseWithGroupId { /// /// Allows to choose Parallel FastTree Learning Algorithm. @@ -442,7 +442,7 @@ internal virtual void Check(IExceptionContext ectx) } } - public abstract class BoostedTreeArgs : TreeArgs + public abstract class BoostedTreeArgs : TreeOptions { // REVIEW: TLC FR likes to call it bestStepRegressionTrees which might be more appropriate. //Use the second derivative for split gains (not just outputs). Use MaxTreeOutput to "clip" cases where the second derivative is too close to zero. diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index a2428aec27..e59afb3358 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -139,7 +139,7 @@ internal FastTreeBinaryClassificationTrainer(IHostEnvironment env, : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDatapointsInLeaves, learningRate) { // Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss - _sigmoidParameter = 2.0 * Args.LearningRates; + _sigmoidParameter = 2.0 * FastTreeTrainerOptions.LearningRates; } /// @@ -151,7 +151,7 @@ internal FastTreeBinaryClassificationTrainer(IHostEnvironment env, Options optio : base(env, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn)) { // Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss - _sigmoidParameter = 2.0 * Args.LearningRates; + _sigmoidParameter = 2.0 * FastTreeTrainerOptions.LearningRates; } public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; @@ -193,25 +193,25 @@ protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch) return new ObjectiveImpl( TrainSet, _trainSetLabels, - Args.LearningRates, - Args.Shrinkage, + FastTreeTrainerOptions.LearningRates, + FastTreeTrainerOptions.Shrinkage, _sigmoidParameter, - Args.UnbalancedSets, - Args.MaxTreeOutput, - Args.GetDerivativesSampleRate, - Args.BestStepRankingRegressionTrees, - Args.RngSeed, + FastTreeTrainerOptions.UnbalancedSets, + FastTreeTrainerOptions.MaxTreeOutput, + FastTreeTrainerOptions.GetDerivativesSampleRate, + FastTreeTrainerOptions.BestStepRankingRegressionTrees, + FastTreeTrainerOptions.RngSeed, ParallelTraining); } private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch) { OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch); - if (Args.UseLineSearch) + if (FastTreeTrainerOptions.UseLineSearch) { var lossCalculator = new BinaryClassificationTest(optimizationAlgorithm.TrainingScores, _trainSetLabels, _sigmoidParameter); // REVIEW: we should makeloss indices an enum in BinaryClassificationTest - optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, Args.UnbalancedSets ? 3 /*Unbalanced sets loss*/ : 1 /*normal loss*/, Args.NumPostBracketSteps, Args.MinStepSize); + optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, FastTreeTrainerOptions.UnbalancedSets ? 3 /*Unbalanced sets loss*/ : 1 /*normal loss*/, FastTreeTrainerOptions.NumPostBracketSteps, FastTreeTrainerOptions.MinStepSize); } return optimizationAlgorithm; } @@ -258,9 +258,9 @@ protected override void InitializeTests() } } - if (Args.EnablePruning && ValidSet != null) + if (FastTreeTrainerOptions.EnablePruning && ValidSet != null) { - if (!Args.UseTolerantPruning) + if (!FastTreeTrainerOptions.UseTolerantPruning) { //use simple early stopping condition PruningTest = new TestHistory(ValidTest, 0); @@ -268,7 +268,7 @@ protected override void InitializeTests() else { //use tollerant stopping condition - PruningTest = new TestWindowWithTolerance(ValidTest, 0, Args.PruningWindowSize, Args.PruningThreshold); + PruningTest = new TestWindowWithTolerance(ValidTest, 0, FastTreeTrainerOptions.PruningWindowSize, FastTreeTrainerOptions.PruningThreshold); } } } diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index b3d0aa12af..fa4e137376 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -132,8 +132,8 @@ private Double[] GetLabelGains() { try { - Host.AssertValue(Args.CustomGains); - return Args.CustomGains.Split(',').Select(k => Convert.ToDouble(k.Trim())).ToArray(); + Host.AssertValue(FastTreeTrainerOptions.CustomGains); + return FastTreeTrainerOptions.CustomGains.Split(',').Select(k => Convert.ToDouble(k.Trim())).ToArray(); } catch (Exception ex) { @@ -145,12 +145,12 @@ private Double[] GetLabelGains() protected override void CheckArgs(IChannel ch) { - if (!string.IsNullOrEmpty(Args.CustomGains)) + if (!string.IsNullOrEmpty(FastTreeTrainerOptions.CustomGains)) { - var stringGain = Args.CustomGains.Split(','); + var stringGain = FastTreeTrainerOptions.CustomGains.Split(','); if (stringGain.Length < 5) { - throw ch.ExceptUserArg(nameof(Args.CustomGains), + throw ch.ExceptUserArg(nameof(FastTreeTrainerOptions.CustomGains), "{0} an invalid number of gain levels. We require at least 5. Make certain they're comma separated.", stringGain.Length); } @@ -159,7 +159,7 @@ protected override void CheckArgs(IChannel ch) { if (!Double.TryParse(stringGain[i], out gain[i])) { - throw ch.ExceptUserArg(nameof(Args.CustomGains), + throw ch.ExceptUserArg(nameof(FastTreeTrainerOptions.CustomGains), "Could not parse '{0}' as a floating point number", stringGain[0]); } } @@ -167,7 +167,7 @@ protected override void CheckArgs(IChannel ch) Dataset.DatasetSkeleton.LabelGainMap = gain; } - ch.CheckUserArg((Args.EarlyStoppingRule == null && !Args.EnablePruning) || (Args.EarlyStoppingMetrics == 1 || Args.EarlyStoppingMetrics == 3), nameof(Args.EarlyStoppingMetrics), + ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 3), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "earlyStoppingMetrics should be 1 or 3."); base.CheckArgs(ch); @@ -176,33 +176,33 @@ protected override void CheckArgs(IChannel ch) protected override void Initialize(IChannel ch) { base.Initialize(ch); - if (Args.CompressEnsemble) + if (FastTreeTrainerOptions.CompressEnsemble) { _ensembleCompressor = new LassoBasedEnsembleCompressor(); - _ensembleCompressor.Initialize(Args.NumTrees, TrainSet, TrainSet.Ratings, Args.RngSeed); + _ensembleCompressor.Initialize(FastTreeTrainerOptions.NumTrees, TrainSet, TrainSet.Ratings, FastTreeTrainerOptions.RngSeed); } } protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch) { - return new LambdaRankObjectiveFunction(TrainSet, TrainSet.Ratings, Args, ParallelTraining); + return new LambdaRankObjectiveFunction(TrainSet, TrainSet.Ratings, FastTreeTrainerOptions, ParallelTraining); } private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch) { OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch); - if (Args.UseLineSearch) + if (FastTreeTrainerOptions.UseLineSearch) { - _specialTrainSetTest = new FastNdcgTest(optimizationAlgorithm.TrainingScores, TrainSet.Ratings, Args.SortingAlgorithm, Args.EarlyStoppingMetrics); - optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(_specialTrainSetTest, 0, Args.NumPostBracketSteps, Args.MinStepSize); + _specialTrainSetTest = new FastNdcgTest(optimizationAlgorithm.TrainingScores, TrainSet.Ratings, FastTreeTrainerOptions.SortingAlgorithm, FastTreeTrainerOptions.EarlyStoppingMetrics); + optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(_specialTrainSetTest, 0, FastTreeTrainerOptions.NumPostBracketSteps, FastTreeTrainerOptions.MinStepSize); } return optimizationAlgorithm; } protected override BaggingProvider CreateBaggingProvider() { - Host.Assert(Args.BaggingSize > 0); - return new RankingBaggingProvider(TrainSet, Args.NumLeaves, Args.RngSeed, Args.BaggingTrainFraction); + Host.Assert(FastTreeTrainerOptions.BaggingSize > 0); + return new RankingBaggingProvider(TrainSet, FastTreeTrainerOptions.NumLeaves, FastTreeTrainerOptions.RngSeed, FastTreeTrainerOptions.BaggingTrainFraction); } protected override void PrepareLabels(IChannel ch) @@ -211,17 +211,17 @@ protected override void PrepareLabels(IChannel ch) protected override Test ConstructTestForTrainingData() { - return new NdcgTest(ConstructScoreTracker(TrainSet), TrainSet.Ratings, Args.SortingAlgorithm); + return new NdcgTest(ConstructScoreTracker(TrainSet), TrainSet.Ratings, FastTreeTrainerOptions.SortingAlgorithm); } protected override void InitializeTests() { - if (Args.TestFrequency != int.MaxValue) + if (FastTreeTrainerOptions.TestFrequency != int.MaxValue) { AddFullTests(); } - if (Args.PrintTestGraph) + if (FastTreeTrainerOptions.PrintTestGraph) { // If FirstTestHistory is null (which means the tests were not intialized due to /tf==infinity) // We need initialize first set for graph printing @@ -238,14 +238,14 @@ protected override void InitializeTests() if (ValidSet != null) ValidTest = CreateSpecialValidSetTest(); - if (Args.PrintTrainValidGraph && Args.EnablePruning && _specialTrainSetTest == null) + if (FastTreeTrainerOptions.PrintTrainValidGraph && FastTreeTrainerOptions.EnablePruning && _specialTrainSetTest == null) { _specialTrainSetTest = CreateSpecialTrainSetTest(); } - if (Args.EnablePruning && ValidTest != null) + if (FastTreeTrainerOptions.EnablePruning && ValidTest != null) { - if (!Args.UseTolerantPruning) + if (!FastTreeTrainerOptions.UseTolerantPruning) { //use simple eraly stopping condition PruningTest = new TestHistory(ValidTest, 0); @@ -253,7 +253,7 @@ protected override void InitializeTests() else { //use tolerant stopping condition - PruningTest = new TestWindowWithTolerance(ValidTest, 0, Args.PruningWindowSize, Args.PruningThreshold); + PruningTest = new TestWindowWithTolerance(ValidTest, 0, FastTreeTrainerOptions.PruningWindowSize, FastTreeTrainerOptions.PruningThreshold); } } } @@ -375,7 +375,7 @@ protected override void Train(IChannel ch) private protected override void CustomizedTrainingIteration(InternalRegressionTree tree) { Contracts.AssertValueOrNull(tree); - if (tree != null && Args.CompressEnsemble) + if (tree != null && FastTreeTrainerOptions.CompressEnsemble) { double[] trainOutputs = Ensemble.GetTreeAt(Ensemble.NumTrees - 1).GetOutputs(TrainSet); _ensembleCompressor.SetTreeScores(Ensemble.NumTrees - 1, trainOutputs); @@ -395,7 +395,7 @@ private Test CreateStandardTest(Dataset dataset) return new NdcgTest( ConstructScoreTracker(dataset), dataset.Ratings, - Args.SortingAlgorithm); + FastTreeTrainerOptions.SortingAlgorithm); } /// @@ -408,8 +408,8 @@ private Test CreateSpecialTrainSetTest() OptimizationAlgorithm.TrainingScores, OptimizationAlgorithm.ObjectiveFunction as LambdaRankObjectiveFunction, TrainSet.Ratings, - Args.SortingAlgorithm, - Args.EarlyStoppingMetrics); + FastTreeTrainerOptions.SortingAlgorithm, + FastTreeTrainerOptions.EarlyStoppingMetrics); } /// @@ -421,8 +421,8 @@ private Test CreateSpecialValidSetTest() return new FastNdcgTest( ConstructScoreTracker(ValidSet), ValidSet.Ratings, - Args.SortingAlgorithm, - Args.EarlyStoppingMetrics); + FastTreeTrainerOptions.SortingAlgorithm, + FastTreeTrainerOptions.EarlyStoppingMetrics); } /// @@ -442,12 +442,12 @@ protected override string GetTestGraphHeader() { StringBuilder headerBuilder = new StringBuilder("Eval:\tFileName\tNDCG@1\tNDCG@2\tNDCG@3\tNDCG@4\tNDCG@5\tNDCG@6\tNDCG@7\tNDCG@8\tNDCG@9\tNDCG@10"); - if (Args.PrintTrainValidGraph) + if (FastTreeTrainerOptions.PrintTrainValidGraph) { headerBuilder.Append("\tNDCG@20\tNDCG@40"); headerBuilder.AppendFormat( "\nNote: Printing train NDCG@{0} as NDCG@20 and validation NDCG@{0} as NDCG@40..\n", - Args.EarlyStoppingMetrics); + FastTreeTrainerOptions.EarlyStoppingMetrics); } return headerBuilder.ToString(); diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 52a9185b9d..e6bdaa7e54 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -107,7 +107,7 @@ protected override void CheckArgs(IChannel ch) base.CheckArgs(ch); - ch.CheckUserArg((Args.EarlyStoppingRule == null && !Args.EnablePruning) || (Args.EarlyStoppingMetrics >= 1 && Args.EarlyStoppingMetrics <= 2), nameof(Args.EarlyStoppingMetrics), + ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)"); } @@ -118,17 +118,17 @@ private static SchemaShape.Column MakeLabelColumn(string labelColumn) protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch) { - return new ObjectiveImpl(TrainSet, Args); + return new ObjectiveImpl(TrainSet, FastTreeTrainerOptions); } private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch) { OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch); - if (Args.UseLineSearch) + if (FastTreeTrainerOptions.UseLineSearch) { var lossCalculator = new RegressionTest(optimizationAlgorithm.TrainingScores); // REVIEW: We should make loss indices an enum in BinaryClassificationTest. - optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, 1 /*L2 error*/, Args.NumPostBracketSteps, Args.MinStepSize); + optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, 1 /*L2 error*/, FastTreeTrainerOptions.NumPostBracketSteps, FastTreeTrainerOptions.MinStepSize); } return optimizationAlgorithm; @@ -229,10 +229,10 @@ protected virtual void AddFullNDCGTests() protected override void InitializeTests() { // Initialize regression tests. - if (Args.TestFrequency != int.MaxValue) + if (FastTreeTrainerOptions.TestFrequency != int.MaxValue) AddFullRegressionTests(); - if (Args.PrintTestGraph) + if (FastTreeTrainerOptions.PrintTestGraph) { // If FirstTestHistory is null (which means the tests were not intialized due to /tf==infinity), // we need initialize first set for graph printing. @@ -244,24 +244,24 @@ protected override void InitializeTests() } } - if (Args.PrintTrainValidGraph && _trainRegressionTest == null) + if (FastTreeTrainerOptions.PrintTrainValidGraph && _trainRegressionTest == null) { Test trainRegressionTest = new RegressionTest(ConstructScoreTracker(TrainSet)); _trainRegressionTest = trainRegressionTest; } - if (Args.PrintTrainValidGraph && _testRegressionTest == null && TestSets != null && TestSets.Length > 0) + if (FastTreeTrainerOptions.PrintTrainValidGraph && _testRegressionTest == null && TestSets != null && TestSets.Length > 0) _testRegressionTest = new RegressionTest(ConstructScoreTracker(TestSets[0])); // Add early stopping if appropriate. - TrainTest = new RegressionTest(ConstructScoreTracker(TrainSet), Args.EarlyStoppingMetrics); + TrainTest = new RegressionTest(ConstructScoreTracker(TrainSet), FastTreeTrainerOptions.EarlyStoppingMetrics); if (ValidSet != null) - ValidTest = new RegressionTest(ConstructScoreTracker(ValidSet), Args.EarlyStoppingMetrics); + ValidTest = new RegressionTest(ConstructScoreTracker(ValidSet), FastTreeTrainerOptions.EarlyStoppingMetrics); - if (Args.EnablePruning && ValidTest != null) + if (FastTreeTrainerOptions.EnablePruning && ValidTest != null) { - if (Args.UseTolerantPruning) // Use simple early stopping condition. - PruningTest = new TestWindowWithTolerance(ValidTest, 0, Args.PruningWindowSize, Args.PruningThreshold); + if (FastTreeTrainerOptions.UseTolerantPruning) // Use simple early stopping condition. + PruningTest = new TestWindowWithTolerance(ValidTest, 0, FastTreeTrainerOptions.PruningWindowSize, FastTreeTrainerOptions.PruningThreshold); else PruningTest = new TestHistory(ValidTest, 0); } @@ -311,7 +311,7 @@ protected override string GetTestGraphHeader() { StringBuilder headerBuilder = new StringBuilder("Eval:\tFileName\tNDCG@1\tNDCG@2\tNDCG@3\tNDCG@4\tNDCG@5\tNDCG@6\tNDCG@7\tNDCG@8\tNDCG@9\tNDCG@10"); - if (Args.PrintTrainValidGraph) + if (FastTreeTrainerOptions.PrintTrainValidGraph) { headerBuilder.Append("\tNDCG@20\tNDCG@40"); headerBuilder.Append("\nNote: Printing train L2 error as NDCG@20 and test L2 error as NDCG@40..\n"); diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index 03df71233d..2e2c27deb7 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -115,27 +115,27 @@ protected override void CheckArgs(IChannel ch) // a simple integer, because the metric that we might want is parameterized by this floating point "index" parameter. For now // we just leave the existing regression checks, though with a warning. - if (Args.EarlyStoppingMetrics > 0) + if (FastTreeTrainerOptions.EarlyStoppingMetrics > 0) ch.Warning("For Tweedie regression, early stopping does not yet use the Tweedie distribution."); - ch.CheckUserArg((Args.EarlyStoppingRule == null && !Args.EnablePruning) || (Args.EarlyStoppingMetrics >= 1 && Args.EarlyStoppingMetrics <= 2), nameof(Args.EarlyStoppingMetrics), + ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)"); } protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch) { - return new ObjectiveImpl(TrainSet, Args); + return new ObjectiveImpl(TrainSet, FastTreeTrainerOptions); } private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch) { OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch); - if (Args.UseLineSearch) + if (FastTreeTrainerOptions.UseLineSearch) { var lossCalculator = new RegressionTest(optimizationAlgorithm.TrainingScores); // REVIEW: We should make loss indices an enum in BinaryClassificationTest. // REVIEW: Nope, subcomponent. - optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, 1 /*L2 error*/, Args.NumPostBracketSteps, Args.MinStepSize); + optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, 1 /*L2 error*/, FastTreeTrainerOptions.NumPostBracketSteps, FastTreeTrainerOptions.MinStepSize); } return optimizationAlgorithm; @@ -171,7 +171,7 @@ protected override Test ConstructTestForTrainingData() private void Initialize() { - Host.CheckUserArg(1 <= Args.Index && Args.Index <= 2, nameof(Args.Index), "Must be in the range [1, 2]"); + Host.CheckUserArg(1 <= FastTreeTrainerOptions.Index && FastTreeTrainerOptions.Index <= 2, nameof(FastTreeTrainerOptions.Index), "Must be in the range [1, 2]"); _outputColumns = new[] { @@ -202,10 +202,10 @@ private void AddFullRegressionTests() protected override void InitializeTests() { // Initialize regression tests. - if (Args.TestFrequency != int.MaxValue) + if (FastTreeTrainerOptions.TestFrequency != int.MaxValue) AddFullRegressionTests(); - if (Args.PrintTestGraph) + if (FastTreeTrainerOptions.PrintTestGraph) { // If FirstTestHistory is null (which means the tests were not intialized due to /tf==infinity), // we need initialize first set for graph printing. @@ -217,24 +217,24 @@ protected override void InitializeTests() } } - if (Args.PrintTrainValidGraph && _trainRegressionTest == null) + if (FastTreeTrainerOptions.PrintTrainValidGraph && _trainRegressionTest == null) { Test trainRegressionTest = new RegressionTest(ConstructScoreTracker(TrainSet)); _trainRegressionTest = trainRegressionTest; } - if (Args.PrintTrainValidGraph && _testRegressionTest == null && TestSets != null && TestSets.Length > 0) + if (FastTreeTrainerOptions.PrintTrainValidGraph && _testRegressionTest == null && TestSets != null && TestSets.Length > 0) _testRegressionTest = new RegressionTest(ConstructScoreTracker(TestSets[0])); // Add early stopping if appropriate. - TrainTest = new RegressionTest(ConstructScoreTracker(TrainSet), Args.EarlyStoppingMetrics); + TrainTest = new RegressionTest(ConstructScoreTracker(TrainSet), FastTreeTrainerOptions.EarlyStoppingMetrics); if (ValidSet != null) - ValidTest = new RegressionTest(ConstructScoreTracker(ValidSet), Args.EarlyStoppingMetrics); + ValidTest = new RegressionTest(ConstructScoreTracker(ValidSet), FastTreeTrainerOptions.EarlyStoppingMetrics); - if (Args.EnablePruning && ValidTest != null) + if (FastTreeTrainerOptions.EnablePruning && ValidTest != null) { - if (Args.UseTolerantPruning) // Use simple early stopping condition. - PruningTest = new TestWindowWithTolerance(ValidTest, 0, Args.PruningWindowSize, Args.PruningThreshold); + if (FastTreeTrainerOptions.UseTolerantPruning) // Use simple early stopping condition. + PruningTest = new TestWindowWithTolerance(ValidTest, 0, FastTreeTrainerOptions.PruningWindowSize, FastTreeTrainerOptions.PruningThreshold); else PruningTest = new TestHistory(ValidTest, 0); } @@ -249,7 +249,7 @@ protected override string GetTestGraphHeader() { StringBuilder headerBuilder = new StringBuilder("Eval:\tFileName\tNDCG@1\tNDCG@2\tNDCG@3\tNDCG@4\tNDCG@5\tNDCG@6\tNDCG@7\tNDCG@8\tNDCG@9\tNDCG@10"); - if (Args.PrintTrainValidGraph) + if (FastTreeTrainerOptions.PrintTrainValidGraph) { headerBuilder.Append("\tNDCG@20\tNDCG@40"); headerBuilder.Append("\nNote: Printing train L2 error as NDCG@20 and test L2 error as NDCG@40..\n"); diff --git a/src/Microsoft.ML.FastTree/GamClassification.cs b/src/Microsoft.ML.FastTree/GamClassification.cs index 8dd51c03b5..a0d199fc4b 100644 --- a/src/Microsoft.ML.FastTree/GamClassification.cs +++ b/src/Microsoft.ML.FastTree/GamClassification.cs @@ -33,7 +33,7 @@ public sealed class BinaryClassificationGamTrainer : BinaryPredictionTransformer>, CalibratedModelParametersBase> { - public sealed class Options : ArgumentsBase + public sealed class Options : OptionsBase { [Argument(ArgumentType.LastOccurenceWins, HelpText = "Should we use derivatives optimized for unbalanced sets", ShortName = "us")] [TGUI(Label = "Optimize for unbalanced")] diff --git a/src/Microsoft.ML.FastTree/GamRegression.cs b/src/Microsoft.ML.FastTree/GamRegression.cs index d740e13589..56ec5f6176 100644 --- a/src/Microsoft.ML.FastTree/GamRegression.cs +++ b/src/Microsoft.ML.FastTree/GamRegression.cs @@ -26,7 +26,7 @@ namespace Microsoft.ML.Trainers.FastTree { public sealed class RegressionGamTrainer : GamTrainerBase, RegressionGamModelParameters> { - public partial class Options : ArgumentsBase + public partial class Options : OptionsBase { [Argument(ArgumentType.AtMostOnce, HelpText = "Metric for pruning. (For regression, 1: L1, 2:L2; default L2)", ShortName = "pmetric")] [TGUI(Description = "Metric for pruning. (For regression, 1: L1, 2:L2; default L2")] diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index d1017b4f1d..9af1ea1c6b 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -51,10 +51,10 @@ namespace Microsoft.ML.Trainers.FastTree /// public abstract partial class GamTrainerBase : TrainerEstimatorBase where TTransformer: ISingleFeaturePredictionTransformer - where TArgs : GamTrainerBase.ArgumentsBase, new() + where TArgs : GamTrainerBase.OptionsBase, new() where TPredictor : class { - public abstract class ArgumentsBase : LearnerInputBaseWithWeight + public abstract class OptionsBase : LearnerInputBaseWithWeight { [Argument(ArgumentType.LastOccurenceWins, HelpText = "The entropy (regularization) coefficient between 0 and 1", ShortName = "e")] public double EntropyCoefficient; diff --git a/src/Microsoft.ML.FastTree/RandomForest.cs b/src/Microsoft.ML.FastTree/RandomForest.cs index 419f2b0129..425abb73e9 100644 --- a/src/Microsoft.ML.FastTree/RandomForest.cs +++ b/src/Microsoft.ML.FastTree/RandomForest.cs @@ -5,7 +5,7 @@ namespace Microsoft.ML.Trainers.FastTree { public abstract class RandomForestTrainerBase : FastTreeTrainerBase - where TArgs : FastForestArgumentsBase, new() + where TArgs : FastForestOptionsBase, new() where TModel : class where TTransformer: ISingleFeaturePredictionTransformer { @@ -46,7 +46,7 @@ private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm( optimizationAlgorithm.TreeLearner = ConstructTreeLearner(ch); optimizationAlgorithm.ObjectiveFunction = ConstructObjFunc(ch); - optimizationAlgorithm.Smoothing = Args.Smoothing; + optimizationAlgorithm.Smoothing = FastTreeTrainerOptions.Smoothing; // No notion of dropout for non-boosting applications. optimizationAlgorithm.DropoutRate = 0; optimizationAlgorithm.DropoutRng = null; @@ -62,12 +62,12 @@ protected override void InitializeTests() private protected override TreeLearner ConstructTreeLearner(IChannel ch) { return new RandomForestLeastSquaresTreeLearner( - TrainSet, Args.NumLeaves, Args.MinDocumentsInLeafs, Args.EntropyCoefficient, - Args.FeatureFirstUsePenalty, Args.FeatureReusePenalty, Args.SoftmaxTemperature, - Args.HistogramPoolSize, Args.RngSeed, Args.SplitFraction, - Args.AllowEmptyTrees, Args.GainConfidenceLevel, Args.MaxCategoricalGroupsPerNode, - Args.MaxCategoricalSplitPoints, _quantileEnabled, Args.QuantileSampleCount, ParallelTraining, - Args.MinDocsPercentageForCategoricalSplit, Args.Bundling, Args.MinDocsForCategoricalSplit, Args.Bias); + TrainSet, FastTreeTrainerOptions.NumLeaves, FastTreeTrainerOptions.MinDocumentsInLeafs, FastTreeTrainerOptions.EntropyCoefficient, + FastTreeTrainerOptions.FeatureFirstUsePenalty, FastTreeTrainerOptions.FeatureReusePenalty, FastTreeTrainerOptions.SoftmaxTemperature, + FastTreeTrainerOptions.HistogramPoolSize, FastTreeTrainerOptions.RngSeed, FastTreeTrainerOptions.SplitFraction, + FastTreeTrainerOptions.AllowEmptyTrees, FastTreeTrainerOptions.GainConfidenceLevel, FastTreeTrainerOptions.MaxCategoricalGroupsPerNode, + FastTreeTrainerOptions.MaxCategoricalSplitPoints, _quantileEnabled, FastTreeTrainerOptions.QuantileSampleCount, ParallelTraining, + FastTreeTrainerOptions.MinDocsPercentageForCategoricalSplit, FastTreeTrainerOptions.Bundling, FastTreeTrainerOptions.MinDocsForCategoricalSplit, FastTreeTrainerOptions.Bias); } public abstract class RandomForestObjectiveFunction : ObjectiveFunctionBase diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index 348eed8429..232b7d62ff 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -32,12 +32,12 @@ namespace Microsoft.ML.Trainers.FastTree { - public abstract class FastForestArgumentsBase : TreeArgs + public abstract class FastForestOptionsBase : TreeOptions { [Argument(ArgumentType.AtMostOnce, HelpText = "Number of labels to be sampled from each leaf to make the distribtuion", ShortName = "qsc")] public int QuantileSampleCount = 100; - public FastForestArgumentsBase() + public FastForestOptionsBase() { FeatureFraction = 0.7; BaggingSize = 1; @@ -111,7 +111,7 @@ private static IPredictorProducing Create(IHostEnvironment env, ModelLoad public sealed partial class FastForestClassification : RandomForestTrainerBase, FastForestClassificationModelParameters> { - public sealed class Options : FastForestArgumentsBase + public sealed class Options : FastForestOptionsBase { [Argument(ArgumentType.AtMostOnce, HelpText = "Upper bound on absolute value of single tree output", ShortName = "mo")] public Double MaxTreeOutput = 100; @@ -196,7 +196,7 @@ private protected override FastForestClassificationModelParameters TrainModelCor protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch) { - return new ObjectiveFunctionImpl(TrainSet, _trainSetLabels, Args); + return new ObjectiveFunctionImpl(TrainSet, _trainSetLabels, FastTreeTrainerOptions); } protected override void PrepareLabels(IChannel ch) diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index f1a348706c..ddb79ed585 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -138,7 +138,7 @@ ISchemaBindableMapper IQuantileRegressionPredictor.CreateMapper(Double[] quantil public sealed partial class FastForestRegression : RandomForestTrainerBase, FastForestRegressionModelParameters> { - public sealed class Options : FastForestArgumentsBase + public sealed class Options : FastForestOptionsBase { [Argument(ArgumentType.LastOccurenceWins, HelpText = "Shuffle the labels on every iteration. " + "Useful probably only if using this tree as a tree leaf featurizer for multiclass.")] @@ -204,7 +204,7 @@ private protected override FastForestRegressionModelParameters TrainModelCore(Tr ConvertData(trainData); TrainCore(ch); } - return new FastForestRegressionModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs, Args.QuantileSampleCount); + return new FastForestRegressionModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs, FastTreeTrainerOptions.QuantileSampleCount); } protected override void PrepareLabels(IChannel ch) @@ -213,7 +213,7 @@ protected override void PrepareLabels(IChannel ch) protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch) { - return ObjectiveFunctionImplBase.Create(TrainSet, Args); + return ObjectiveFunctionImplBase.Create(TrainSet, FastTreeTrainerOptions); } protected override Test ConstructTestForTrainingData() diff --git a/src/Microsoft.ML.ImageAnalytics/ImageGrayscale.cs b/src/Microsoft.ML.ImageAnalytics/ImageGrayscale.cs index 32768bc9a9..35b318e41a 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageGrayscale.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageGrayscale.cs @@ -106,14 +106,14 @@ internal ImageGrayscalingTransformer(IHostEnvironment env, params (string output } // Factory method for SignatureDataTransform. - internal static IDataTransform Create(IHostEnvironment env, Options args, IDataView input) + internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); + env.CheckValue(options, nameof(options)); env.CheckValue(input, nameof(input)); - env.CheckValue(args.Columns, nameof(args.Columns)); + env.CheckValue(options.Columns, nameof(options.Columns)); - return new ImageGrayscalingTransformer(env, args.Columns.Select(x => (x.Name, x.Source ?? x.Name)).ToArray()) + return new ImageGrayscalingTransformer(env, options.Columns.Select(x => (x.Name, x.Source ?? x.Name)).ToArray()) .MakeDataTransform(input); } diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs index 1c972e18f0..6906af7c61 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs @@ -97,9 +97,9 @@ internal ImageLoadingTransformer(IHostEnvironment env, string imageFolder = null } // Factory method for SignatureDataTransform. - internal static IDataTransform Create(IHostEnvironment env, Options args, IDataView data) + internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView data) { - return new ImageLoadingTransformer(env, args.ImageFolder, args.Columns.Select(x => (x.Name, x.Source ?? x.Name)).ToArray()) + return new ImageLoadingTransformer(env, options.ImageFolder, options.Columns.Select(x => (x.Name, x.Source ?? x.Name)).ToArray()) .MakeDataTransform(data); } diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractor.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractor.cs index 2b7d2d0243..158e888310 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractor.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractor.cs @@ -199,19 +199,19 @@ private static (string outputColumnName, string inputColumnName)[] GetColumnPair } // Factory method for SignatureDataTransform. - internal static IDataTransform Create(IHostEnvironment env, Options args, IDataView input) + internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); + env.CheckValue(options, nameof(options)); env.CheckValue(input, nameof(input)); - env.CheckValue(args.Columns, nameof(args.Columns)); + env.CheckValue(options.Columns, nameof(options.Columns)); - var columns = new ImagePixelExtractingEstimator.ColumnInfo[args.Columns.Length]; + var columns = new ImagePixelExtractingEstimator.ColumnInfo[options.Columns.Length]; for (int i = 0; i < columns.Length; i++) { - var item = args.Columns[i]; - columns[i] = new ImagePixelExtractingEstimator.ColumnInfo(item, args); + var item = options.Columns[i]; + columns[i] = new ImagePixelExtractingEstimator.ColumnInfo(item, options); } var transformer = new ImagePixelExtractingTransformer(env, columns); @@ -549,23 +549,23 @@ public sealed class ColumnInfo internal bool Green => (Colors & ColorBits.Green) != 0; internal bool Blue => (Colors & ColorBits.Blue) != 0; - internal ColumnInfo(ImagePixelExtractingTransformer.Column item, ImagePixelExtractingTransformer.Options args) + internal ColumnInfo(ImagePixelExtractingTransformer.Column item, ImagePixelExtractingTransformer.Options options) { Contracts.CheckValue(item, nameof(item)); - Contracts.CheckValue(args, nameof(args)); + Contracts.CheckValue(options, nameof(options)); Name = item.Name; InputColumnName = item.Source ?? item.Name; - if (item.UseAlpha ?? args.UseAlpha) { Colors |= ColorBits.Alpha; Planes++; } - if (item.UseRed ?? args.UseRed) { Colors |= ColorBits.Red; Planes++; } - if (item.UseGreen ?? args.UseGreen) { Colors |= ColorBits.Green; Planes++; } - if (item.UseBlue ?? args.UseBlue) { Colors |= ColorBits.Blue; Planes++; } + if (item.UseAlpha ?? options.UseAlpha) { Colors |= ColorBits.Alpha; Planes++; } + if (item.UseRed ?? options.UseRed) { Colors |= ColorBits.Red; Planes++; } + if (item.UseGreen ?? options.UseGreen) { Colors |= ColorBits.Green; Planes++; } + if (item.UseBlue ?? options.UseBlue) { Colors |= ColorBits.Blue; Planes++; } Contracts.CheckUserArg(Planes > 0, nameof(item.UseRed), "Need to use at least one color plane"); - Interleave = item.InterleaveArgb ?? args.InterleaveArgb; + Interleave = item.InterleaveArgb ?? options.InterleaveArgb; - AsFloat = item.Convert ?? args.Convert; + AsFloat = item.Convert ?? options.Convert; if (!AsFloat) { Offset = ImagePixelExtractingTransformer.Defaults.Offset; @@ -573,8 +573,8 @@ internal ColumnInfo(ImagePixelExtractingTransformer.Column item, ImagePixelExtra } else { - Offset = item.Offset ?? args.Offset ?? ImagePixelExtractingTransformer.Defaults.Offset; - Scale = item.Scale ?? args.Scale ?? ImagePixelExtractingTransformer.Defaults.Scale; + Offset = item.Offset ?? options.Offset ?? ImagePixelExtractingTransformer.Defaults.Offset; + Scale = item.Scale ?? options.Scale ?? ImagePixelExtractingTransformer.Defaults.Scale; Contracts.CheckUserArg(FloatUtils.IsFinite(Offset), nameof(item.Offset)); Contracts.CheckUserArg(FloatUtils.IsFiniteNonZero(Scale), nameof(item.Scale)); } diff --git a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs index 82fedef36b..e66e3bb814 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs @@ -51,7 +51,7 @@ internal static class Defaults public const int ClustersCount = 5; } - public class Options : UnsupervisedLearnerInputBaseWithWeight + public sealed class Options : UnsupervisedLearnerInputBaseWithWeight { /// /// The number of clusters. diff --git a/src/Microsoft.ML.LightGBM/LightGbmArguments.cs b/src/Microsoft.ML.LightGBM/LightGbmArguments.cs index 0085ad31aa..8cf005f339 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmArguments.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmArguments.cs @@ -171,8 +171,8 @@ public class Options : ISupportBoosterParameterFactory IBoosterParameter IComponentFactory.CreateComponent(IHostEnvironment env) => CreateComponent(env); } - internal TreeBooster(Options args) - : base(args) + internal TreeBooster(Options options) + : base(options) { Contracts.CheckUserArg(Args.MinSplitGain >= 0, nameof(Args.MinSplitGain), "must be >= 0."); Contracts.CheckUserArg(Args.MinChildWeight >= 0, nameof(Args.MinChildWeight), "must be >= 0."); @@ -194,7 +194,7 @@ public sealed class DartBooster : BoosterParameter internal const string FriendlyName = "Tree Dropout Tree Booster"; [TlcModule.Component(Name = Name, FriendlyName = FriendlyName, Desc = "Dropouts meet Multiple Additive Regresion Trees. See https://arxiv.org/abs/1505.01866")] - public class Options : TreeBooster.Options + public sealed class Options : TreeBooster.Options { [Argument(ArgumentType.AtMostOnce, HelpText = "Drop ratio for trees. Range:(0,1).")] [TlcModule.Range(Inf = 0.0, Max = 1.0)] @@ -217,8 +217,8 @@ public class Options : TreeBooster.Options internal override IBoosterParameter CreateComponent(IHostEnvironment env) => new DartBooster(this); } - internal DartBooster(Options args) - : base(args) + internal DartBooster(Options options) + : base(options) { Contracts.CheckUserArg(Args.DropRate > 0 && Args.DropRate < 1, nameof(Args.DropRate), "must be in (0,1)."); Contracts.CheckUserArg(Args.MaxDrop > 0, nameof(Args.MaxDrop), "must be > 0."); @@ -238,7 +238,7 @@ public sealed class GossBooster : BoosterParameter internal const string FriendlyName = "Gradient-based One-Size Sampling"; [TlcModule.Component(Name = Name, FriendlyName = FriendlyName, Desc = "Gradient-based One-Side Sampling.")] - public class Options : TreeBooster.Options + public sealed class Options : TreeBooster.Options { [Argument(ArgumentType.AtMostOnce, HelpText = "Retain ratio for large gradient instances.")] @@ -254,8 +254,8 @@ public class Options : TreeBooster.Options internal override IBoosterParameter CreateComponent(IHostEnvironment env) => new GossBooster(this); } - internal GossBooster(Options args) - : base(args) + internal GossBooster(Options options) + : base(options) { Contracts.CheckUserArg(Args.TopRate > 0 && Args.TopRate < 1, nameof(Args.TopRate), "must be in (0,1)."); Contracts.CheckUserArg(Args.OtherRate >= 0 && Args.OtherRate < 1, nameof(Args.TopRate), "must be in [0,1)."); diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs index 1c25e8ef08..6fb7cb70e3 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs @@ -164,16 +164,16 @@ protected override void GetDefaultParameters(IChannel ch, int numRow, bool hasCa { base.GetDefaultParameters(ch, numRow, hasCategorical, totalCats, true); int numLeaves = (int)Options["num_leaves"]; - int minDataPerLeaf = Args.MinDataPerLeaf ?? DefaultMinDataPerLeaf(numRow, numLeaves, _numClass); + int minDataPerLeaf = LightGbmTrainerOptions.MinDataPerLeaf ?? DefaultMinDataPerLeaf(numRow, numLeaves, _numClass); Options["min_data_per_leaf"] = minDataPerLeaf; if (!hiddenMsg) { - if (!Args.LearningRate.HasValue) - ch.Info("Auto-tuning parameters: " + nameof(Args.LearningRate) + " = " + Options["learning_rate"]); - if (!Args.NumLeaves.HasValue) - ch.Info("Auto-tuning parameters: " + nameof(Args.NumLeaves) + " = " + numLeaves); - if (!Args.MinDataPerLeaf.HasValue) - ch.Info("Auto-tuning parameters: " + nameof(Args.MinDataPerLeaf) + " = " + minDataPerLeaf); + if (!LightGbmTrainerOptions.LearningRate.HasValue) + ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.LearningRate) + " = " + Options["learning_rate"]); + if (!LightGbmTrainerOptions.NumLeaves.HasValue) + ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.NumLeaves) + " = " + numLeaves); + if (!LightGbmTrainerOptions.MinDataPerLeaf.HasValue) + ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.MinDataPerLeaf) + " = " + minDataPerLeaf); } } @@ -185,14 +185,14 @@ private protected override void CheckAndUpdateParametersBeforeTraining(IChannel Options["num_class"] = _numClass; bool useSoftmax = false; - if (Args.UseSoftmax.HasValue) - useSoftmax = Args.UseSoftmax.Value; + if (LightGbmTrainerOptions.UseSoftmax.HasValue) + useSoftmax = LightGbmTrainerOptions.UseSoftmax.Value; else { if (labels.Length >= _minDataToUseSoftmax) useSoftmax = true; - ch.Info("Auto-tuning parameters: " + nameof(Args.UseSoftmax) + " = " + useSoftmax); + ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.UseSoftmax) + " = " + useSoftmax); } if (useSoftmax) diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs index febe35eb9d..67f8d42676 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs @@ -40,7 +40,7 @@ private sealed class CategoricalMetaData public bool[] IsCategoricalFeature; } - private protected readonly Options Args; + private protected readonly Options LightGbmTrainerOptions; /// /// Stores argumments as objects to convert them to invariant string type in the end so that @@ -69,21 +69,21 @@ private protected LightGbmTrainerBase(IHostEnvironment env, int numBoostRound) : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn)) { - Args = new Options(); + LightGbmTrainerOptions = new Options(); - Args.NumLeaves = numLeaves; - Args.MinDataPerLeaf = minDataPerLeaf; - Args.LearningRate = learningRate; - Args.NumBoostRound = numBoostRound; + LightGbmTrainerOptions.NumLeaves = numLeaves; + LightGbmTrainerOptions.MinDataPerLeaf = minDataPerLeaf; + LightGbmTrainerOptions.LearningRate = learningRate; + LightGbmTrainerOptions.NumBoostRound = numBoostRound; - Args.LabelColumn = label.Name; - Args.FeatureColumn = featureColumn; + LightGbmTrainerOptions.LabelColumn = label.Name; + LightGbmTrainerOptions.FeatureColumn = featureColumn; if (weightColumn != null) - Args.WeightColumn = Optional.Explicit(weightColumn); + LightGbmTrainerOptions.WeightColumn = Optional.Explicit(weightColumn); if (groupIdColumn != null) - Args.GroupIdColumn = Optional.Explicit(groupIdColumn); + LightGbmTrainerOptions.GroupIdColumn = Optional.Explicit(groupIdColumn); InitParallelTraining(); } @@ -93,7 +93,7 @@ private protected LightGbmTrainerBase(IHostEnvironment env, string name, Options { Host.CheckValue(options, nameof(options)); - Args = options; + LightGbmTrainerOptions = options; InitParallelTraining(); } @@ -132,8 +132,8 @@ private protected override TModel TrainModelCore(TrainContext context) private void InitParallelTraining() { - Options = Args.ToDictionary(Host); - ParallelTraining = Args.ParallelTrainer != null ? Args.ParallelTrainer.CreateComponent(Host) : new SingleTrainer(); + Options = LightGbmTrainerOptions.ToDictionary(Host); + ParallelTraining = LightGbmTrainerOptions.ParallelTrainer != null ? LightGbmTrainerOptions.ParallelTrainer.CreateComponent(Host) : new SingleTrainer(); if (ParallelTraining.ParallelType() != "serial" && ParallelTraining.NumMachines() > 1) { @@ -170,20 +170,20 @@ private protected virtual void CheckDataValid(IChannel ch, RoleMappedData data) protected virtual void GetDefaultParameters(IChannel ch, int numRow, bool hasCategarical, int totalCats, bool hiddenMsg=false) { - double learningRate = Args.LearningRate ?? DefaultLearningRate(numRow, hasCategarical, totalCats); - int numLeaves = Args.NumLeaves ?? DefaultNumLeaves(numRow, hasCategarical, totalCats); - int minDataPerLeaf = Args.MinDataPerLeaf ?? DefaultMinDataPerLeaf(numRow, numLeaves, 1); + double learningRate = LightGbmTrainerOptions.LearningRate ?? DefaultLearningRate(numRow, hasCategarical, totalCats); + int numLeaves = LightGbmTrainerOptions.NumLeaves ?? DefaultNumLeaves(numRow, hasCategarical, totalCats); + int minDataPerLeaf = LightGbmTrainerOptions.MinDataPerLeaf ?? DefaultMinDataPerLeaf(numRow, numLeaves, 1); Options["learning_rate"] = learningRate; Options["num_leaves"] = numLeaves; Options["min_data_per_leaf"] = minDataPerLeaf; if (!hiddenMsg) { - if (!Args.LearningRate.HasValue) - ch.Info("Auto-tuning parameters: " + nameof(Args.LearningRate) + " = " + learningRate); - if (!Args.NumLeaves.HasValue) - ch.Info("Auto-tuning parameters: " + nameof(Args.NumLeaves) + " = " + numLeaves); - if (!Args.MinDataPerLeaf.HasValue) - ch.Info("Auto-tuning parameters: " + nameof(Args.MinDataPerLeaf) + " = " + minDataPerLeaf); + if (!LightGbmTrainerOptions.LearningRate.HasValue) + ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.LearningRate) + " = " + learningRate); + if (!LightGbmTrainerOptions.NumLeaves.HasValue) + ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.NumLeaves) + " = " + numLeaves); + if (!LightGbmTrainerOptions.MinDataPerLeaf.HasValue) + ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.MinDataPerLeaf) + " = " + minDataPerLeaf); } } @@ -278,9 +278,9 @@ private CategoricalMetaData GetCategoricalMetaData(IChannel ch, RoleMappedData t int[] categoricalFeatures = null; const int useCatThreshold = 50000; // Disable cat when data is too small, reduce the overfitting. - bool useCat = Args.UseCat ?? numRow > useCatThreshold; - if (!Args.UseCat.HasValue) - ch.Info("Auto-tuning parameters: " + nameof(Args.UseCat) + " = " + useCat); + bool useCat = LightGbmTrainerOptions.UseCat ?? numRow > useCatThreshold; + if (!LightGbmTrainerOptions.UseCat.HasValue) + ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.UseCat) + " = " + useCat); if (useCat) { var featureCol = trainData.Schema.Schema[DefaultColumnNames.Features]; @@ -329,7 +329,7 @@ private Dataset LoadTrainingData(IChannel ch, RoleMappedData trainData, out Cate } // Push rows into dataset. - LoadDataset(ch, factory, dtrain, numRow, Args.BatchSize, catMetaData); + LoadDataset(ch, factory, dtrain, numRow, LightGbmTrainerOptions.BatchSize, catMetaData); // Some checks. CheckAndUpdateParametersBeforeTraining(ch, trainData, labels, groups); @@ -353,7 +353,7 @@ private Dataset LoadValidationData(IChannel ch, Dataset dtrain, RoleMappedData v Dataset dvalid = new Dataset(dtrain, numRow, labels, weights, groups); // Push rows into dataset. - LoadDataset(ch, factory, dvalid, numRow, Args.BatchSize, catMetaData); + LoadDataset(ch, factory, dvalid, numRow, LightGbmTrainerOptions.BatchSize, catMetaData); return dvalid; } @@ -373,8 +373,8 @@ private void TrainCore(IChannel ch, IProgressChannel pch, Dataset dtrain, Catego { ch.Info("LightGBM objective={0}", Options["objective"]); using (Booster bst = WrappedLightGbmTraining.Train(ch, pch, Options, dtrain, - dvalid: dvalid, numIteration: Args.NumBoostRound, - verboseEval: Args.VerboseEval, earlyStoppingRound: Args.EarlyStoppingRound)) + dvalid: dvalid, numIteration: LightGbmTrainerOptions.NumBoostRound, + verboseEval: LightGbmTrainerOptions.VerboseEval, earlyStoppingRound: LightGbmTrainerOptions.EarlyStoppingRound)) { TrainedEnsemble = bst.GetModel(catMetaData.CategoricalBoudaries); } diff --git a/src/Microsoft.ML.PCA/PcaTrainer.cs b/src/Microsoft.ML.PCA/PcaTrainer.cs index e46592eec4..7b51f19ba8 100644 --- a/src/Microsoft.ML.PCA/PcaTrainer.cs +++ b/src/Microsoft.ML.PCA/PcaTrainer.cs @@ -48,7 +48,7 @@ public sealed class RandomizedPcaTrainer : TrainerEstimatorBase : TrainerEstimatorBase where TTransformer : ISingleFeaturePredictionTransformer where TModel : class - where TArgs : LbfgsTrainerBase.ArgumentsBase, new () + where TArgs : LbfgsTrainerBase.OptionsBase, new () { - public abstract class ArgumentsBase : LearnerInputBaseWithWeight + public abstract class OptionsBase : LearnerInputBaseWithWeight { [Argument(ArgumentType.AtMostOnce, HelpText = "L2 regularization weight", ShortName = "l2", SortOrder = 50)] [TGUI(Label = "L2 Weight", Description = "Weight of L2 regularizer term", SuggestedSweeps = "0,0.1,1")] diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs index 572def3f98..430c889550 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs @@ -40,7 +40,7 @@ public sealed partial class LogisticRegression : LbfgsTrainerBase /// If set to truetraining statistics will be generated at the end of training. diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index f003566bc0..2a37b751a6 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -44,7 +44,7 @@ public sealed class MulticlassLogisticRegression : LbfgsTrainerBase /// Arguments class for averaged linear trainers. /// - public abstract class AveragedLinearArguments : OnlineLinearArguments + public abstract class AveragedLinearOptions : OnlineLinearOptions { /// /// Learning rate. @@ -25,7 +25,7 @@ public abstract class AveragedLinearArguments : OnlineLinearArguments [Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate", ShortName = "lr", SortOrder = 50)] [TGUI(Label = "Learning rate", SuggestedSweeps = "0.01,0.1,0.5,1.0")] [TlcModule.SweepableDiscreteParam("LearningRate", new object[] { 0.01, 0.1, 0.5, 1.0 })] - public float LearningRate = AveragedDefaultArgs.LearningRate; + public float LearningRate = AveragedDefault.LearningRate; /// /// Determine whether to decrease the or not. @@ -37,7 +37,7 @@ public abstract class AveragedLinearArguments : OnlineLinearArguments [Argument(ArgumentType.AtMostOnce, HelpText = "Decrease learning rate", ShortName = "decreaselr", SortOrder = 50)] [TGUI(Label = "Decrease Learning Rate", Description = "Decrease learning rate as iterations progress")] [TlcModule.SweepableDiscreteParam("DecreaseLearningRate", new object[] { false, true })] - public bool DecreaseLearningRate = AveragedDefaultArgs.DecreaseLearningRate; + public bool DecreaseLearningRate = AveragedDefault.DecreaseLearningRate; /// /// Number of examples after which weights will be reset to the current average. @@ -65,7 +65,7 @@ public abstract class AveragedLinearArguments : OnlineLinearArguments [Argument(ArgumentType.AtMostOnce, HelpText = "L2 Regularization Weight", ShortName = "reg", SortOrder = 50)] [TGUI(Label = "L2 Regularization Weight")] [TlcModule.SweepableFloatParam("L2RegularizerWeight", 0.0f, 0.4f)] - public float L2RegularizerWeight = AveragedDefaultArgs.L2RegularizerWeight; + public float L2RegularizerWeight = AveragedDefault.L2RegularizerWeight; /// /// Extra weight given to more recent updates. @@ -104,7 +104,7 @@ public abstract class AveragedLinearArguments : OnlineLinearArguments internal float AveragedTolerance = (float)1e-2; [BestFriend] - internal class AveragedDefaultArgs : OnlineDefaultArgs + internal class AveragedDefault : OnlineLinearOptions.OnlineDefault { public const float LearningRate = 1; public const bool DecreaseLearningRate = false; @@ -118,7 +118,7 @@ public abstract class AveragedLinearTrainer : OnlineLinear where TTransformer : ISingleFeaturePredictionTransformer where TModel : class { - protected readonly new AveragedLinearArguments Args; + protected readonly AveragedLinearOptions AveragedLinearTrainerOptions; protected IScalarOutputLoss LossFunction; private protected abstract class AveragedTrainStateBase : TrainStateBase @@ -140,7 +140,7 @@ private protected abstract class AveragedTrainStateBase : TrainStateBase protected readonly bool Averaged; private readonly long _resetWeightsAfterXExamples; - private readonly AveragedLinearArguments _args; + private readonly AveragedLinearOptions _args; private readonly IScalarOutputLoss _loss; private protected AveragedTrainStateBase(IChannel ch, int numFeatures, LinearModelParameters predictor, AveragedLinearTrainer parent) @@ -148,10 +148,10 @@ private protected AveragedTrainStateBase(IChannel ch, int numFeatures, LinearMod { // Do the other initializations by setting the setters as if user had set them // Initialize the averaged weights if needed (i.e., do what happens when Averaged is set) - Averaged = parent.Args.Averaged; + Averaged = parent.AveragedLinearTrainerOptions.Averaged; if (Averaged) { - if (parent.Args.AveragedTolerance > 0) + if (parent.AveragedLinearTrainerOptions.AveragedTolerance > 0) VBufferUtils.Densify(ref Weights); Weights.CopyTo(ref TotalWeights); } @@ -161,8 +161,8 @@ private protected AveragedTrainStateBase(IChannel ch, int numFeatures, LinearMod // to another vector with each update. VBufferUtils.Densify(ref Weights); } - _resetWeightsAfterXExamples = parent.Args.ResetWeightsAfterXExamples ?? 0; - _args = parent.Args; + _resetWeightsAfterXExamples = parent.AveragedLinearTrainerOptions.ResetWeightsAfterXExamples ?? 0; + _args = parent.AveragedLinearTrainerOptions; _loss = parent.LossFunction; Gain = 1; @@ -295,20 +295,20 @@ private void IncrementAverageNonLazy() } } - protected AveragedLinearTrainer(AveragedLinearArguments args, IHostEnvironment env, string name, SchemaShape.Column label) - : base(args, env, name, label) + protected AveragedLinearTrainer(AveragedLinearOptions options, IHostEnvironment env, string name, SchemaShape.Column label) + : base(options, env, name, label) { - Contracts.CheckUserArg(args.LearningRate > 0, nameof(args.LearningRate), UserErrorPositive); - Contracts.CheckUserArg(!args.ResetWeightsAfterXExamples.HasValue || args.ResetWeightsAfterXExamples > 0, nameof(args.ResetWeightsAfterXExamples), UserErrorPositive); + Contracts.CheckUserArg(options.LearningRate > 0, nameof(options.LearningRate), UserErrorPositive); + Contracts.CheckUserArg(!options.ResetWeightsAfterXExamples.HasValue || options.ResetWeightsAfterXExamples > 0, nameof(options.ResetWeightsAfterXExamples), UserErrorPositive); // Weights are scaled down by 2 * L2 regularization on each update step, so 0.5 would scale all weights to 0, which is not sensible. - Contracts.CheckUserArg(0 <= args.L2RegularizerWeight && args.L2RegularizerWeight < 0.5, nameof(args.L2RegularizerWeight), "must be in range [0, 0.5)"); - Contracts.CheckUserArg(args.RecencyGain >= 0, nameof(args.RecencyGain), UserErrorNonNegative); - Contracts.CheckUserArg(args.AveragedTolerance >= 0, nameof(args.AveragedTolerance), UserErrorNonNegative); + Contracts.CheckUserArg(0 <= options.L2RegularizerWeight && options.L2RegularizerWeight < 0.5, nameof(options.L2RegularizerWeight), "must be in range [0, 0.5)"); + Contracts.CheckUserArg(options.RecencyGain >= 0, nameof(options.RecencyGain), UserErrorNonNegative); + Contracts.CheckUserArg(options.AveragedTolerance >= 0, nameof(options.AveragedTolerance), UserErrorNonNegative); // Verify user didn't specify parameters that conflict - Contracts.Check(!args.DoLazyUpdates || !args.RecencyGainMulti && args.RecencyGain == 0, "Cannot have both recency gain and lazy updates."); + Contracts.Check(!options.DoLazyUpdates || !options.RecencyGainMulti && options.RecencyGain == 0, "Cannot have both recency gain and lazy updates."); - Args = args; + AveragedLinearTrainerOptions = options; } } -} \ No newline at end of file +} diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs index 0fe7a7ea15..5ba6dfe2d7 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs @@ -56,13 +56,13 @@ public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer /// Options for the averaged perceptron trainer. /// - public sealed class Options : AveragedLinearArguments + public sealed class Options : AveragedLinearOptions { /// /// A custom loss. /// [Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] - public ISupportClassificationLossFactory LossFunction = new HingeLoss.Arguments(); + public ISupportClassificationLossFactory LossFunction = new HingeLoss.Options(); /// /// The calibrator for producing probabilities. Default is exponential (aka Platt) calibration. @@ -132,10 +132,10 @@ internal AveragedPerceptronTrainer(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, IClassificationLoss lossFunction = null, - float learningRate = Options.AveragedDefaultArgs.LearningRate, - bool decreaseLearningRate = Options.AveragedDefaultArgs.DecreaseLearningRate, - float l2RegularizerWeight = Options.AveragedDefaultArgs.L2RegularizerWeight, - int numIterations = Options.AveragedDefaultArgs.NumIterations) + float learningRate = Options.AveragedDefault.LearningRate, + bool decreaseLearningRate = Options.AveragedDefault.DecreaseLearningRate, + float l2RegularizerWeight = Options.AveragedDefault.L2RegularizerWeight, + int numIterations = Options.AveragedDefault.NumIterations) : this(env, new Options { LabelColumn = labelColumn, @@ -220,4 +220,4 @@ internal static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnviro calibrator: input.Calibrator, maxCalibrationExamples: input.MaxCalibrationExamples); } } -} \ No newline at end of file +} diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs index e7b04996ac..14cc1ca85d 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs @@ -40,7 +40,7 @@ public sealed class LinearSvmTrainer : OnlineLinearTrainer LossFunctionFactory => LossFunction; [BestFriend] - internal class OgdDefaultArgs : AveragedDefaultArgs + internal class OgdDefaultArgs : AveragedDefault { public new const float LearningRate = 0.1f; public new const bool DecreaseLearningRate = true; diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs index d8ff4b58be..76d0bd6ea0 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs @@ -19,7 +19,7 @@ namespace Microsoft.ML.Trainers.Online /// /// Arguments class for online linear trainers. /// - public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel + public abstract class OnlineLinearOptions : LearnerInputBaseWithLabel { /// /// Number of passes through the training dataset. @@ -27,7 +27,7 @@ public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel [Argument(ArgumentType.AtMostOnce, HelpText = "Number of iterations", ShortName = "iter, numIterations", SortOrder = 50)] [TGUI(Label = "Number of Iterations", Description = "Number of training iterations through data", SuggestedSweeps = "1,10,100")] [TlcModule.SweepableLongParamAttribute("NumIterations", 1, 100, stepSize: 10, isLogScale: true)] - public int NumberOfIterations = OnlineDefaultArgs.NumIterations; + public int NumberOfIterations = OnlineDefault.NumIterations; /// /// Initial weights and bias, comma-separated. @@ -60,7 +60,7 @@ public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel public bool Shuffle = true; [BestFriend] - internal class OnlineDefaultArgs + internal class OnlineDefault { public const int NumIterations = 1; } @@ -70,7 +70,7 @@ public abstract class OnlineLinearTrainer : TrainerEstimat where TTransformer : ISingleFeaturePredictionTransformer where TModel : class { - protected readonly OnlineLinearArguments Args; + protected readonly OnlineLinearOptions OnlineLinearTrainerOptions; protected readonly string Name; /// @@ -136,10 +136,10 @@ protected TrainStateBase(IChannel ch, int numFeatures, LinearModelParameters pre VBufferUtils.Densify(ref Weights); Bias = predictor.Bias; } - else if (!string.IsNullOrWhiteSpace(parent.Args.InitialWeights)) + else if (!string.IsNullOrWhiteSpace(parent.OnlineLinearTrainerOptions.InitialWeights)) { - ch.Info("Initializing weights and bias to " + parent.Args.InitialWeights); - string[] weightStr = parent.Args.InitialWeights.Split(','); + ch.Info("Initializing weights and bias to " + parent.OnlineLinearTrainerOptions.InitialWeights); + string[] weightStr = parent.OnlineLinearTrainerOptions.InitialWeights.Split(','); if (weightStr.Length != numFeatures + 1) { throw ch.Except( @@ -153,13 +153,13 @@ protected TrainStateBase(IChannel ch, int numFeatures, LinearModelParameters pre Weights = new VBuffer(numFeatures, weightValues); Bias = float.Parse(weightStr[numFeatures], CultureInfo.InvariantCulture); } - else if (parent.Args.InitialWeightsDiameter > 0) + else if (parent.OnlineLinearTrainerOptions.InitialWeightsDiameter > 0) { var weightValues = new float[numFeatures]; for (int i = 0; i < numFeatures; i++) - weightValues[i] = parent.Args.InitialWeightsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5); + weightValues[i] = parent.OnlineLinearTrainerOptions.InitialWeightsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5); Weights = new VBuffer(numFeatures, weightValues); - Bias = parent.Args.InitialWeightsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5); + Bias = parent.OnlineLinearTrainerOptions.InitialWeightsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5); } else if (numFeatures <= 1000) Weights = VBufferUtils.CreateDense(numFeatures); @@ -253,14 +253,14 @@ public virtual float Margin(in VBuffer feat) protected virtual bool NeedCalibration => false; - private protected OnlineLinearTrainer(OnlineLinearArguments args, IHostEnvironment env, string name, SchemaShape.Column label) - : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.InitialWeights)) + private protected OnlineLinearTrainer(OnlineLinearOptions options, IHostEnvironment env, string name, SchemaShape.Column label) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(options.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(options.InitialWeights)) { - Contracts.CheckValue(args, nameof(args)); - Contracts.CheckUserArg(args.NumberOfIterations > 0, nameof(args.NumberOfIterations), UserErrorPositive); - Contracts.CheckUserArg(args.InitialWeightsDiameter >= 0, nameof(args.InitialWeightsDiameter), UserErrorNonNegative); + Contracts.CheckValue(options, nameof(options)); + Contracts.CheckUserArg(options.NumberOfIterations > 0, nameof(options.NumberOfIterations), UserErrorPositive); + Contracts.CheckUserArg(options.InitialWeightsDiameter >= 0, nameof(options.InitialWeightsDiameter), UserErrorNonNegative); - Args = args; + OnlineLinearTrainerOptions = options; Name = name; // REVIEW: Caching could be false for one iteration, if we got around the whole shuffling issue. Info = new TrainerInfo(calibration: NeedCalibration, supportIncrementalTrain: true); @@ -309,7 +309,7 @@ public TTransformer Fit(IDataView trainData, IPredictor initialPredictor) private void TrainCore(IChannel ch, RoleMappedData data, TrainStateBase state) { - bool shuffle = Args.Shuffle; + bool shuffle = OnlineLinearTrainerOptions.Shuffle; if (shuffle && !data.Data.CanShuffle) { ch.Warning("Training data does not support shuffling, so ignoring request to shuffle"); @@ -323,7 +323,7 @@ private void TrainCore(IChannel ch, RoleMappedData data, TrainStateBase state) var cursorFactory = new FloatLabelCursor.Factory(data, cursorOpt); long numBad = 0; - while (state.Iteration < Args.NumberOfIterations) + while (state.Iteration < OnlineLinearTrainerOptions.NumberOfIterations) { state.BeginIteration(ch); @@ -341,10 +341,10 @@ private void TrainCore(IChannel ch, RoleMappedData data, TrainStateBase state) { ch.Warning( "Skipped {0} instances with missing features during training (over {1} iterations; {2} inst/iter)", - numBad, Args.NumberOfIterations, numBad / Args.NumberOfIterations); + numBad, OnlineLinearTrainerOptions.NumberOfIterations, numBad / OnlineLinearTrainerOptions.NumberOfIterations); } } private protected abstract TrainStateBase MakeState(IChannel ch, int numFeatures, LinearModelParameters predictor); } -} \ No newline at end of file +} diff --git a/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs index 4ee904c76e..d09a1d9633 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs @@ -33,7 +33,7 @@ public sealed class PoissonRegression : LbfgsTrainerBase : StochasticTrainerBase where TTransformer : ISingleFeaturePredictionTransformer where TModel : class - where TArgs : SdcaTrainerBase.ArgumentsBase, new() + where TArgs : SdcaTrainerBase.OptionsBase, new() { // REVIEW: Making it even faster and more accurate: // 1. Train with not-too-many threads. nt = 2 or 4 seems to be good enough. Didn't seem additional benefit over more threads. @@ -159,7 +159,7 @@ public abstract class SdcaTrainerBase : StochasticT // 3. Don't "guess" the iteration to converge. It is very data-set dependent and hard to control. Always check for at least once to ensure convergence. // 4. Use dual variable updates to infer whether a full iteration of convergence checking is necessary. Convergence checking iteration is time-consuming. - public abstract class ArgumentsBase : LearnerInputBaseWithLabel + public abstract class OptionsBase : LearnerInputBaseWithLabel { [Argument(ArgumentType.AtMostOnce, HelpText = "L2 regularizer constant. By default the l2 constant is automatically inferred based on data set.", NullName = "", ShortName = "l2", SortOrder = 1)] [TGUI(Label = "L2 Regularizer Constant", SuggestedSweeps = ",1e-7,1e-6,1e-5,1e-4,1e-3,1e-2")] @@ -293,7 +293,7 @@ private protected sealed override TModel TrainCore(IChannel ch, RoleMappedData d if (Args.NumThreads.HasValue) { numThreads = Args.NumThreads.Value; - Host.CheckUserArg(numThreads > 0, nameof(ArgumentsBase.NumThreads), "The number of threads must be either null or a positive integer."); + Host.CheckUserArg(numThreads > 0, nameof(OptionsBase.NumThreads), "The number of threads must be either null or a positive integer."); if (0 < Host.ConcurrencyFactor && Host.ConcurrencyFactor < numThreads) { numThreads = Host.ConcurrencyFactor; @@ -1417,7 +1417,7 @@ public abstract class SdcaBinaryTrainerBase : public override TrainerInfo Info { get; } - public class BinaryArgumentBase : ArgumentsBase + public class BinaryArgumentBase : OptionsBase { [Argument(ArgumentType.AtMostOnce, HelpText = "Apply weight to the positive class, for imbalanced data", ShortName = "piw")] public float PositiveInstanceWeight = 1; diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs index 1e61c63490..c0f7ea4446 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs @@ -35,7 +35,7 @@ public class SdcaMultiClassTrainer : SdcaTrainerBase score, Scalar predictedLabel) AveragedPercept Vector features, Scalar weights = null, IClassificationLoss lossFunction = null, - float learningRate = AveragedLinearArguments.AveragedDefaultArgs.LearningRate, - bool decreaseLearningRate = AveragedLinearArguments.AveragedDefaultArgs.DecreaseLearningRate, - float l2RegularizerWeight = AveragedLinearArguments.AveragedDefaultArgs.L2RegularizerWeight, - int numIterations = AveragedLinearArguments.AveragedDefaultArgs.NumIterations, + float learningRate = AveragedLinearOptions.AveragedDefault.LearningRate, + bool decreaseLearningRate = AveragedLinearOptions.AveragedDefault.DecreaseLearningRate, + float l2RegularizerWeight = AveragedLinearOptions.AveragedDefault.L2RegularizerWeight, + int numIterations = AveragedLinearOptions.AveragedDefault.NumIterations, Action onFit = null ) { @@ -168,7 +168,7 @@ public static Scalar OnlineGradientDescent(this RegressionCatalog.Regress float learningRate = OnlineGradientDescentTrainer.Options.OgdDefaultArgs.LearningRate, bool decreaseLearningRate = OnlineGradientDescentTrainer.Options.OgdDefaultArgs.DecreaseLearningRate, float l2RegularizerWeight = OnlineGradientDescentTrainer.Options.OgdDefaultArgs.L2RegularizerWeight, - int numIterations = OnlineLinearArguments.OnlineDefaultArgs.NumIterations, + int numIterations = OnlineLinearOptions.OnlineDefault.NumIterations, Action onFit = null) { OnlineLinearStaticUtils.CheckUserParams(label, features, weights, learningRate, l2RegularizerWeight, numIterations, onFit); diff --git a/src/Microsoft.ML.StaticPipe/TextLoaderStatic.cs b/src/Microsoft.ML.StaticPipe/TextLoaderStatic.cs index 5d6aeb6c20..4a114cd88b 100644 --- a/src/Microsoft.ML.StaticPipe/TextLoaderStatic.cs +++ b/src/Microsoft.ML.StaticPipe/TextLoaderStatic.cs @@ -45,7 +45,7 @@ public static DataReader CreateReader<[IsShape] TSha env.CheckValueOrNull(files); // Populate all args except the columns. - var args = new TextLoader.Arguments(); + var args = new TextLoader.Options(); args.AllowQuoting = allowQuoting; args.AllowSparse = allowSparse; args.HasHeader = hasHeader; @@ -65,15 +65,15 @@ public static DataReader CreateReader<[IsShape] TSha private sealed class TextReconciler : ReaderReconciler { - private readonly TextLoader.Arguments _args; + private readonly TextLoader.Options _args; private readonly IMultiStreamSource _files; - public TextReconciler(TextLoader.Arguments args, IMultiStreamSource files) + public TextReconciler(TextLoader.Options options, IMultiStreamSource files) { - Contracts.AssertValue(args); + Contracts.AssertValue(options); Contracts.AssertValueOrNull(files); - _args = args; + _args = options; _files = files; } diff --git a/src/Microsoft.ML.Sweeper/Algorithms/Grid.cs b/src/Microsoft.ML.Sweeper/Algorithms/Grid.cs index 1466b78ae4..281fdbc0ce 100644 --- a/src/Microsoft.ML.Sweeper/Algorithms/Grid.cs +++ b/src/Microsoft.ML.Sweeper/Algorithms/Grid.cs @@ -9,9 +9,9 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Sweeper; -[assembly: LoadableClass(typeof(RandomGridSweeper), typeof(RandomGridSweeper.Arguments), typeof(SignatureSweeper), +[assembly: LoadableClass(typeof(RandomGridSweeper), typeof(RandomGridSweeper.Options), typeof(SignatureSweeper), "Random Grid Sweeper", "RandomGridSweeper", "RandomGrid")] -[assembly: LoadableClass(typeof(RandomGridSweeper), typeof(RandomGridSweeper.Arguments), typeof(SignatureSweeperFromParameterList), +[assembly: LoadableClass(typeof(RandomGridSweeper), typeof(RandomGridSweeper.Options), typeof(SignatureSweeperFromParameterList), "Random Grid Sweeper", "RandomGridSweeperParamList", "RandomGridpl")] namespace Microsoft.ML.Sweeper @@ -26,7 +26,7 @@ namespace Microsoft.ML.Sweeper /// public abstract class SweeperBase : ISweeper { - public class ArgumentsBase + public class OptionsBase { [Argument(ArgumentType.Multiple, HelpText = "Swept parameters", ShortName = "p", SignatureType = typeof(SignatureSweeperParameter))] public IComponentFactory[] SweptParameters; @@ -35,11 +35,11 @@ public class ArgumentsBase public int Retries = 10; } - private readonly ArgumentsBase _args; + private readonly OptionsBase _args; protected readonly IValueGenerator[] SweepParameters; protected readonly IHost Host; - protected SweeperBase(ArgumentsBase args, IHostEnvironment env, string name) + protected SweeperBase(OptionsBase args, IHostEnvironment env, string name) { Contracts.CheckValue(env, nameof(env)); env.CheckNonWhiteSpace(name, nameof(name)); @@ -52,7 +52,7 @@ protected SweeperBase(ArgumentsBase args, IHostEnvironment env, string name) SweepParameters = args.SweptParameters.Select(p => p.CreateComponent(Host)).ToArray(); } - protected SweeperBase(ArgumentsBase args, IHostEnvironment env, IValueGenerator[] sweepParameters, string name) + protected SweeperBase(OptionsBase args, IHostEnvironment env, IValueGenerator[] sweepParameters, string name) { Contracts.CheckValue(env, nameof(env)); env.CheckNonWhiteSpace(name, nameof(name)); @@ -110,20 +110,20 @@ public sealed class RandomGridSweeper : SweeperBase // This is a parallel array to the _permutation array and stores the (already generated) parameter sets private readonly ParameterSet[] _cache; - public sealed class Arguments : ArgumentsBase + public sealed class Options : OptionsBase { [Argument(ArgumentType.LastOccurenceWins, HelpText = "Limit for the number of combinations to generate the entire grid.", ShortName = "maxpoints")] public int MaxGridPoints = 1000000; } - public RandomGridSweeper(IHostEnvironment env, Arguments args) - : base(args, env, "RandomGrid") + public RandomGridSweeper(IHostEnvironment env, Options options) + : base(options, env, "RandomGrid") { _nGridPoints = 1; foreach (var sweptParameter in SweepParameters) { _nGridPoints *= sweptParameter.Count; - if (_nGridPoints > args.MaxGridPoints) + if (_nGridPoints > options.MaxGridPoints) _nGridPoints = 0; } if (_nGridPoints != 0) @@ -133,14 +133,14 @@ public RandomGridSweeper(IHostEnvironment env, Arguments args) } } - public RandomGridSweeper(IHostEnvironment env, Arguments args, IValueGenerator[] sweepParameters) - : base(args, env, sweepParameters, "RandomGrid") + public RandomGridSweeper(IHostEnvironment env, Options options, IValueGenerator[] sweepParameters) + : base(options, env, sweepParameters, "RandomGrid") { _nGridPoints = 1; foreach (var sweptParameter in SweepParameters) { _nGridPoints *= sweptParameter.Count; - if (_nGridPoints > args.MaxGridPoints) + if (_nGridPoints > options.MaxGridPoints) _nGridPoints = 0; } if (_nGridPoints != 0) diff --git a/src/Microsoft.ML.Sweeper/Algorithms/KdoSweeper.cs b/src/Microsoft.ML.Sweeper/Algorithms/KdoSweeper.cs index 469988ec7c..cf413f5aad 100644 --- a/src/Microsoft.ML.Sweeper/Algorithms/KdoSweeper.cs +++ b/src/Microsoft.ML.Sweeper/Algorithms/KdoSweeper.cs @@ -12,7 +12,7 @@ using Microsoft.ML.Trainers.FastTree; using Float = System.Single; -[assembly: LoadableClass(typeof(KdoSweeper), typeof(KdoSweeper.Arguments), typeof(SignatureSweeper), +[assembly: LoadableClass(typeof(KdoSweeper), typeof(KdoSweeper.Options), typeof(SignatureSweeper), "KDO Sweeper", "KDOSweeper", "KDO")] namespace Microsoft.ML.Sweeper.Algorithms @@ -36,7 +36,7 @@ namespace Microsoft.ML.Sweeper.Algorithms public sealed class KdoSweeper : ISweeper { - public sealed class Arguments + public sealed class Options { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Swept parameters", ShortName = "p", SignatureType = typeof(SignatureSweeperParameter))] public IComponentFactory[] SweptParameters; @@ -77,7 +77,7 @@ public sealed class Arguments private readonly ISweeper _randomSweeper; private readonly ISweeper _redundantSweeper; - private readonly Arguments _args; + private readonly Options _args; private readonly IHost _host; private readonly IValueGenerator[] _sweepParameters; @@ -85,22 +85,22 @@ public sealed class Arguments private readonly SortedSet _alreadySeenConfigs; private readonly List _randomParamSets; - public KdoSweeper(IHostEnvironment env, Arguments args) + public KdoSweeper(IHostEnvironment env, Options options) { Contracts.CheckValue(env, nameof(env)); _host = env.Register("Sweeper"); - _host.CheckUserArg(args.NumberInitialPopulation > 1, nameof(args.NumberInitialPopulation), "Must be greater than 1"); - _host.CheckUserArg(args.HistoryLength > 1, nameof(args.HistoryLength), "Must be greater than 1"); - _host.CheckUserArg(args.MinimumMutationSpread >= 0, nameof(args.MinimumMutationSpread), "Must be nonnegative"); - _host.CheckUserArg(0 <= args.ProportionRandom && args.ProportionRandom <= 1, nameof(args.ProportionRandom), "Must be in [0, 1]"); - _host.CheckUserArg(args.WeightRescalingPower >= 1, nameof(args.WeightRescalingPower), "Must be greater or equal to 1"); - - _args = args; - _host.CheckUserArg(Utils.Size(args.SweptParameters) > 0, nameof(args.SweptParameters), "KDO sweeper needs at least one parameter to sweep over"); - _sweepParameters = args.SweptParameters.Select(p => p.CreateComponent(_host)).ToArray(); - _randomSweeper = new UniformRandomSweeper(env, new SweeperBase.ArgumentsBase(), _sweepParameters); - _redundantSweeper = new UniformRandomSweeper(env, new SweeperBase.ArgumentsBase { Retries = 0 }, _sweepParameters); + _host.CheckUserArg(options.NumberInitialPopulation > 1, nameof(options.NumberInitialPopulation), "Must be greater than 1"); + _host.CheckUserArg(options.HistoryLength > 1, nameof(options.HistoryLength), "Must be greater than 1"); + _host.CheckUserArg(options.MinimumMutationSpread >= 0, nameof(options.MinimumMutationSpread), "Must be nonnegative"); + _host.CheckUserArg(0 <= options.ProportionRandom && options.ProportionRandom <= 1, nameof(options.ProportionRandom), "Must be in [0, 1]"); + _host.CheckUserArg(options.WeightRescalingPower >= 1, nameof(options.WeightRescalingPower), "Must be greater or equal to 1"); + + _args = options; + _host.CheckUserArg(Utils.Size(options.SweptParameters) > 0, nameof(options.SweptParameters), "KDO sweeper needs at least one parameter to sweep over"); + _sweepParameters = options.SweptParameters.Select(p => p.CreateComponent(_host)).ToArray(); + _randomSweeper = new UniformRandomSweeper(env, new SweeperBase.OptionsBase(), _sweepParameters); + _redundantSweeper = new UniformRandomSweeper(env, new SweeperBase.OptionsBase { Retries = 0 }, _sweepParameters); _spu = new SweeperProbabilityUtils(_host); _alreadySeenConfigs = new SortedSet(new FloatArrayComparer()); _randomParamSets = new List(); diff --git a/src/Microsoft.ML.Sweeper/Algorithms/NelderMead.cs b/src/Microsoft.ML.Sweeper/Algorithms/NelderMead.cs index 51dd6c91f0..14c58ca49a 100644 --- a/src/Microsoft.ML.Sweeper/Algorithms/NelderMead.cs +++ b/src/Microsoft.ML.Sweeper/Algorithms/NelderMead.cs @@ -11,20 +11,20 @@ using Microsoft.ML.Sweeper; using Float = System.Single; -[assembly: LoadableClass(typeof(NelderMeadSweeper), typeof(NelderMeadSweeper.Arguments), typeof(SignatureSweeper), +[assembly: LoadableClass(typeof(NelderMeadSweeper), typeof(NelderMeadSweeper.Options), typeof(SignatureSweeper), "Nelder Mead Sweeper", "NelderMeadSweeper", "NelderMead", "NM")] namespace Microsoft.ML.Sweeper { public sealed class NelderMeadSweeper : ISweeper { - public sealed class Arguments + public sealed class Options { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Swept parameters", ShortName = "p", SignatureType = typeof(SignatureSweeperParameter))] public IComponentFactory[] SweptParameters; [Argument(ArgumentType.LastOccurenceWins, HelpText = "The sweeper used to get the initial results.", ShortName = "init", SignatureType = typeof(SignatureSweeperFromParameterList))] - public IComponentFactory FirstBatchSweeper = ComponentFactoryUtils.CreateFromFunction((host, array) => new UniformRandomSweeper(host, new SweeperBase.ArgumentsBase(), array)); + public IComponentFactory FirstBatchSweeper = ComponentFactoryUtils.CreateFromFunction((host, array) => new UniformRandomSweeper(host, new SweeperBase.OptionsBase(), array)); [Argument(ArgumentType.AtMostOnce, HelpText = "Seed for the random number generator for the first batch sweeper", ShortName = "seed")] public int RandomSeed; @@ -66,7 +66,7 @@ private enum OptimizationStage } private readonly ISweeper _initSweeper; - private readonly Arguments _args; + private readonly Options _args; private SortedList _simplexVertices; private readonly int _dim; @@ -84,21 +84,21 @@ private enum OptimizationStage private readonly List _sweepParameters; - public NelderMeadSweeper(IHostEnvironment env, Arguments args) + public NelderMeadSweeper(IHostEnvironment env, Options options) { Contracts.CheckValue(env, nameof(env)); - env.CheckUserArg(-1 < args.DeltaInsideContraction, nameof(args.DeltaInsideContraction), "Must be greater than -1"); - env.CheckUserArg(args.DeltaInsideContraction < 0, nameof(args.DeltaInsideContraction), "Must be less than 0"); - env.CheckUserArg(0 < args.DeltaOutsideContraction, nameof(args.DeltaOutsideContraction), "Must be greater than 0"); - env.CheckUserArg(args.DeltaReflection > args.DeltaOutsideContraction, nameof(args.DeltaReflection), "Must be greater than " + nameof(args.DeltaOutsideContraction)); - env.CheckUserArg(args.DeltaExpansion > args.DeltaReflection, nameof(args.DeltaExpansion), "Must be greater than " + nameof(args.DeltaReflection)); - env.CheckUserArg(0 < args.GammaShrink && args.GammaShrink < 1, nameof(args.GammaShrink), "Must be between 0 and 1"); - env.CheckValue(args.FirstBatchSweeper, nameof(args.FirstBatchSweeper) , "First Batch Sweeper Contains Null Value"); + env.CheckUserArg(-1 < options.DeltaInsideContraction, nameof(options.DeltaInsideContraction), "Must be greater than -1"); + env.CheckUserArg(options.DeltaInsideContraction < 0, nameof(options.DeltaInsideContraction), "Must be less than 0"); + env.CheckUserArg(0 < options.DeltaOutsideContraction, nameof(options.DeltaOutsideContraction), "Must be greater than 0"); + env.CheckUserArg(options.DeltaReflection > options.DeltaOutsideContraction, nameof(options.DeltaReflection), "Must be greater than " + nameof(options.DeltaOutsideContraction)); + env.CheckUserArg(options.DeltaExpansion > options.DeltaReflection, nameof(options.DeltaExpansion), "Must be greater than " + nameof(options.DeltaReflection)); + env.CheckUserArg(0 < options.GammaShrink && options.GammaShrink < 1, nameof(options.GammaShrink), "Must be between 0 and 1"); + env.CheckValue(options.FirstBatchSweeper, nameof(options.FirstBatchSweeper) , "First Batch Sweeper Contains Null Value"); - _args = args; + _args = options; _sweepParameters = new List(); - foreach (var sweptParameter in args.SweptParameters) + foreach (var sweptParameter in options.SweptParameters) { var parameter = sweptParameter.CreateComponent(env); // REVIEW: ideas about how to support discrete values: @@ -108,13 +108,13 @@ public NelderMeadSweeper(IHostEnvironment env, Arguments args) // the metric values that we get when using them. (For example, if, for a given discrete value, we get a bad result, // we lower its weight, but if we get a good result we increase its weight). var parameterNumeric = parameter as INumericValueGenerator; - env.CheckUserArg(parameterNumeric != null, nameof(args.SweptParameters), "Nelder-Mead sweeper can only sweep over numeric parameters"); + env.CheckUserArg(parameterNumeric != null, nameof(options.SweptParameters), "Nelder-Mead sweeper can only sweep over numeric parameters"); _sweepParameters.Add(parameterNumeric); } - _initSweeper = args.FirstBatchSweeper.CreateComponent(env, _sweepParameters.ToArray()); + _initSweeper = options.FirstBatchSweeper.CreateComponent(env, _sweepParameters.ToArray()); _dim = _sweepParameters.Count; - env.CheckUserArg(_dim > 1, nameof(args.SweptParameters), "Nelder-Mead sweeper needs at least two parameters to sweep over."); + env.CheckUserArg(_dim > 1, nameof(options.SweptParameters), "Nelder-Mead sweeper needs at least two parameters to sweep over."); _simplexVertices = new SortedList(new SimplexVertexComparer()); _stage = OptimizationStage.NeedReflectionPoint; diff --git a/src/Microsoft.ML.Sweeper/Algorithms/Random.cs b/src/Microsoft.ML.Sweeper/Algorithms/Random.cs index 6711696bf4..105c6b648d 100644 --- a/src/Microsoft.ML.Sweeper/Algorithms/Random.cs +++ b/src/Microsoft.ML.Sweeper/Algorithms/Random.cs @@ -6,9 +6,9 @@ using Microsoft.ML; using Microsoft.ML.Sweeper; -[assembly: LoadableClass(typeof(UniformRandomSweeper), typeof(SweeperBase.ArgumentsBase), typeof(SignatureSweeper), +[assembly: LoadableClass(typeof(UniformRandomSweeper), typeof(SweeperBase.OptionsBase), typeof(SignatureSweeper), "Uniform Random Sweeper", "UniformRandomSweeper", "UniformRandom")] -[assembly: LoadableClass(typeof(UniformRandomSweeper), typeof(SweeperBase.ArgumentsBase), typeof(SignatureSweeperFromParameterList), +[assembly: LoadableClass(typeof(UniformRandomSweeper), typeof(SweeperBase.OptionsBase), typeof(SignatureSweeperFromParameterList), "Uniform Random Sweeper", "UniformRandomSweeperParamList", "UniformRandompl")] namespace Microsoft.ML.Sweeper @@ -18,12 +18,12 @@ namespace Microsoft.ML.Sweeper /// public sealed class UniformRandomSweeper : SweeperBase { - public UniformRandomSweeper(IHostEnvironment env, ArgumentsBase args) + public UniformRandomSweeper(IHostEnvironment env, OptionsBase args) : base(args, env, "UniformRandom") { } - public UniformRandomSweeper(IHostEnvironment env, ArgumentsBase args, IValueGenerator[] sweepParameters) + public UniformRandomSweeper(IHostEnvironment env, OptionsBase args, IValueGenerator[] sweepParameters) : base(args, env, sweepParameters, "UniformRandom") { } diff --git a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs index b7b1c9563e..6053bdc4bd 100644 --- a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs +++ b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs @@ -15,7 +15,7 @@ using Microsoft.ML.Trainers.FastTree; using Float = System.Single; -[assembly: LoadableClass(typeof(SmacSweeper), typeof(SmacSweeper.Arguments), typeof(SignatureSweeper), +[assembly: LoadableClass(typeof(SmacSweeper), typeof(SmacSweeper.Options), typeof(SignatureSweeper), "SMAC Sweeper", "SMACSweeper", "SMAC")] namespace Microsoft.ML.Sweeper @@ -24,7 +24,7 @@ namespace Microsoft.ML.Sweeper //encapsulating common functionality. This seems like a good plan to persue. public sealed class SmacSweeper : ISweeper { - public sealed class Arguments + public sealed class Options { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Swept parameters", ShortName = "p", SignatureType = typeof(SignatureSweeperParameter))] public IComponentFactory[] SweptParameters; @@ -61,28 +61,28 @@ public sealed class Arguments } private readonly ISweeper _randomSweeper; - private readonly Arguments _args; + private readonly Options _args; private readonly IHost _host; private readonly IValueGenerator[] _sweepParameters; - public SmacSweeper(IHostEnvironment env, Arguments args) + public SmacSweeper(IHostEnvironment env, Options options) { Contracts.CheckValue(env, nameof(env)); _host = env.Register("Sweeper"); - _host.CheckUserArg(args.NumOfTrees > 0, nameof(args.NumOfTrees), "parameter must be greater than 0"); - _host.CheckUserArg(args.NMinForSplit > 1, nameof(args.NMinForSplit), "parameter must be greater than 1"); - _host.CheckUserArg(args.SplitRatio > 0 && args.SplitRatio <= 1, nameof(args.SplitRatio), "parameter must be in range (0,1]."); - _host.CheckUserArg(args.NumberInitialPopulation > 1, nameof(args.NumberInitialPopulation), "parameter must be greater than 1"); - _host.CheckUserArg(args.LocalSearchParentCount > 0, nameof(args.LocalSearchParentCount), "parameter must be greater than 0"); - _host.CheckUserArg(args.NumRandomEISearchConfigurations > 0, nameof(args.NumRandomEISearchConfigurations), "parameter must be greater than 0"); - _host.CheckUserArg(args.NumNeighborsForNumericalParams > 0, nameof(args.NumNeighborsForNumericalParams), "parameter must be greater than 0"); - - _args = args; - _host.CheckUserArg(Utils.Size(args.SweptParameters) > 0, nameof(args.SweptParameters), "SMAC sweeper needs at least one parameter to sweep over"); - _sweepParameters = args.SweptParameters.Select(p => p.CreateComponent(_host)).ToArray(); - _randomSweeper = new UniformRandomSweeper(env, new SweeperBase.ArgumentsBase(), _sweepParameters); + _host.CheckUserArg(options.NumOfTrees > 0, nameof(options.NumOfTrees), "parameter must be greater than 0"); + _host.CheckUserArg(options.NMinForSplit > 1, nameof(options.NMinForSplit), "parameter must be greater than 1"); + _host.CheckUserArg(options.SplitRatio > 0 && options.SplitRatio <= 1, nameof(options.SplitRatio), "parameter must be in range (0,1]."); + _host.CheckUserArg(options.NumberInitialPopulation > 1, nameof(options.NumberInitialPopulation), "parameter must be greater than 1"); + _host.CheckUserArg(options.LocalSearchParentCount > 0, nameof(options.LocalSearchParentCount), "parameter must be greater than 0"); + _host.CheckUserArg(options.NumRandomEISearchConfigurations > 0, nameof(options.NumRandomEISearchConfigurations), "parameter must be greater than 0"); + _host.CheckUserArg(options.NumNeighborsForNumericalParams > 0, nameof(options.NumNeighborsForNumericalParams), "parameter must be greater than 0"); + + _args = options; + _host.CheckUserArg(Utils.Size(options.SweptParameters) > 0, nameof(options.SweptParameters), "SMAC sweeper needs at least one parameter to sweep over"); + _sweepParameters = options.SweptParameters.Select(p => p.CreateComponent(_host)).ToArray(); + _randomSweeper = new UniformRandomSweeper(env, new SweeperBase.OptionsBase(), _sweepParameters); } public ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable previousRuns = null) diff --git a/src/Microsoft.ML.Sweeper/AsyncSweeper.cs b/src/Microsoft.ML.Sweeper/AsyncSweeper.cs index 4713bb54dd..579f20c53b 100644 --- a/src/Microsoft.ML.Sweeper/AsyncSweeper.cs +++ b/src/Microsoft.ML.Sweeper/AsyncSweeper.cs @@ -13,11 +13,11 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Sweeper; -[assembly: LoadableClass(typeof(SimpleAsyncSweeper), typeof(SweeperBase.ArgumentsBase), typeof(SignatureAsyncSweeper), +[assembly: LoadableClass(typeof(SimpleAsyncSweeper), typeof(SweeperBase.OptionsBase), typeof(SignatureAsyncSweeper), "Asynchronous Uniform Random Sweeper", "UniformRandomSweeper", "UniformRandom")] -[assembly: LoadableClass(typeof(SimpleAsyncSweeper), typeof(RandomGridSweeper.Arguments), typeof(SignatureAsyncSweeper), +[assembly: LoadableClass(typeof(SimpleAsyncSweeper), typeof(RandomGridSweeper.Options), typeof(SignatureAsyncSweeper), "Asynchronous Random Grid Sweeper", "RandomGridSweeper", "RandomGrid")] -[assembly: LoadableClass(typeof(DeterministicSweeperAsync), typeof(DeterministicSweeperAsync.Arguments), typeof(SignatureAsyncSweeper), +[assembly: LoadableClass(typeof(DeterministicSweeperAsync), typeof(DeterministicSweeperAsync.Options), typeof(SignatureAsyncSweeper), "Asynchronous and Deterministic Sweeper", "DeterministicSweeper", "Deterministic")] namespace Microsoft.ML.Sweeper @@ -88,13 +88,13 @@ private SimpleAsyncSweeper(ISweeper baseSweeper) _results = new List(); } - public SimpleAsyncSweeper(IHostEnvironment env, UniformRandomSweeper.ArgumentsBase args) - : this(new UniformRandomSweeper(env, args)) + public SimpleAsyncSweeper(IHostEnvironment env, UniformRandomSweeper.OptionsBase options) + : this(new UniformRandomSweeper(env, options)) { } - public SimpleAsyncSweeper(IHostEnvironment env, RandomGridSweeper.Arguments args) - : this(new UniformRandomSweeper(env, args)) + public SimpleAsyncSweeper(IHostEnvironment env, RandomGridSweeper.Options options) + : this(new UniformRandomSweeper(env, options)) { } @@ -150,7 +150,7 @@ public void Dispose() /// public sealed class DeterministicSweeperAsync : IAsyncSweeper, IDisposable { - public sealed class Arguments + public sealed class Options { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Base sweeper", ShortName = "sweeper", SignatureType = typeof(SignatureSweeper))] public IComponentFactory Sweeper; @@ -190,21 +190,21 @@ public sealed class Arguments // The number of ParameterSets generated so far. Used for indexing. private int _numGenerated; - public DeterministicSweeperAsync(IHostEnvironment env, Arguments args) + public DeterministicSweeperAsync(IHostEnvironment env, Options options) { - _host = env.Register("DeterministicSweeperAsync", args.RandomSeed); - _host.CheckValue(args.Sweeper, nameof(args.Sweeper), "Please specify a sweeper"); - _host.CheckUserArg(args.BatchSize > 0, nameof(args.BatchSize), "Batch size must be positive"); - _host.CheckUserArg(args.Relaxation >= 0, nameof(args.Relaxation), "Synchronization relaxation must be non-negative"); - _host.CheckUserArg(args.Relaxation <= args.BatchSize, nameof(args.Relaxation), + _host = env.Register("DeterministicSweeperAsync", options.RandomSeed); + _host.CheckValue(options.Sweeper, nameof(options.Sweeper), "Please specify a sweeper"); + _host.CheckUserArg(options.BatchSize > 0, nameof(options.BatchSize), "Batch size must be positive"); + _host.CheckUserArg(options.Relaxation >= 0, nameof(options.Relaxation), "Synchronization relaxation must be non-negative"); + _host.CheckUserArg(options.Relaxation <= options.BatchSize, nameof(options.Relaxation), "Synchronization relaxation cannot be larger than batch size"); - _batchSize = args.BatchSize; - _baseSweeper = args.Sweeper.CreateComponent(_host); - _host.CheckUserArg(!(_baseSweeper is NelderMeadSweeper) || args.Relaxation == 0, nameof(args.Relaxation), + _batchSize = options.BatchSize; + _baseSweeper = options.Sweeper.CreateComponent(_host); + _host.CheckUserArg(!(_baseSweeper is NelderMeadSweeper) || options.Relaxation == 0, nameof(options.Relaxation), "Nelder-Mead requires full synchronization (relaxation = 0)"); _cts = new CancellationTokenSource(); - _relaxation = args.Relaxation; + _relaxation = options.Relaxation; _lock = new object(); _results = new List(); _nullRuns = new HashSet(); diff --git a/src/Microsoft.ML.Sweeper/ConfigRunner.cs b/src/Microsoft.ML.Sweeper/ConfigRunner.cs index cbe2a99900..5f2e5269f7 100644 --- a/src/Microsoft.ML.Sweeper/ConfigRunner.cs +++ b/src/Microsoft.ML.Sweeper/ConfigRunner.cs @@ -14,7 +14,7 @@ using ResultProcessorInternal = Microsoft.ML.ResultProcessor; -[assembly: LoadableClass(typeof(LocalExeConfigRunner), typeof(LocalExeConfigRunner.Arguments), typeof(SignatureConfigRunner), +[assembly: LoadableClass(typeof(LocalExeConfigRunner), typeof(LocalExeConfigRunner.Options), typeof(SignatureConfigRunner), "Local Sweep Config Runner", "Local")] namespace Microsoft.ML.Sweeper @@ -30,7 +30,7 @@ public interface IConfigRunner public abstract class ExeConfigRunnerBase : IConfigRunner { - public abstract class ArgumentsBase + public abstract class OptionsBase { [Argument(ArgumentType.AtMostOnce, HelpText = "Command pattern for the sweeps", ShortName = "pattern")] public string ArgsPattern; @@ -46,7 +46,7 @@ public abstract class ArgumentsBase [Argument(ArgumentType.Multiple, HelpText = "Specify how to extract the metrics from the result file.", ShortName = "ev", SignatureType = typeof(SignatureSweepResultEvaluator))] public IComponentFactory> ResultProcessor = ComponentFactoryUtils.CreateFromFunction( - env => new InternalSweepResultEvaluator(env, new InternalSweepResultEvaluator.Arguments())); + env => new InternalSweepResultEvaluator(env, new InternalSweepResultEvaluator.Options())); [Argument(ArgumentType.AtMostOnce, Hide = true)] public bool CalledFromUnitTestSuite; @@ -63,7 +63,7 @@ public abstract class ArgumentsBase private readonly bool _calledFromUnitTestSuite; - protected ExeConfigRunnerBase(ArgumentsBase args, IHostEnvironment env, string registrationName) + protected ExeConfigRunnerBase(OptionsBase args, IHostEnvironment env, string registrationName) { Contracts.AssertValue(env); Host = env.Register(registrationName); @@ -82,7 +82,7 @@ protected virtual void ProcessFullExePath(string exe) Exe = GetFullExePath(exe); if (!File.Exists(Exe) && !File.Exists(Exe + ".exe")) - throw Host.ExceptUserArg(nameof(ArgumentsBase.Exe), "Executable {0} not found", Exe); + throw Host.ExceptUserArg(nameof(OptionsBase.Exe), "Executable {0} not found", Exe); } protected virtual string GetFullExePath(string exe) @@ -180,7 +180,7 @@ protected string GetFilePath(int i, string kind) public sealed class LocalExeConfigRunner : ExeConfigRunnerBase { - public sealed class Arguments : ArgumentsBase + public sealed class Options : OptionsBase { [Argument(ArgumentType.AtMostOnce, HelpText = "The number of threads to use for the sweep (default auto determined by the number of cores)", ShortName = "t")] public int? NumThreads; @@ -188,13 +188,13 @@ public sealed class Arguments : ArgumentsBase private readonly ParallelOptions _parallelOptions; - public LocalExeConfigRunner(IHostEnvironment env, Arguments args) - : base(args, env, "LocalExeSweepEvaluator") + public LocalExeConfigRunner(IHostEnvironment env, Options options) + : base(options, env, "LocalExeSweepEvaluator") { - Contracts.CheckParam(args.NumThreads == null || args.NumThreads.Value > 0, nameof(args.NumThreads), "Cannot be 0 or negative"); - _parallelOptions = new ParallelOptions { MaxDegreeOfParallelism = args.NumThreads ?? -1 }; - Contracts.AssertNonEmpty(args.OutputFolderName); - ProcessFullExePath(args.Exe); + Contracts.CheckParam(options.NumThreads == null || options.NumThreads.Value > 0, nameof(options.NumThreads), "Cannot be 0 or negative"); + _parallelOptions = new ParallelOptions { MaxDegreeOfParallelism = options.NumThreads ?? -1 }; + Contracts.AssertNonEmpty(options.OutputFolderName); + ProcessFullExePath(options.Exe); } protected override IEnumerable RunConfigsCore(ParameterSet[] sweeps, IChannel ch, int min) diff --git a/src/Microsoft.ML.Sweeper/SweepCommand.cs b/src/Microsoft.ML.Sweeper/SweepCommand.cs index c1959272fb..2e27338040 100644 --- a/src/Microsoft.ML.Sweeper/SweepCommand.cs +++ b/src/Microsoft.ML.Sweeper/SweepCommand.cs @@ -24,7 +24,7 @@ public sealed class Arguments { [Argument(ArgumentType.Multiple, HelpText = "Config runner", ShortName = "run,ev,evaluator", SignatureType = typeof(SignatureConfigRunner))] public IComponentFactory Runner = ComponentFactoryUtils.CreateFromFunction( - env => new LocalExeConfigRunner(env, new LocalExeConfigRunner.Arguments())); + env => new LocalExeConfigRunner(env, new LocalExeConfigRunner.Options())); [Argument(ArgumentType.Multiple, HelpText = "Sweeper", ShortName = "s", SignatureType = typeof(SignatureSweeper))] public IComponentFactory Sweeper; diff --git a/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs b/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs index a55f8d4647..5b5f298009 100644 --- a/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs +++ b/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs @@ -10,14 +10,14 @@ using Microsoft.ML.Sweeper; using ResultProcessor = Microsoft.ML.ResultProcessor; -[assembly: LoadableClass(typeof(InternalSweepResultEvaluator), typeof(InternalSweepResultEvaluator.Arguments), typeof(SignatureSweepResultEvaluator), +[assembly: LoadableClass(typeof(InternalSweepResultEvaluator), typeof(InternalSweepResultEvaluator.Options), typeof(SignatureSweepResultEvaluator), "TLC Sweep Result Evaluator", "TlcEvaluator", "Tlc")] namespace Microsoft.ML.Sweeper { public class InternalSweepResultEvaluator : ISweepResultEvaluator { - public class Arguments + public sealed class Options { [Argument(ArgumentType.LastOccurenceWins, HelpText = "The sweeper used to get the initial results.", ShortName = "m")] public string Metric = "AUC"; @@ -28,7 +28,7 @@ public class Arguments private readonly IHost _host; - public InternalSweepResultEvaluator(IHostEnvironment env, Arguments args) + public InternalSweepResultEvaluator(IHostEnvironment env, Options args) { Contracts.CheckValue(env, nameof(env)); _host = env.Register("InternalSweepResultEvaluator"); diff --git a/src/Microsoft.ML.Sweeper/SynthConfigRunner.cs b/src/Microsoft.ML.Sweeper/SynthConfigRunner.cs index d112c0659e..21555277af 100644 --- a/src/Microsoft.ML.Sweeper/SynthConfigRunner.cs +++ b/src/Microsoft.ML.Sweeper/SynthConfigRunner.cs @@ -11,7 +11,7 @@ using Microsoft.ML.CommandLine; using Microsoft.ML.Sweeper; -[assembly: LoadableClass(typeof(SynthConfigRunner), typeof(SynthConfigRunner.Arguments), typeof(SignatureConfigRunner), +[assembly: LoadableClass(typeof(SynthConfigRunner), typeof(SynthConfigRunner.Options), typeof(SignatureConfigRunner), "", "Synth")] namespace Microsoft.ML.Sweeper @@ -22,7 +22,7 @@ namespace Microsoft.ML.Sweeper /// public sealed class SynthConfigRunner : ExeConfigRunnerBase { - public sealed class Arguments : ArgumentsBase + public sealed class Options : OptionsBase { [Argument(ArgumentType.AtMostOnce, HelpText = "The number of threads to use for the sweep (default auto determined by the number of cores)", ShortName = "t")] public int? NumThreads; @@ -30,13 +30,13 @@ public sealed class Arguments : ArgumentsBase private readonly ParallelOptions _parallelOptions; - public SynthConfigRunner(IHostEnvironment env, Arguments args) - : base(args, env, "SynthSweepEvaluator") + public SynthConfigRunner(IHostEnvironment env, Options options) + : base(options, env, "SynthSweepEvaluator") { - Host.CheckUserArg(args.NumThreads == null || args.NumThreads.Value > 0, nameof(args.NumThreads), "Must be positive"); - _parallelOptions = new ParallelOptions { MaxDegreeOfParallelism = args.NumThreads ?? -1 }; - Host.AssertNonEmpty(args.OutputFolderName); - ProcessFullExePath(args.Exe); + Host.CheckUserArg(options.NumThreads == null || options.NumThreads.Value > 0, nameof(options.NumThreads), "Must be positive"); + _parallelOptions = new ParallelOptions { MaxDegreeOfParallelism = options.NumThreads ?? -1 }; + Host.AssertNonEmpty(options.OutputFolderName); + ProcessFullExePath(options.Exe); } protected override IEnumerable RunConfigsCore(ParameterSet[] sweeps, IChannel ch, int min) diff --git a/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs b/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs index 40915c484c..d64150c78d 100644 --- a/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs +++ b/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs @@ -10,10 +10,10 @@ using Microsoft.ML.Model; using Microsoft.ML.Transforms; -[assembly: LoadableClass(typeof(GaussianFourierSampler), typeof(GaussianFourierSampler.Arguments), typeof(SignatureFourierDistributionSampler), +[assembly: LoadableClass(typeof(GaussianFourierSampler), typeof(GaussianFourierSampler.Options), typeof(SignatureFourierDistributionSampler), "Gaussian Kernel", GaussianFourierSampler.LoadName, "Gaussian")] -[assembly: LoadableClass(typeof(LaplacianFourierSampler), typeof(LaplacianFourierSampler.Arguments), typeof(SignatureFourierDistributionSampler), +[assembly: LoadableClass(typeof(LaplacianFourierSampler), typeof(LaplacianFourierSampler.Options), typeof(SignatureFourierDistributionSampler), "Laplacian Kernel", LaplacianFourierSampler.RegistrationName, "Laplacian")] // This is for deserialization from a binary model file. @@ -46,7 +46,7 @@ public sealed class GaussianFourierSampler : IFourierDistributionSampler { private readonly IHost _host; - public class Arguments : IFourierDistributionSamplerFactory + public sealed class Options : IFourierDistributionSamplerFactory { [Argument(ArgumentType.AtMostOnce, HelpText = "gamma in the kernel definition: exp(-gamma*||x-y||^2 / r^2). r is an estimate of the average intra-example distance", ShortName = "g")] public float Gamma = 1; @@ -70,7 +70,7 @@ private static VersionInfo GetVersionInfo() private readonly float _gamma; - public GaussianFourierSampler(IHostEnvironment env, Arguments args, float avgDist) + public GaussianFourierSampler(IHostEnvironment env, Options args, float avgDist) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(LoadName); @@ -125,7 +125,7 @@ public float Next(Random rand) public sealed class LaplacianFourierSampler : IFourierDistributionSampler { - public class Arguments : IFourierDistributionSamplerFactory + public sealed class Options : IFourierDistributionSamplerFactory { [Argument(ArgumentType.AtMostOnce, HelpText = "a in the term exp(-a|x| / r). r is an estimate of the average intra-example L1 distance")] public float A = 1; @@ -150,7 +150,7 @@ private static VersionInfo GetVersionInfo() private readonly IHost _host; private readonly float _a; - public LaplacianFourierSampler(IHostEnvironment env, Arguments args, float avgDist) + public LaplacianFourierSampler(IHostEnvironment env, Options args, float avgDist) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(RegistrationName); diff --git a/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs index c1fb8b58d9..2663b37257 100644 --- a/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs @@ -132,30 +132,30 @@ private static IDataView Create(IHostEnvironment env, IDataView input, string ou } /// Factory method for SignatureDataTransform. - internal static IDataTransform Create(IHostEnvironment env, Options args, IDataView input) + internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); var h = env.Register("Categorical"); - h.CheckValue(args, nameof(args)); + h.CheckValue(options, nameof(options)); h.CheckValue(input, nameof(input)); - h.CheckUserArg(Utils.Size(args.Columns) > 0, nameof(args.Columns)); + h.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns)); var replaceCols = new List(); var naIndicatorCols = new List(); var naConvCols = new List(); var concatCols = new List(); var dropCols = new List(); - var tmpIsMissingColNames = input.Schema.GetTempColumnNames(args.Columns.Length, "IsMissing"); - var tmpReplaceColNames = input.Schema.GetTempColumnNames(args.Columns.Length, "Replace"); - for (int i = 0; i < args.Columns.Length; i++) + var tmpIsMissingColNames = input.Schema.GetTempColumnNames(options.Columns.Length, "IsMissing"); + var tmpReplaceColNames = input.Schema.GetTempColumnNames(options.Columns.Length, "Replace"); + for (int i = 0; i < options.Columns.Length; i++) { - var column = args.Columns[i]; + var column = options.Columns[i]; - var addInd = column.ConcatIndicator ?? args.Concat; + var addInd = column.ConcatIndicator ?? options.Concat; if (!addInd) { replaceCols.Add(new MissingValueReplacingEstimator.ColumnInfo(column.Name, column.Source, - (MissingValueReplacingEstimator.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot)); + (MissingValueReplacingEstimator.ColumnInfo.ReplacementMode)(column.Kind ?? options.ReplaceWith), column.ImputeBySlot ?? options.ImputeBySlot)); continue; } @@ -190,7 +190,7 @@ internal static IDataTransform Create(IHostEnvironment env, Options args, IDataV // Add the NAReplaceTransform column. replaceCols.Add(new MissingValueReplacingEstimator.ColumnInfo(tmpReplacementColName, column.Source, - (MissingValueReplacingEstimator.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot)); + (MissingValueReplacingEstimator.ColumnInfo.ReplacementMode)(column.Kind ?? options.ReplaceWith), column.ImputeBySlot ?? options.ImputeBySlot)); // Add the ConcatTransform column. if (replaceType is VectorType) diff --git a/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs b/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs index a07b854a2f..96be9fab0e 100644 --- a/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs +++ b/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs @@ -44,7 +44,7 @@ internal sealed class Options public int NewDim = RandomFourierFeaturizingEstimator.Defaults.NewDim; [Argument(ArgumentType.Multiple, HelpText = "Which kernel to use?", ShortName = "kernel", SignatureType = typeof(SignatureFourierDistributionSampler))] - public IComponentFactory MatrixGenerator = new GaussianFourierSampler.Arguments(); + public IComponentFactory MatrixGenerator = new GaussianFourierSampler.Options(); [Argument(ArgumentType.AtMostOnce, HelpText = "Create two features for every random Fourier frequency? (one for cos and one for sin)")] public bool UseSin = RandomFourierFeaturizingEstimator.Defaults.UseSin; @@ -666,7 +666,7 @@ public ColumnInfo(string name, int newDim, bool useSin, string inputColumnName = Contracts.CheckUserArg(newDim > 0, nameof(newDim), "must be positive."); InputColumnName = inputColumnName ?? name; Name = name; - Generator = generator ?? new GaussianFourierSampler.Arguments(); + Generator = generator ?? new GaussianFourierSampler.Options(); NewDim = newDim; UseSin = useSin; Seed = seed; diff --git a/test/Microsoft.ML.Benchmarks/RffTransform.cs b/test/Microsoft.ML.Benchmarks/RffTransform.cs index b94eb7930e..815f8a3f33 100644 --- a/test/Microsoft.ML.Benchmarks/RffTransform.cs +++ b/test/Microsoft.ML.Benchmarks/RffTransform.cs @@ -30,7 +30,7 @@ public void SetupTrainingSpeedTests() public void CV_Multiclass_Digits_RffTransform_OVAAveragedPerceptron() { var mlContext = new MLContext(); - var reader = mlContext.Data.CreateTextLoader(new TextLoader.Arguments + var reader = mlContext.Data.CreateTextLoader(new TextLoader.Options { Columns = new[] { diff --git a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs index c08118bee6..7e97160e96 100644 --- a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs +++ b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs @@ -77,7 +77,7 @@ private TransformerChain>(scoreColumn.Value.Index); - var c = new MultiAverage(Env, new MultiAverage.Arguments()).GetCombiner(); + var c = new MultiAverage(Env, new MultiAverage.Options()).GetCombiner(); VBuffer score = default(VBuffer); VBuffer[] score0 = new VBuffer[5]; VBuffer scoreSaved = default(VBuffer); @@ -2513,7 +2513,7 @@ public void TestInputBuilderComponentFactories() expected = FixWhitespace(expected); Assert.Equal(expected, json); - options.LossFunction = new HingeLoss.Arguments(); + options.LossFunction = new HingeLoss.Options(); result = inputBuilder.GetJsonObject(options, inputBindingMap, inputMap); json = FixWhitespace(result.ToString(Formatting.Indented)); @@ -2529,7 +2529,7 @@ public void TestInputBuilderComponentFactories() expected = FixWhitespace(expected); Assert.Equal(expected, json); - options.LossFunction = new HingeLoss.Arguments() { Margin = 2 }; + options.LossFunction = new HingeLoss.Options() { Margin = 2 }; result = inputBuilder.GetJsonObject(options, inputBindingMap, inputMap); json = FixWhitespace(result.ToString(Formatting.Indented)); diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs index faa1ba5f26..fff6e594f2 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs @@ -84,7 +84,7 @@ public void LogEventProcessesMessages() // create a dummy text reader to trigger log messages env.Data.CreateTextLoader( - new TextLoader.Arguments {Columns = new[] {new TextLoader.Column("TestColumn", null, 0)}}); + new TextLoader.Options {Columns = new[] {new TextLoader.Column("TestColumn", null, 0)}}); Assert.True(messages.Count > 0); } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestLoss.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestLoss.cs index ef03fdce16..8a10eb6051 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestLoss.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestLoss.cs @@ -64,8 +64,8 @@ public void LossHinge() [Fact] public void LossExponential() { - ExpLoss.Arguments args = new ExpLoss.Arguments(); - ExpLoss loss = new ExpLoss(args); + ExpLoss.Options options = new ExpLoss.Options(); + ExpLoss loss = new ExpLoss(options); TestHelper(loss, 1, 3, Math.Exp(-3), Math.Exp(-3)); TestHelper(loss, 0, 3, Math.Exp(3), -Math.Exp(3)); TestHelper(loss, 0, -3, Math.Exp(-3), -Math.Exp(-3)); diff --git a/test/Microsoft.ML.Functional.Tests/Validation.cs b/test/Microsoft.ML.Functional.Tests/Validation.cs index a1e675d248..a39bd14884 100644 --- a/test/Microsoft.ML.Functional.Tests/Validation.cs +++ b/test/Microsoft.ML.Functional.Tests/Validation.cs @@ -82,7 +82,7 @@ public void TrainWithValidationSet() var trainedModel = mlContext.Regression.Trainers.FastTree(new Trainers.FastTree.FastTreeRegressionTrainer.Options { NumTrees = 2, EarlyStoppingMetrics = 2, - EarlyStoppingRule = new GLEarlyStoppingCriterion.Arguments() + EarlyStoppingRule = new GLEarlyStoppingCriterion.Options() }) .Fit(trainData: preprocessedTrainData, validationData: preprocessedValidData); diff --git a/test/Microsoft.ML.Predictor.Tests/TestIniModels.cs b/test/Microsoft.ML.Predictor.Tests/TestIniModels.cs index 69ddb618ac..234ff0a7d5 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestIniModels.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestIniModels.cs @@ -522,7 +522,7 @@ public void TestGamRegressionIni() { var mlContext = new MLContext(seed: 0); var idv = mlContext.Data.CreateTextLoader( - new TextLoader.Arguments() + new TextLoader.Options() { HasHeader = false, Columns = new[] @@ -561,7 +561,7 @@ public void TestGamBinaryClassificationIni() { var mlContext = new MLContext(seed: 0); var idv = mlContext.Data.CreateTextLoader( - new TextLoader.Arguments() + new TextLoader.Options() { HasHeader = false, Columns = new[] diff --git a/test/Microsoft.ML.Sweeper.Tests/SweeperTest.cs b/test/Microsoft.ML.Sweeper.Tests/SweeperTest.cs index 077ed5104c..9f88d678d7 100644 --- a/test/Microsoft.ML.Sweeper.Tests/SweeperTest.cs +++ b/test/Microsoft.ML.Sweeper.Tests/SweeperTest.cs @@ -15,7 +15,7 @@ public void UniformRandomSweeperReturnsDistinctValuesWhenProposeSweep() var env = new MLContext(42); var sweeper = new UniformRandomSweeper(env, - new SweeperBase.ArgumentsBase(), + new SweeperBase.OptionsBase(), new[] { valueGenerator }); var results = sweeper.ProposeSweeps(3); @@ -32,7 +32,7 @@ public void RandomGridSweeperReturnsDistinctValuesWhenProposeSweep() var env = new MLContext(42); var sweeper = new RandomGridSweeper(env, - new RandomGridSweeper.Arguments(), + new RandomGridSweeper.Options(), new[] { valueGenerator }); var results = sweeper.ProposeSweeps(3); diff --git a/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs b/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs index 11bd4ae0db..07dd0db895 100644 --- a/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs +++ b/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs @@ -90,7 +90,7 @@ public void TestDiscreteValueSweep(double normalizedValue, string expected) public void TestRandomSweeper() { var env = new MLContext(42); - var args = new SweeperBase.ArgumentsBase() + var args = new SweeperBase.OptionsBase() { SweptParameters = new[] { ComponentFactoryUtils.CreateFromFunction( @@ -131,7 +131,7 @@ public void TestSimpleSweeperAsync() var random = new Random(42); var env = new MLContext(42); const int sweeps = 100; - var sweeper = new SimpleAsyncSweeper(env, new SweeperBase.ArgumentsBase + var sweeper = new SimpleAsyncSweeper(env, new SweeperBase.OptionsBase { SweptParameters = new IComponentFactory[] { ComponentFactoryUtils.CreateFromFunction( @@ -154,7 +154,7 @@ public void TestSimpleSweeperAsync() CheckAsyncSweeperResult(paramSets); // Test consumption without ever calling Update. - var gridArgs = new RandomGridSweeper.Arguments(); + var gridArgs = new RandomGridSweeper.Options(); gridArgs.SweptParameters = new IComponentFactory[] { ComponentFactoryUtils.CreateFromFunction( environ => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})), @@ -178,13 +178,13 @@ public void TestDeterministicSweeperAsyncCancellation() { var random = new Random(42); var env = new MLContext(42); - var args = new DeterministicSweeperAsync.Arguments(); + var args = new DeterministicSweeperAsync.Options(); args.BatchSize = 5; args.Relaxation = 1; args.Sweeper = ComponentFactoryUtils.CreateFromFunction( environ => new KdoSweeper(environ, - new KdoSweeper.Arguments() + new KdoSweeper.Options() { SweptParameters = new IComponentFactory[] { ComponentFactoryUtils.CreateFromFunction( @@ -228,13 +228,13 @@ public void TestDeterministicSweeperAsync() { var random = new Random(42); var env = new MLContext(42); - var args = new DeterministicSweeperAsync.Arguments(); + var args = new DeterministicSweeperAsync.Options(); args.BatchSize = 5; args.Relaxation = args.BatchSize - 1; args.Sweeper = ComponentFactoryUtils.CreateFromFunction( environ => new SmacSweeper(environ, - new SmacSweeper.Arguments() + new SmacSweeper.Options() { SweptParameters = new IComponentFactory[] { ComponentFactoryUtils.CreateFromFunction( @@ -300,13 +300,13 @@ public void TestDeterministicSweeperAsyncParallel() const int batchSize = 5; const int sweeps = 20; var paramSets = new List(); - var args = new DeterministicSweeperAsync.Arguments(); + var args = new DeterministicSweeperAsync.Options(); args.BatchSize = batchSize; args.Relaxation = batchSize - 2; args.Sweeper = ComponentFactoryUtils.CreateFromFunction( environ => new SmacSweeper(environ, - new SmacSweeper.Arguments() + new SmacSweeper.Options() { SweptParameters = new IComponentFactory[] { ComponentFactoryUtils.CreateFromFunction( @@ -352,7 +352,7 @@ public async Task TestNelderMeadSweeperAsync() const int batchSize = 5; const int sweeps = 40; var paramSets = new List(); - var args = new DeterministicSweeperAsync.Arguments(); + var args = new DeterministicSweeperAsync.Options(); args.BatchSize = batchSize; args.Relaxation = 0; @@ -366,12 +366,12 @@ public async Task TestNelderMeadSweeperAsync() innerEnviron => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) }; - var nelderMeadSweeperArgs = new NelderMeadSweeper.Arguments() + var nelderMeadSweeperArgs = new NelderMeadSweeper.Options() { SweptParameters = param, FirstBatchSweeper = ComponentFactoryUtils.CreateFromFunction( (firstBatchSweeperEnviron, firstBatchSweeperArgs) => - new RandomGridSweeper(environ, new RandomGridSweeper.Arguments() { SweptParameters = param })) + new RandomGridSweeper(environ, new RandomGridSweeper.Options() { SweptParameters = param })) }; return new NelderMeadSweeper(environ, nelderMeadSweeperArgs); @@ -427,7 +427,7 @@ private void CheckAsyncSweeperResult(List paramSets) public void TestRandomGridSweeper() { var env = new MLContext(42); - var args = new RandomGridSweeper.Arguments() + var args = new RandomGridSweeper.Options() { SweptParameters = new[] { ComponentFactoryUtils.CreateFromFunction( @@ -544,13 +544,13 @@ public void TestNelderMeadSweeper() environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) }; - var args = new NelderMeadSweeper.Arguments() + var args = new NelderMeadSweeper.Options() { SweptParameters = param, FirstBatchSweeper = ComponentFactoryUtils.CreateFromFunction( (environ, firstBatchArgs) => { - return new RandomGridSweeper(environ, new RandomGridSweeper.Arguments() { SweptParameters = param }); + return new RandomGridSweeper(environ, new RandomGridSweeper.Options() { SweptParameters = param }); } ) }; @@ -600,7 +600,7 @@ public void TestNelderMeadSweeperWithDefaultFirstBatchSweeper() environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) }; - var args = new NelderMeadSweeper.Arguments(); + var args = new NelderMeadSweeper.Options(); args.SweptParameters = param; var sweeper = new NelderMeadSweeper(env, args); var sweeps = sweeper.ProposeSweeps(5, new List()); @@ -642,7 +642,7 @@ public void TestSmacSweeper() var random = new Random(42); var env = new MLContext(42); const int maxInitSweeps = 5; - var args = new SmacSweeper.Arguments() + var args = new SmacSweeper.Options() { NumberInitialPopulation = 20, SweptParameters = new IComponentFactory[] { diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index 8528a51c1b..10caa9c6de 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -819,7 +819,7 @@ public void SavePipeWithKey() pipe => { - var argsText = new TextLoader.Arguments(); + var argsText = new TextLoader.Options(); bool tmp = CmdParser.ParseArguments(Env, " header=+" + " col=Label:TX:0" + diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs index 5da719472a..b1a4ad34c9 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs @@ -403,7 +403,7 @@ protected bool SaveLoadText(IDataView view, IHostEnvironment env, view = new ChooseColumnsByIndexTransform(env, chooseargs, view); } - var args = new TextLoader.Arguments(); + var args = new TextLoader.Options(); if (!CmdParser.ParseArguments(Env, argsLoader, args)) { Fail("Couldn't parse the args '{0}' in '{1}'", argsLoader, pathData); diff --git a/test/Microsoft.ML.Tests/AnomalyDetectionTests.cs b/test/Microsoft.ML.Tests/AnomalyDetectionTests.cs index 251b3cc611..f9aba9550b 100644 --- a/test/Microsoft.ML.Tests/AnomalyDetectionTests.cs +++ b/test/Microsoft.ML.Tests/AnomalyDetectionTests.cs @@ -32,7 +32,7 @@ public void RandomizedPcaTrainerBaselineTest() var mlContext = new MLContext(seed: 1, conc: 1); string featureColumn = "NumericFeatures"; - var reader = new TextLoader(Env, new TextLoader.Arguments() + var reader = new TextLoader(Env, new TextLoader.Options() { HasHeader = true, Separator = "\t", diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs index b2f1c26155..2bf7129e61 100644 --- a/test/Microsoft.ML.Tests/ImagesTests.cs +++ b/test/Microsoft.ML.Tests/ImagesTests.cs @@ -29,7 +29,7 @@ public void TestEstimatorChain() var env = new MLContext(); var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + var data = TextLoader.Create(env, new TextLoader.Options() { Columns = new[] { @@ -37,7 +37,7 @@ public void TestEstimatorChain() new TextLoader.Column("Name", DataKind.TX, 1), } }, new MultiFileSource(dataFile)); - var invalidData = TextLoader.Create(env, new TextLoader.Arguments() + var invalidData = TextLoader.Create(env, new TextLoader.Options() { Columns = new[] { @@ -60,7 +60,7 @@ public void TestEstimatorSaveLoad() IHostEnvironment env = new MLContext(); var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + var data = TextLoader.Create(env, new TextLoader.Options() { Columns = new[] { @@ -99,7 +99,7 @@ public void TestSaveImages() var env = new MLContext(); var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + var data = TextLoader.Create(env, new TextLoader.Options() { Columns = new[] { @@ -138,7 +138,7 @@ public void TestGreyscaleTransformImages() var imageWidth = 100; var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + var data = TextLoader.Create(env, new TextLoader.Options() { Columns = new[] { @@ -189,7 +189,7 @@ public void TestBackAndForthConversionWithAlphaInterleave() const int imageWidth = 130; var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + var data = TextLoader.Create(env, new TextLoader.Options() { Columns = new[] { @@ -256,7 +256,7 @@ public void TestBackAndForthConversionWithoutAlphaInterleave() const int imageWidth = 130; var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + var data = TextLoader.Create(env, new TextLoader.Options() { Columns = new[] { @@ -323,7 +323,7 @@ public void TestBackAndForthConversionWithAlphaNoInterleave() const int imageWidth = 130; var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + var data = TextLoader.Create(env, new TextLoader.Options() { Columns = new[] { @@ -390,7 +390,7 @@ public void TestBackAndForthConversionWithoutAlphaNoInterleave() const int imageWidth = 130; var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + var data = TextLoader.Create(env, new TextLoader.Options() { Columns = new[] { @@ -457,7 +457,7 @@ public void TestBackAndForthConversionWithAlphaInterleaveNoOffset() const int imageWidth = 130; var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + var data = TextLoader.Create(env, new TextLoader.Options() { Columns = new[] { @@ -523,7 +523,7 @@ public void TestBackAndForthConversionWithoutAlphaInterleaveNoOffset() const int imageWidth = 130; var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + var data = TextLoader.Create(env, new TextLoader.Options() { Columns = new[] { @@ -589,7 +589,7 @@ public void TestBackAndForthConversionWithAlphaNoInterleaveNoOffset() var imageWidth = 130; var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + var data = TextLoader.Create(env, new TextLoader.Options() { Columns = new[] { @@ -655,7 +655,7 @@ public void TestBackAndForthConversionWithoutAlphaNoInterleaveNoOffset() const int imageWidth = 130; var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + var data = TextLoader.Create(env, new TextLoader.Options() { Columns = new[] { @@ -718,7 +718,7 @@ public void ImageResizerTransformResizingModeFill() var env = new MLContext(); var dataFile = GetDataPath("images/fillmode.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + var data = TextLoader.Create(env, new TextLoader.Options() { Columns = new[] { diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs index 817b8c5f3a..3a6c643bef 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs @@ -171,7 +171,7 @@ public void TrainAveragedPerceptronWithCache() { var mlContext = new MLContext(0); var dataFile = GetDataPath("breast-cancer.txt"); - var loader = TextLoader.Create(mlContext, new TextLoader.Arguments(), new MultiFileSource(dataFile)); + var loader = TextLoader.Create(mlContext, new TextLoader.Options(), new MultiFileSource(dataFile)); var globalCounter = 0; var xf = LambdaTransform.CreateFilter(mlContext, loader, (i, s) => true, diff --git a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs index 51a37eccd7..e13d1171cd 100644 --- a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs +++ b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs @@ -19,7 +19,7 @@ public void OvaLogisticRegression() // Create a new context for ML.NET operations. It can be used for exception tracking and logging, // as a catalog of available operations and as the source of randomness. var mlContext = new MLContext(seed: 1); - var reader = new TextLoader(mlContext, new TextLoader.Arguments() + var reader = new TextLoader(mlContext, new TextLoader.Options() { Columns = new[] { @@ -51,7 +51,7 @@ public void OvaAveragedPerceptron() // Create a new context for ML.NET operations. It can be used for exception tracking and logging, // as a catalog of available operations and as the source of randomness. var mlContext = new MLContext(seed: 1); - var reader = new TextLoader(mlContext, new TextLoader.Arguments() + var reader = new TextLoader(mlContext, new TextLoader.Options() { Columns = new[] { @@ -84,7 +84,7 @@ public void OvaFastTree() // Create a new context for ML.NET operations. It can be used for exception tracking and logging, // as a catalog of available operations and as the source of randomness. var mlContext = new MLContext(seed: 1); - var reader = new TextLoader(mlContext, new TextLoader.Arguments() + var reader = new TextLoader(mlContext, new TextLoader.Options() { Columns = new[] { @@ -117,7 +117,7 @@ public void OvaLinearSvm() // Create a new context for ML.NET operations. It can be used for exception tracking and logging, // as a catalog of available operations and as the source of randomness. var mlContext = new MLContext(seed: 1); - var reader = new TextLoader(mlContext, new TextLoader.Arguments() + var reader = new TextLoader(mlContext, new TextLoader.Options() { Columns = new[] { diff --git a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs index 8d615d9330..28e6b56af7 100644 --- a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs @@ -24,7 +24,7 @@ public void TensorFlowTransforCifarEndToEndTest() var imageFolder = Path.GetDirectoryName(dataFile); var mlContext = new MLContext(seed: 1, conc: 1); - var data = TextLoader.Create(mlContext, new TextLoader.Arguments() + var data = TextLoader.Create(mlContext, new TextLoader.Options() { Columns = new[] { diff --git a/test/Microsoft.ML.Tests/TermEstimatorTests.cs b/test/Microsoft.ML.Tests/TermEstimatorTests.cs index fad2b231d2..4cc8f02f92 100644 --- a/test/Microsoft.ML.Tests/TermEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/TermEstimatorTests.cs @@ -55,7 +55,7 @@ void TestDifferentTypes() { string dataPath = GetDataPath("adult.tiny.with-schema.txt"); - var loader = new TextLoader(ML, new TextLoader.Arguments + var loader = new TextLoader(ML, new TextLoader.Options { Columns = new[]{ new TextLoader.Column("float1", DataKind.R4, 9), diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs b/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs index 93f2bab7f0..359e44b7e1 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs @@ -62,9 +62,9 @@ public void FieldAwareFactorizationMachine_Estimator() Done(); } - private TextLoader.Arguments GetFafmBCLoaderArgs() + private TextLoader.Options GetFafmBCLoaderArgs() { - return new TextLoader.Arguments() + return new TextLoader.Options() { Separator = "\t", HasHeader = false, diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs index e757027e7b..6d2243de8f 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs @@ -143,9 +143,9 @@ public void MatrixFactorizationSimpleTrainAndPredict() var modelWithValidation = pipeline.Fit(data, testData); } - private TextLoader.Arguments GetLoaderArgs(string labelColumnName, string matrixColumnIndexColumnName, string matrixRowIndexColumnName) + private TextLoader.Options GetLoaderArgs(string labelColumnName, string matrixColumnIndexColumnName, string matrixRowIndexColumnName) { - return new TextLoader.Arguments() + return new TextLoader.Options() { Separator = "\t", HasHeader = true, diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/PriorRandomTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/PriorRandomTests.cs index b862809cc8..84a71dadfb 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/PriorRandomTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/PriorRandomTests.cs @@ -15,7 +15,7 @@ public partial class TrainerEstimators private IDataView GetBreastCancerDataviewWithTextColumns() { return new TextLoader(Env, - new TextLoader.Arguments() + new TextLoader.Options() { HasHeader = true, Columns = new[] diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs index 9050f2b454..31356057a6 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs @@ -28,7 +28,7 @@ public void PCATrainerEstimator() { string featureColumn = "NumericFeatures"; - var reader = new TextLoader(Env, new TextLoader.Arguments() + var reader = new TextLoader(Env, new TextLoader.Options() { HasHeader = true, Separator = "\t", @@ -56,7 +56,7 @@ public void KMeansEstimator() string featureColumn = "NumericFeatures"; string weights = "Weights"; - var reader = new TextLoader(Env, new TextLoader.Arguments + var reader = new TextLoader(Env, new TextLoader.Options { HasHeader = true, Separator = "\t", @@ -159,7 +159,7 @@ public void TestEstimatorMultiClassNaiveBayesTrainer() private (IEstimator, IDataView) GetBinaryClassificationPipeline() { var data = new TextLoader(Env, - new TextLoader.Arguments() + new TextLoader.Options() { Separator = "\t", HasHeader = true, @@ -179,7 +179,7 @@ public void TestEstimatorMultiClassNaiveBayesTrainer() private (IEstimator, IDataView) GetRankingPipeline() { - var data = new TextLoader(Env, new TextLoader.Arguments + var data = new TextLoader(Env, new TextLoader.Options { HasHeader = true, Separator = "\t", @@ -202,7 +202,7 @@ public void TestEstimatorMultiClassNaiveBayesTrainer() private IDataView GetRegressionPipeline() { return new TextLoader(Env, - new TextLoader.Arguments() + new TextLoader.Options() { Separator = ";", HasHeader = true, @@ -214,9 +214,9 @@ private IDataView GetRegressionPipeline() }).Read(GetDataPath(TestDatasets.generatedRegressionDatasetmacro.trainFilename)); } - private TextLoader.Arguments GetIrisLoaderArgs() + private TextLoader.Options GetIrisLoaderArgs() { - return new TextLoader.Arguments() + return new TextLoader.Options() { Separator = "comma", HasHeader = true, @@ -230,7 +230,7 @@ private TextLoader.Arguments GetIrisLoaderArgs() private (IEstimator, IDataView) GetMultiClassPipeline() { - var data = new TextLoader(Env, new TextLoader.Arguments() + var data = new TextLoader(Env, new TextLoader.Options() { Separator = "comma", Columns = new[] diff --git a/test/Microsoft.ML.Tests/Transformers/ConcatTests.cs b/test/Microsoft.ML.Tests/Transformers/ConcatTests.cs index a8ad138588..1a3614d25b 100644 --- a/test/Microsoft.ML.Tests/Transformers/ConcatTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/ConcatTests.cs @@ -26,7 +26,7 @@ void TestConcat() string dataPath = GetDataPath("adult.tiny.with-schema.txt"); var source = new MultiFileSource(dataPath); - var loader = new TextLoader(ML, new TextLoader.Arguments + var loader = new TextLoader(ML, new TextLoader.Options { Columns = new[]{ new TextLoader.Column("float1", DataKind.R4, 9), @@ -84,7 +84,7 @@ public void ConcatWithAliases() string dataPath = GetDataPath("adult.tiny.with-schema.txt"); var source = new MultiFileSource(dataPath); - var loader = new TextLoader(ML, new TextLoader.Arguments + var loader = new TextLoader(ML, new TextLoader.Options { Columns = new[]{ new TextLoader.Column("float1", DataKind.R4, 9), diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToValueTests.cs b/test/Microsoft.ML.Tests/Transformers/KeyToValueTests.cs index 6002af15f6..55ccb06de1 100644 --- a/test/Microsoft.ML.Tests/Transformers/KeyToValueTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/KeyToValueTests.cs @@ -26,7 +26,7 @@ public void KeyToValueWorkout() { string dataPath = GetDataPath("iris.txt"); - var reader = new TextLoader(Env, new TextLoader.Arguments + var reader = new TextLoader(Env, new TextLoader.Options { Columns = new[] { diff --git a/test/Microsoft.ML.Tests/Transformers/NormalizerTests.cs b/test/Microsoft.ML.Tests/Transformers/NormalizerTests.cs index 873bc5333e..8d5fa5bd88 100644 --- a/test/Microsoft.ML.Tests/Transformers/NormalizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/NormalizerTests.cs @@ -32,7 +32,7 @@ public void NormalizerWorkout() { string dataPath = GetDataPath(TestDatasets.iris.trainFilename); - var loader = new TextLoader(Env, new TextLoader.Arguments + var loader = new TextLoader(Env, new TextLoader.Options { Columns = new[] { new TextLoader.Column("float1", DataKind.R4, 1), @@ -97,7 +97,7 @@ public void NormalizerParameters() { string dataPath = GetDataPath("iris.txt"); - var loader = new TextLoader(Env, new TextLoader.Arguments + var loader = new TextLoader(Env, new TextLoader.Options { Columns = new[] { new TextLoader.Column("float1", DataKind.R4, 1), @@ -214,7 +214,7 @@ public void SimpleConstructorsAndExtensions() { string dataPath = GetDataPath(TestDatasets.iris.trainFilename); - var loader = new TextLoader(Env, new TextLoader.Arguments + var loader = new TextLoader(Env, new TextLoader.Options { Columns = new[] { new TextLoader.Column("Label", DataKind.R4, 0), @@ -479,7 +479,7 @@ public void TestGcnNormOldSavingAndLoading() void TestNormalizeBackCompatibility() { var dataFile = GetDataPath("breast-cancer.txt"); - var dataView = TextLoader.Create(ML, new TextLoader.Arguments(), new MultiFileSource(dataFile)); + var dataView = TextLoader.Create(ML, new TextLoader.Options(), new MultiFileSource(dataFile)); string chooseModelPath = GetDataPath("backcompat/ap_with_norm.zip"); using (FileStream fs = File.OpenRead(chooseModelPath)) { diff --git a/test/Microsoft.ML.Tests/Transformers/RffTests.cs b/test/Microsoft.ML.Tests/Transformers/RffTests.cs index d0912e846e..c58e56b4d7 100644 --- a/test/Microsoft.ML.Tests/Transformers/RffTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/RffTests.cs @@ -51,11 +51,11 @@ public void RffWorkout() var invalidData = ML.Data.ReadFromEnumerable(new[] { new TestClassInvalidSchema { A = 1 }, new TestClassInvalidSchema { A = 1 } }); var validFitInvalidData = ML.Data.ReadFromEnumerable(new[] { new TestClassBiggerSize { A = new float[200] }, new TestClassBiggerSize { A = new float[200] } }); var dataView = ML.Data.ReadFromEnumerable(data); - var generator = new GaussianFourierSampler.Arguments(); + var generator = new GaussianFourierSampler.Options(); var pipe = ML.Transforms.Projection.CreateRandomFourierFeatures(new[]{ new RandomFourierFeaturizingEstimator.ColumnInfo("RffA", 5, false, "A"), - new RandomFourierFeaturizingEstimator.ColumnInfo("RffB", 10, true, "A", new LaplacianFourierSampler.Arguments()) + new RandomFourierFeaturizingEstimator.ColumnInfo("RffB", 10, true, "A", new LaplacianFourierSampler.Options()) }); TestEstimatorCore(pipe, dataView, invalidInput: invalidData, validForFitNotValidForTransformInput: validFitInvalidData); @@ -105,7 +105,7 @@ public void TestOldSavingAndLoading() var est = ML.Transforms.Projection.CreateRandomFourierFeatures(new[]{ new RandomFourierFeaturizingEstimator.ColumnInfo("RffA", 5, false, "A"), - new RandomFourierFeaturizingEstimator.ColumnInfo("RffB", 10, true, "A", new LaplacianFourierSampler.Arguments()) + new RandomFourierFeaturizingEstimator.ColumnInfo("RffB", 10, true, "A", new LaplacianFourierSampler.Options()) }); var result = est.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); diff --git a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs index e5372d964a..aadf328c8b 100644 --- a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs @@ -164,7 +164,7 @@ public void StopWordsRemoverFromFactory() { var factory = new PredefinedStopWordsRemoverFactory(); string sentimentDataPath = GetDataPath("wikipedia-detox-250-line-data.tsv"); - var data = TextLoader.Create(ML, new TextLoader.Arguments() + var data = TextLoader.Create(ML, new TextLoader.Options() { Columns = new[] { diff --git a/test/Microsoft.ML.Tests/Transformers/WordEmbeddingsTests.cs b/test/Microsoft.ML.Tests/Transformers/WordEmbeddingsTests.cs index ab396485a6..8f877eba5d 100644 --- a/test/Microsoft.ML.Tests/Transformers/WordEmbeddingsTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/WordEmbeddingsTests.cs @@ -24,7 +24,7 @@ public void TestWordEmbeddings() { var dataPath = GetDataPath(TestDatasets.Sentiment.trainFilename); var data = new TextLoader(ML, - new TextLoader.Arguments() + new TextLoader.Options() { Separator = "\t", HasHeader = true, @@ -59,7 +59,7 @@ public void TestCustomWordEmbeddings() { var dataPath = GetDataPath(TestDatasets.Sentiment.trainFilename); var data = new TextLoader(ML, - new TextLoader.Arguments() + new TextLoader.Options() { Separator = "\t", HasHeader = true,