Skip to content

Commit 40e8e58

Browse files
daholsteDmitry-A
authored andcommitted
add estimator extensions / catalog; add conversion from external to internal pipeline; transform clean-up; add back in test proj and fix build; refactor trainer ext name mappings (dotnet#15)
* Make validation data param mandatory; remove GetFirstPipeline sample * remove deprecated todo * add estimator extensions / catalog; add ability to go from external to internal pipeline; a lot of transform clean-up; add back in test proj and get it building; refactor trainer ext name mappings
1 parent de7ac15 commit 40e8e58

18 files changed

+492
-511
lines changed

AutoML.sln

+24-20
Original file line numberDiff line numberDiff line change
@@ -7,50 +7,54 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoML", "src\AutoML\AutoML
77
EndProject
88
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Samples", "src\Samples\Samples.csproj", "{64A7294E-A2C7-4499-8F0B-4BB074047C6B}"
99
EndProject
10-
#Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "InternalClient", "src\InternalClient\InternalClient.csproj", "{8D564A01-DCA9-443A-9995-A5A67BE4C2CD}"
11-
#EndProject
12-
#Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Test", "src\Test\Test.csproj", "{6DA91D40-302C-495C-B1DA-20701CDA49C6}"
13-
#EndProject
10+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Test", "src\Test\Test.csproj", "{55ACB7E2-053D-43BB-88E8-0E102FBD62F0}"
11+
EndProject
1412
Global
1513
GlobalSection(SolutionConfigurationPlatforms) = preSolution
1614
Debug|Any CPU = Debug|Any CPU
1715
Debug-Intrinsics|Any CPU = Debug-Intrinsics|Any CPU
16+
Debug-netfx|Any CPU = Debug-netfx|Any CPU
1817
Release|Any CPU = Release|Any CPU
1918
Release-Intrinsics|Any CPU = Release-Intrinsics|Any CPU
19+
Release-netfx|Any CPU = Release-netfx|Any CPU
2020
EndGlobalSection
2121
GlobalSection(ProjectConfigurationPlatforms) = postSolution
2222
{B3727729-3DF8-47E0-8710-9B41DAF55817}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
2323
{B3727729-3DF8-47E0-8710-9B41DAF55817}.Debug|Any CPU.Build.0 = Debug|Any CPU
2424
{B3727729-3DF8-47E0-8710-9B41DAF55817}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
2525
{B3727729-3DF8-47E0-8710-9B41DAF55817}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
26+
{B3727729-3DF8-47E0-8710-9B41DAF55817}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU
27+
{B3727729-3DF8-47E0-8710-9B41DAF55817}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU
2628
{B3727729-3DF8-47E0-8710-9B41DAF55817}.Release|Any CPU.ActiveCfg = Release|Any CPU
2729
{B3727729-3DF8-47E0-8710-9B41DAF55817}.Release|Any CPU.Build.0 = Release|Any CPU
2830
{B3727729-3DF8-47E0-8710-9B41DAF55817}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
2931
{B3727729-3DF8-47E0-8710-9B41DAF55817}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
32+
{B3727729-3DF8-47E0-8710-9B41DAF55817}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU
33+
{B3727729-3DF8-47E0-8710-9B41DAF55817}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU
3034
{64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
3135
{64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Debug|Any CPU.Build.0 = Debug|Any CPU
3236
{64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
3337
{64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
38+
{64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU
39+
{64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU
3440
{64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Release|Any CPU.ActiveCfg = Release|Any CPU
3541
{64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Release|Any CPU.Build.0 = Release|Any CPU
3642
{64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
3743
{64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
38-
{8D564A01-DCA9-443A-9995-A5A67BE4C2CD}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
39-
{8D564A01-DCA9-443A-9995-A5A67BE4C2CD}.Debug|Any CPU.Build.0 = Debug|Any CPU
40-
{8D564A01-DCA9-443A-9995-A5A67BE4C2CD}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
41-
{8D564A01-DCA9-443A-9995-A5A67BE4C2CD}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
42-
{8D564A01-DCA9-443A-9995-A5A67BE4C2CD}.Release|Any CPU.ActiveCfg = Release|Any CPU
43-
{8D564A01-DCA9-443A-9995-A5A67BE4C2CD}.Release|Any CPU.Build.0 = Release|Any CPU
44-
{8D564A01-DCA9-443A-9995-A5A67BE4C2CD}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
45-
{8D564A01-DCA9-443A-9995-A5A67BE4C2CD}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
46-
{6DA91D40-302C-495C-B1DA-20701CDA49C6}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
47-
{6DA91D40-302C-495C-B1DA-20701CDA49C6}.Debug|Any CPU.Build.0 = Debug|Any CPU
48-
{6DA91D40-302C-495C-B1DA-20701CDA49C6}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
49-
{6DA91D40-302C-495C-B1DA-20701CDA49C6}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
50-
{6DA91D40-302C-495C-B1DA-20701CDA49C6}.Release|Any CPU.ActiveCfg = Release|Any CPU
51-
{6DA91D40-302C-495C-B1DA-20701CDA49C6}.Release|Any CPU.Build.0 = Release|Any CPU
52-
{6DA91D40-302C-495C-B1DA-20701CDA49C6}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
53-
{6DA91D40-302C-495C-B1DA-20701CDA49C6}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
44+
{64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU
45+
{64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU
46+
{55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
47+
{55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Debug|Any CPU.Build.0 = Debug|Any CPU
48+
{55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU
49+
{55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU
50+
{55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU
51+
{55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU
52+
{55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Release|Any CPU.ActiveCfg = Release|Any CPU
53+
{55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Release|Any CPU.Build.0 = Release|Any CPU
54+
{55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
55+
{55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
56+
{55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU
57+
{55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU
5458
EndGlobalSection
5559
GlobalSection(SolutionProperties) = preSolution
5660
HideSolutionNode = FALSE

src/AutoML/API/Pipeline.cs

+14-2
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,25 @@ public class PipelineNode
2222

2323
public PipelineNode(string name, PipelineNodeType elementType,
2424
string[] inColumns, string[] outColumns,
25-
IDictionary<string, object> properties)
25+
IDictionary<string, object> properties = null)
2626
{
2727
Name = name;
2828
ElementType = elementType;
2929
InColumns = inColumns;
3030
OutColumns = outColumns;
31-
Properties = properties;
31+
Properties = properties ?? new Dictionary<string, object>();
32+
}
33+
34+
public PipelineNode(string name, PipelineNodeType elementType,
35+
string inColumn, string outColumn, IDictionary<string, object> properties = null) :
36+
this(name, elementType, new string[] { inColumn }, new string[] { outColumn }, properties)
37+
{
38+
}
39+
40+
public PipelineNode(string name, PipelineNodeType elementType,
41+
string[] inColumns, string outColumn, IDictionary<string, object> properties = null) :
42+
this(name, elementType, inColumns, new string[] { outColumn }, properties)
43+
{
3244
}
3345
}
3446

src/AutoML/Assembly.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
using System.Runtime.CompilerServices;
22

3-
// [assembly: InternalsVisibleTo("InternalClient")]
4-
// [assembly: InternalsVisibleTo("Test")]
3+
//[assembly: InternalsVisibleTo("InternalClient")]
4+
[assembly: InternalsVisibleTo("Test, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")]

src/AutoML/AutoFitter/InferredPipeline.cs

+40-13
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System;
56
using System.Collections.Generic;
67
using System.Linq;
78
using Microsoft.ML.Core.Data;
89
using Microsoft.ML.Data;
9-
using static Microsoft.ML.Auto.TransformInference.ColumnRoutingStructure;
1010

1111
namespace Microsoft.ML.Auto
1212
{
@@ -22,12 +22,17 @@ internal class InferredPipeline
2222

2323
public InferredPipeline(IEnumerable<SuggestedTransform> transforms,
2424
SuggestedTrainer trainer,
25-
MLContext context = null)
25+
MLContext context = null,
26+
bool autoNormalize = true)
2627
{
2728
Transforms = transforms.Select(t => t.Clone()).ToList();
2829
Trainer = trainer.Clone();
2930
_context = context ?? new MLContext();
30-
AddNormalizationTransforms();
31+
32+
if(autoNormalize)
33+
{
34+
AddNormalizationTransforms();
35+
}
3136
}
3237

3338
public override string ToString() => $"{Trainer}+{string.Join("+", Transforms.Select(t => t.ToString()))}";
@@ -52,12 +57,42 @@ public Pipeline ToPipeline()
5257
var pipelineElements = new List<PipelineNode>();
5358
foreach(var transform in Transforms)
5459
{
55-
pipelineElements.Add(transform.ToPipelineNode());
60+
pipelineElements.Add(transform.PipelineNode);
5661
}
5762
pipelineElements.Add(Trainer.ToPipelineNode());
5863
return new Pipeline(pipelineElements.ToArray());
5964
}
6065

66+
public static InferredPipeline FromPipeline(Pipeline pipeline)
67+
{
68+
var context = new MLContext();
69+
70+
var transforms = new List<SuggestedTransform>();
71+
SuggestedTrainer trainer = null;
72+
73+
foreach(var pipelineNode in pipeline.Elements)
74+
{
75+
if(pipelineNode.ElementType == PipelineNodeType.Trainer)
76+
{
77+
var trainerName = (TrainerName)Enum.Parse(typeof(TrainerName), pipelineNode.Name);
78+
var trainerExtension = TrainerExtensionCatalog.GetTrainerExtension(trainerName);
79+
var stringParamVals = pipelineNode.Properties.Select(prop => new StringParameterValue(prop.Key, prop.Value.ToString()));
80+
var hyperParamSet = new ParameterSet(stringParamVals);
81+
trainer = new SuggestedTrainer(context, trainerExtension, hyperParamSet);
82+
}
83+
else if (pipelineNode.ElementType == PipelineNodeType.Transform)
84+
{
85+
var estimatorName = (EstimatorName)Enum.Parse(typeof(EstimatorName), pipelineNode.Name);
86+
var estimatorExtension = EstimatorExtensionCatalog.GetExtension(estimatorName);
87+
var estimator = estimatorExtension.CreateInstance(new MLContext(), pipelineNode);
88+
var transform = new SuggestedTransform(pipelineNode, estimator);
89+
transforms.Add(transform);
90+
}
91+
}
92+
93+
return new InferredPipeline(transforms, trainer, null, false);
94+
}
95+
6196
public ITransformer TrainTransformer(IDataView trainData)
6297
{
6398
IEstimator<ITransformer> pipeline = new EstimatorChain<ITransformer>();
@@ -91,15 +126,7 @@ private void AddNormalizationTransforms()
91126
return;
92127
}
93128

94-
var estimator = _context.Transforms.Normalize(DefaultColumnNames.Features);
95-
var annotatedColNames = new[] { new AnnotatedName() { Name = DefaultColumnNames.Features, IsNumeric = true } };
96-
var routingStructure = new TransformInference.ColumnRoutingStructure(annotatedColNames, annotatedColNames);
97-
var properties = new Dictionary<string, string>()
98-
{
99-
{ "mode", "MinMax" }
100-
};
101-
var transform = new SuggestedTransform(estimator,
102-
routingStructure: routingStructure, properties: properties);
129+
var transform = NormalizingExtension.CreateSuggestedTransform(_context, DefaultColumnNames.Features, DefaultColumnNames.Features);
103130
Transforms.Add(transform);
104131
}
105132
}

src/AutoML/AutoFitter/SuggestedTrainer.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ internal SuggestedTrainer(MLContext mlContext, ITrainerExtension trainerExtensio
1919
_mlContext = mlContext;
2020
_trainerExtension = trainerExtension;
2121
SweepParams = _trainerExtension.GetHyperparamSweepRanges();
22-
TrainerName = _trainerExtension.GetTrainerName();
22+
TrainerName = TrainerExtensionCatalog.GetTrainerName(_trainerExtension);
2323
SetHyperparamValues(hyperParamSet);
2424
}
2525

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
using System;
2+
using System.Collections.Generic;
3+
4+
namespace Microsoft.ML.Auto
5+
{
6+
public enum EstimatorName
7+
{
8+
ColumnConcatenating,
9+
ColumnCopying,
10+
MissingValueIndicator,
11+
Normalizing,
12+
OneHotEncoding,
13+
OneHotHashEncoding,
14+
TextFeaturizing,
15+
TypeConverting,
16+
ValueToKeyMapping
17+
}
18+
19+
internal class EstimatorExtensionCatalog
20+
{
21+
private static readonly IDictionary<EstimatorName, Type> _namesToExtensionTypes = new
22+
Dictionary<EstimatorName, Type>()
23+
{
24+
{ EstimatorName.ColumnConcatenating, typeof(ColumnConcatenatingExtension) },
25+
{ EstimatorName.ColumnCopying, typeof(ColumnCopyingExtension) },
26+
{ EstimatorName.MissingValueIndicator, typeof(MissingValueIndicatorExtension) },
27+
{ EstimatorName.Normalizing, typeof(NormalizingExtension) },
28+
{ EstimatorName.OneHotEncoding, typeof(OneHotEncodingExtension) },
29+
{ EstimatorName.OneHotHashEncoding, typeof(OneHotHashEncodingExtension) },
30+
{ EstimatorName.TextFeaturizing, typeof(TextFeaturizingExtension) },
31+
{ EstimatorName.TypeConverting, typeof(TypeConvertingExtension) },
32+
{ EstimatorName.ValueToKeyMapping, typeof(ValueToKeyMappingExtension) },
33+
};
34+
35+
public static IEstimatorExtension GetExtension(EstimatorName estimatorName)
36+
{
37+
var extType = _namesToExtensionTypes[estimatorName];
38+
return (IEstimatorExtension)Activator.CreateInstance(extType);
39+
}
40+
}
41+
}

0 commit comments

Comments
 (0)