From 0a3db35ba378440afa462d3e6099becadc7aca00 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Wed, 30 May 2018 14:55:03 -0700 Subject: [PATCH] small code cleanup --- .../EntryPoints/ModuleCatalog.cs | 6 +-- .../EntryPoints/ScoreModel.cs | 5 +- .../AutoInference.cs | 2 +- .../AutoMlUtils.cs | 6 +-- .../PipelinePattern.cs | 4 +- .../EntryPoints/CrossValidationBinaryMacro.cs | 1 - .../EntryPoints/CrossValidationMacro.cs | 5 +- .../Runtime/EntryPoints/ImportTextData.cs | 2 +- .../Runtime/EntryPoints/TrainTestMacro.cs | 6 +-- .../Internal/Tools/CSharpApiGenerator.cs | 50 +++++++++---------- 10 files changed, 41 insertions(+), 46 deletions(-) diff --git a/src/Microsoft.ML.Core/EntryPoints/ModuleCatalog.cs b/src/Microsoft.ML.Core/EntryPoints/ModuleCatalog.cs index 586f6a4b02..498a75c9e5 100644 --- a/src/Microsoft.ML.Core/EntryPoints/ModuleCatalog.cs +++ b/src/Microsoft.ML.Core/EntryPoints/ModuleCatalog.cs @@ -51,7 +51,7 @@ public sealed class EntryPointInfo public readonly Type[] OutputKinds; public readonly ObsoleteAttribute ObsoleteAttribute; - internal EntryPointInfo(IExceptionContext ectx, MethodInfo method, + internal EntryPointInfo(IExceptionContext ectx, MethodInfo method, TlcModule.EntryPointAttribute attribute, ObsoleteAttribute obsoleteAttribute) { Contracts.AssertValueOrNull(ectx); @@ -187,7 +187,7 @@ private ModuleCatalog(IExceptionContext ectx) if (attr == null) continue; - var info = new EntryPointInfo(ectx, methodInfo, attr, + var info = new EntryPointInfo(ectx, methodInfo, attr, methodInfo.GetCustomAttributes(typeof(ObsoleteAttribute), false).FirstOrDefault() as ObsoleteAttribute); entryPoints.Add(info); @@ -315,7 +315,7 @@ public bool TryFindComponent(Type interfaceType, Type argumentType, out Componen Contracts.CheckParam(interfaceType.IsInterface, nameof(interfaceType), "Must be interface"); Contracts.CheckValue(argumentType, nameof(argumentType)); - component = _components.FirstOrDefault(x => x.InterfaceType == interfaceType && x.ArgumentType == argumentType); + component = _components.FirstOrDefault(x => x.InterfaceType == interfaceType && x.ArgumentType == argumentType); return component != null; } diff --git a/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs b/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs index de0a8208b0..96ce0acac9 100644 --- a/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs +++ b/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs @@ -72,11 +72,8 @@ public static Output Score(IHostEnvironment env, Input input) host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - - IPredictor predictor; var inputData = input.Data; - RoleMappedData data; - input.PredictorModel.PrepareData(host, inputData, out data, out predictor); + input.PredictorModel.PrepareData(host, inputData, out RoleMappedData data, out IPredictor predictor); IDataView scoredPipe; using (var ch = host.Start("Creating scoring pipeline")) diff --git a/src/Microsoft.ML.PipelineInference/AutoInference.cs b/src/Microsoft.ML.PipelineInference/AutoInference.cs index 642ff4d0d7..a8681da559 100644 --- a/src/Microsoft.ML.PipelineInference/AutoInference.cs +++ b/src/Microsoft.ML.PipelineInference/AutoInference.cs @@ -353,7 +353,7 @@ private void ProcessPipeline(Sweeper.Algorithms.SweeperProbabilityUtils utils, S testMetricVal += 1e-10; // Save performance score - candidate.PerformanceSummary = + candidate.PerformanceSummary = new RunSummary(testMetricVal, randomizedNumberOfRows, stopwatch.ElapsedMilliseconds, trainMetricVal); _sortedSampledElements.Add(candidate.PerformanceSummary.MetricValue, candidate); _history.Add(candidate); diff --git a/src/Microsoft.ML.PipelineInference/AutoMlUtils.cs b/src/Microsoft.ML.PipelineInference/AutoMlUtils.cs index a0aae16a63..6aec714618 100644 --- a/src/Microsoft.ML.PipelineInference/AutoMlUtils.cs +++ b/src/Microsoft.ML.PipelineInference/AutoMlUtils.cs @@ -15,7 +15,7 @@ namespace Microsoft.ML.Runtime.PipelineInference { public static class AutoMlUtils { - public static double ExtractValueFromIDV(IHostEnvironment env, IDataView result, string columnName) + public static double ExtractValueFromIdv(IHostEnvironment env, IDataView result, string columnName) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(result, nameof(result)); @@ -40,8 +40,8 @@ public static double ExtractValueFromIDV(IHostEnvironment env, IDataView result, public static AutoInference.RunSummary ExtractRunSummary(IHostEnvironment env, IDataView result, string metricColumnName, IDataView trainResult = null) { - double testingMetricValue = ExtractValueFromIDV(env, result, metricColumnName); - double trainingMetricValue = trainResult != null ? ExtractValueFromIDV(env, trainResult, metricColumnName) : double.MinValue; + double testingMetricValue = ExtractValueFromIdv(env, result, metricColumnName); + double trainingMetricValue = trainResult != null ? ExtractValueFromIdv(env, trainResult, metricColumnName) : double.MinValue; return new AutoInference.RunSummary(testingMetricValue, 0, 0, trainingMetricValue); } diff --git a/src/Microsoft.ML.PipelineInference/PipelinePattern.cs b/src/Microsoft.ML.PipelineInference/PipelinePattern.cs index c6b4de44fa..6fdf922a68 100644 --- a/src/Microsoft.ML.PipelineInference/PipelinePattern.cs +++ b/src/Microsoft.ML.PipelineInference/PipelinePattern.cs @@ -213,8 +213,8 @@ public void RunTrainTestExperiment(IDataView trainData, IDataView testData, var dataOut = experiment.GetOutput(trainTestOutput.OverallMetrics); var dataOutTraining = experiment.GetOutput(trainTestOutput.TrainingOverallMetrics); - testMetricValue = AutoMlUtils.ExtractValueFromIDV(_env, dataOut, metric.Name); - trainMetricValue = AutoMlUtils.ExtractValueFromIDV(_env, dataOutTraining, metric.Name); + testMetricValue = AutoMlUtils.ExtractValueFromIdv(_env, dataOut, metric.Name); + trainMetricValue = AutoMlUtils.ExtractValueFromIdv(_env, dataOutTraining, metric.Name); } public static PipelineResultRow[] ExtractResults(IHostEnvironment env, IDataView data, diff --git a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationBinaryMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationBinaryMacro.cs index fca8d3ac5b..cdf52cf076 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationBinaryMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationBinaryMacro.cs @@ -256,7 +256,6 @@ public static ArrayITransformModelOutput MakeArray(IHostEnvironment env, ArrayIT return result; } - public sealed class ArrayIDataViewInput { [Argument(ArgumentType.Required, HelpText = "The data sets", SortOrder = 1)] diff --git a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs index 569d0b3571..325b8a093a 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs @@ -31,7 +31,7 @@ public sealed class SubGraphOutput { [Argument(ArgumentType.AtMostOnce, HelpText = "The predictor model", SortOrder = 1)] public Var PredictorModel; - + [Argument(ArgumentType.AtMostOnce, HelpText = "The transform model", SortOrder = 2)] public Var TransformModel; } @@ -104,7 +104,6 @@ public sealed class Output public IDataView ConfusionMatrix; } - public sealed class CombineMetricsInput { [Argument(ArgumentType.Multiple, HelpText = "Overall metrics datasets", SortOrder = 1)] @@ -219,7 +218,7 @@ public static CommonOutputs.MacroOutput CrossValidate( } else args.Outputs.TransformModel = null; - + // Set train/test trainer kind to match. args.Kind = input.Kind; diff --git a/src/Microsoft.ML/Runtime/EntryPoints/ImportTextData.cs b/src/Microsoft.ML/Runtime/EntryPoints/ImportTextData.cs index 41048000d8..a3913ab160 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/ImportTextData.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/ImportTextData.cs @@ -64,7 +64,7 @@ public static Output TextLoader(IHostEnvironment env, LoaderInput input) var host = env.Register("ImportTextData"); env.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - var loader = host.CreateLoader(input.Arguments, new FileHandleSource(input.InputFile)); + var loader = host.CreateLoader(input.Arguments, new FileHandleSource(input.InputFile)); return new Output { Data = loader }; } } diff --git a/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs index edd4cf6e5b..c3c06b1031 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs @@ -26,7 +26,7 @@ public sealed class SubGraphOutput { [Argument(ArgumentType.AtMostOnce, HelpText = "The predictor model", SortOrder = 1)] public Var PredictorModel; - + [Argument(ArgumentType.AtMostOnce, HelpText = "Transform model", SortOrder = 2)] public Var TransformModel; } @@ -130,7 +130,7 @@ public static CommonOutputs.MacroOutput TrainTest( if (!subGraphRunContext.TryGetVariable(varName, out dataVariable)) throw env.Except($"Invalid variable name '{varName}'."); - string outputVarName = input.Outputs.PredictorModel == null ? node.GetOutputVariableName(nameof(Output.TransformModel)) : + string outputVarName = input.Outputs.PredictorModel == null ? node.GetOutputVariableName(nameof(Output.TransformModel)) : node.GetOutputVariableName(nameof(Output.PredictorModel)); foreach (var subGraphNode in subGraphNodes) @@ -249,7 +249,7 @@ public static CommonOutputs.MacroOutput TrainTest( var evalInputOutputTraining = MacroUtils.GetEvaluatorInputOutput(input.Kind, settings); var evalNodeTraining = evalInputOutputTraining.Item1; var evalOutputTraining = evalInputOutputTraining.Item2; - evalNodeTraining.Data.VarName = input.Outputs.PredictorModel == null ? datasetTransformNodeTrainingOutput.OutputData.VarName : + evalNodeTraining.Data.VarName = input.Outputs.PredictorModel == null ? datasetTransformNodeTrainingOutput.OutputData.VarName : scoreNodeTrainingOutput.ScoredData.VarName; if (node.OutputMap.TryGetValue(nameof(Output.TrainingWarnings), out outVariableName)) diff --git a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs index 7f5114b185..c893f469ae 100644 --- a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs +++ b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs @@ -408,7 +408,7 @@ public static string GetJsonFromField(string fieldName, Type fieldType) private readonly string _regenerate; private readonly HashSet _excludedSet; private const string RegistrationName = "CSharpApiGenerator"; - public Dictionary _typesSymbolTable = new Dictionary(); + public Dictionary TypesSymbolTable = new Dictionary(); public CSharpApiGenerator(IHostEnvironment env, Arguments args, string regenerate) { @@ -612,7 +612,7 @@ private void GenerateEnums(IndentingTextWriter writer, Type inputType, string cu if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Optional<>)) type = type.GetGenericArguments()[0]; - if (_typesSymbolTable.ContainsKey(type.FullName)) + if (TypesSymbolTable.ContainsKey(type.FullName)) continue; if (!type.IsEnum) @@ -625,13 +625,13 @@ private void GenerateEnums(IndentingTextWriter writer, Type inputType, string cu var enumType = Enum.GetUnderlyingType(type); - _typesSymbolTable[type.FullName] = GetSymbolFromType(_typesSymbolTable, type, currentNamespace); + TypesSymbolTable[type.FullName] = GetSymbolFromType(TypesSymbolTable, type, currentNamespace); if (enumType == typeof(int)) - writer.WriteLine($"public enum {_typesSymbolTable[type.FullName].Substring(_typesSymbolTable[type.FullName].LastIndexOf('.') + 1)}"); + writer.WriteLine($"public enum {TypesSymbolTable[type.FullName].Substring(TypesSymbolTable[type.FullName].LastIndexOf('.') + 1)}"); else { Contracts.Assert(enumType == typeof(byte)); - writer.WriteLine($"public enum {_typesSymbolTable[type.FullName].Substring(_typesSymbolTable[type.FullName].LastIndexOf('.') + 1)} : byte"); + writer.WriteLine($"public enum {TypesSymbolTable[type.FullName].Substring(TypesSymbolTable[type.FullName].LastIndexOf('.') + 1)} : byte"); } writer.Write("{"); @@ -707,19 +707,19 @@ private void GenerateStructs(IndentingTextWriter writer, if (typeEnum != TlcModule.DataKind.Unknown) continue; - if (_typesSymbolTable.ContainsKey(type.FullName)) + if (TypesSymbolTable.ContainsKey(type.FullName)) continue; - _typesSymbolTable[type.FullName] = GetSymbolFromType(_typesSymbolTable, type, currentNamespace); + TypesSymbolTable[type.FullName] = GetSymbolFromType(TypesSymbolTable, type, currentNamespace); string classBase = ""; if (type.IsSubclassOf(typeof(OneToOneColumn))) - classBase = $" : OneToOneColumn<{_typesSymbolTable[type.FullName].Substring(_typesSymbolTable[type.FullName].LastIndexOf('.') + 1)}>, IOneToOneColumn"; + classBase = $" : OneToOneColumn<{TypesSymbolTable[type.FullName].Substring(TypesSymbolTable[type.FullName].LastIndexOf('.') + 1)}>, IOneToOneColumn"; else if (type.IsSubclassOf(typeof(ManyToOneColumn))) - classBase = $" : ManyToOneColumn<{_typesSymbolTable[type.FullName].Substring(_typesSymbolTable[type.FullName].LastIndexOf('.') + 1)}>, IManyToOneColumn"; - writer.WriteLine($"public sealed partial class {_typesSymbolTable[type.FullName].Substring(_typesSymbolTable[type.FullName].LastIndexOf('.') + 1)}{classBase}"); + classBase = $" : ManyToOneColumn<{TypesSymbolTable[type.FullName].Substring(TypesSymbolTable[type.FullName].LastIndexOf('.') + 1)}>, IManyToOneColumn"; + writer.WriteLine($"public sealed partial class {TypesSymbolTable[type.FullName].Substring(TypesSymbolTable[type.FullName].LastIndexOf('.') + 1)}{classBase}"); writer.WriteLine("{"); writer.Indent(); - GenerateInputFields(writer, type, catalog, _typesSymbolTable); + GenerateInputFields(writer, type, catalog, TypesSymbolTable); writer.Outdent(); writer.WriteLine("}"); writer.WriteLine(); @@ -858,12 +858,12 @@ private void GenerateColumnAddMethods(IndentingTextWriter writer, writer.Indent(); if (isArray) { - writer.WriteLine($"var list = {fieldName} == null ? new List<{_typesSymbolTable[type.FullName]}>() : new List<{_typesSymbolTable[type.FullName]}>({fieldName});"); - writer.WriteLine($"list.Add(OneToOneColumn<{_typesSymbolTable[type.FullName]}>.Create(source));"); + writer.WriteLine($"var list = {fieldName} == null ? new List<{TypesSymbolTable[type.FullName]}>() : new List<{TypesSymbolTable[type.FullName]}>({fieldName});"); + writer.WriteLine($"list.Add(OneToOneColumn<{TypesSymbolTable[type.FullName]}>.Create(source));"); writer.WriteLine($"{fieldName} = list.ToArray();"); } else - writer.WriteLine($"{fieldName} = OneToOneColumn<{_typesSymbolTable[type.FullName]}>.Create(source);"); + writer.WriteLine($"{fieldName} = OneToOneColumn<{TypesSymbolTable[type.FullName]}>.Create(source);"); writer.Outdent(); writer.WriteLine("}"); writer.WriteLine(); @@ -872,12 +872,12 @@ private void GenerateColumnAddMethods(IndentingTextWriter writer, writer.Indent(); if (isArray) { - writer.WriteLine($"var list = {fieldName} == null ? new List<{_typesSymbolTable[type.FullName]}>() : new List<{_typesSymbolTable[type.FullName]}>({fieldName});"); - writer.WriteLine($"list.Add(OneToOneColumn<{_typesSymbolTable[type.FullName]}>.Create(name, source));"); + writer.WriteLine($"var list = {fieldName} == null ? new List<{TypesSymbolTable[type.FullName]}>() : new List<{TypesSymbolTable[type.FullName]}>({fieldName});"); + writer.WriteLine($"list.Add(OneToOneColumn<{TypesSymbolTable[type.FullName]}>.Create(name, source));"); writer.WriteLine($"{fieldName} = list.ToArray();"); } else - writer.WriteLine($"{fieldName} = OneToOneColumn<{_typesSymbolTable[type.FullName]}>.Create(name, source);"); + writer.WriteLine($"{fieldName} = OneToOneColumn<{TypesSymbolTable[type.FullName]}>.Create(name, source);"); writer.Outdent(); writer.WriteLine("}"); writer.WriteLine(); @@ -905,12 +905,12 @@ private void GenerateColumnAddMethods(IndentingTextWriter writer, writer.Indent(); if (isArray) { - writer.WriteLine($"var list = {fieldName} == null ? new List<{_typesSymbolTable[type.FullName]}>() : new List<{_typesSymbolTable[type.FullName]}>({fieldName});"); - writer.WriteLine($"list.Add(ManyToOneColumn<{_typesSymbolTable[type.FullName]}>.Create(name, source));"); + writer.WriteLine($"var list = {fieldName} == null ? new List<{TypesSymbolTable[type.FullName]}>() : new List<{TypesSymbolTable[type.FullName]}>({fieldName});"); + writer.WriteLine($"list.Add(ManyToOneColumn<{TypesSymbolTable[type.FullName]}>.Create(name, source));"); writer.WriteLine($"{fieldName} = list.ToArray();"); } else - writer.WriteLine($"{fieldName} = ManyToOneColumn<{_typesSymbolTable[type.FullName]}>.Create(name, source);"); + writer.WriteLine($"{fieldName} = ManyToOneColumn<{TypesSymbolTable[type.FullName]}>.Create(name, source);"); writer.Outdent(); writer.WriteLine("}"); writer.WriteLine(); @@ -942,10 +942,10 @@ private void GenerateInput(IndentingTextWriter writer, foreach (var line in entryPointInfo.Description.Split(new[] { Environment.NewLine }, StringSplitOptions.RemoveEmptyEntries)) writer.WriteLine($"/// {line}"); writer.WriteLine("/// "); - - if(entryPointInfo.ObsoleteAttribute != null) + + if (entryPointInfo.ObsoleteAttribute != null) writer.WriteLine($"[Obsolete(\"{entryPointInfo.ObsoleteAttribute.Message}\")]"); - + writer.WriteLine($"public sealed partial class {classAndMethod.Item2}{classBase}"); writer.WriteLine("{"); writer.Indent(); @@ -955,7 +955,7 @@ private void GenerateInput(IndentingTextWriter writer, GenerateColumnAddMethods(writer, entryPointInfo.InputType, catalog, classAndMethod.Item2, out Type transformType); writer.WriteLine(); - GenerateInputFields(writer, entryPointInfo.InputType, catalog, _typesSymbolTable); + GenerateInputFields(writer, entryPointInfo.InputType, catalog, TypesSymbolTable); writer.WriteLine(); GenerateOutput(writer, entryPointInfo, out HashSet outputVariableNames); @@ -1191,7 +1191,7 @@ private void GenerateComponent(IndentingTextWriter writer, ModuleCatalog.Compone writer.WriteLine($"public sealed class {GeneratorUtils.GetComponentName(component)} : {component.Kind}"); writer.WriteLine("{"); writer.Indent(); - GenerateInputFields(writer, component.ArgumentType, catalog, _typesSymbolTable, "Microsoft.ML."); + GenerateInputFields(writer, component.ArgumentType, catalog, TypesSymbolTable, "Microsoft.ML."); writer.WriteLine($"internal override string ComponentName => \"{component.Name}\";"); writer.Outdent(); writer.WriteLine("}");