Skip to content

Commit 6e0d8d0

Browse files
TomFinleyZruty0
authored andcommitted
Static typed Estimator/Transformer/Data (#778)
Fixes #632. Statically typed parallels to `IEstimator`, `ITransformer`, `IDataView`.
1 parent 622e028 commit 6e0d8d0

File tree

65 files changed

+3935
-113
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+3935
-113
lines changed

Microsoft.ML.sln

+23-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.CpuMath", "Mic
9393
EndProject
9494
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tools-local", "tools-local", "{7F13E156-3EBA-4021-84A5-CD56BA72F99E}"
9595
EndProject
96-
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeAnalyzer", "tools-local\Microsoft.ML.CodeAnalyzer\Microsoft.ML.CodeAnalyzer.csproj", "{B4E55B2D-2A92-46E7-B72F-E76D6FD83440}"
96+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.InternalCodeAnalyzer", "tools-local\Microsoft.ML.InternalCodeAnalyzer\Microsoft.ML.InternalCodeAnalyzer.csproj", "{B4E55B2D-2A92-46E7-B72F-E76D6FD83440}"
9797
EndProject
9898
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeAnalyzer.Tests", "test\Microsoft.ML.CodeAnalyzer.Tests\Microsoft.ML.CodeAnalyzer.Tests.csproj", "{3E4ABF07-7970-4BE6-B45B-A13D3C397545}"
9999
EndProject
@@ -111,6 +111,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.HalLearners",
111111
EndProject
112112
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.TensorFlow", "src\Microsoft.ML.TensorFlow\Microsoft.ML.TensorFlow.csproj", "{570A0B8A-5463-44D2-8521-54C0CA4CACA9}"
113113
EndProject
114+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Analyzer", "src\Microsoft.ML.Analyzer\Microsoft.ML.Analyzer.csproj", "{6DEF0F40-3853-47B3-8165-5F24BA5E14DF}"
115+
EndProject
116+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.StaticPipelineTesting", "test\Microsoft.ML.StaticPipelineTesting\Microsoft.ML.StaticPipelineTesting.csproj", "{8B38BF24-35F4-4787-A9C5-22D35987106E}"
117+
EndProject
114118
Global
115119
GlobalSection(SolutionConfigurationPlatforms) = preSolution
116120
Debug|Any CPU = Debug|Any CPU
@@ -399,6 +403,22 @@ Global
399403
{570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Release|Any CPU.Build.0 = Release|Any CPU
400404
{570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
401405
{570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
406+
{6DEF0F40-3853-47B3-8165-5F24BA5E14DF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
407+
{6DEF0F40-3853-47B3-8165-5F24BA5E14DF}.Debug|Any CPU.Build.0 = Debug|Any CPU
408+
{6DEF0F40-3853-47B3-8165-5F24BA5E14DF}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
409+
{6DEF0F40-3853-47B3-8165-5F24BA5E14DF}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
410+
{6DEF0F40-3853-47B3-8165-5F24BA5E14DF}.Release|Any CPU.ActiveCfg = Release|Any CPU
411+
{6DEF0F40-3853-47B3-8165-5F24BA5E14DF}.Release|Any CPU.Build.0 = Release|Any CPU
412+
{6DEF0F40-3853-47B3-8165-5F24BA5E14DF}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
413+
{6DEF0F40-3853-47B3-8165-5F24BA5E14DF}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
414+
{8B38BF24-35F4-4787-A9C5-22D35987106E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
415+
{8B38BF24-35F4-4787-A9C5-22D35987106E}.Debug|Any CPU.Build.0 = Debug|Any CPU
416+
{8B38BF24-35F4-4787-A9C5-22D35987106E}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
417+
{8B38BF24-35F4-4787-A9C5-22D35987106E}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
418+
{8B38BF24-35F4-4787-A9C5-22D35987106E}.Release|Any CPU.ActiveCfg = Release|Any CPU
419+
{8B38BF24-35F4-4787-A9C5-22D35987106E}.Release|Any CPU.Build.0 = Release|Any CPU
420+
{8B38BF24-35F4-4787-A9C5-22D35987106E}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
421+
{8B38BF24-35F4-4787-A9C5-22D35987106E}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
402422
EndGlobalSection
403423
GlobalSection(SolutionProperties) = preSolution
404424
HideSolutionNode = FALSE
@@ -444,6 +464,8 @@ Global
444464
{00E38F77-1E61-4CDF-8F97-1417D4E85053} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
445465
{A7222F41-1CF0-47D9-B80C-B4D77B027A61} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
446466
{570A0B8A-5463-44D2-8521-54C0CA4CACA9} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
467+
{6DEF0F40-3853-47B3-8165-5F24BA5E14DF} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
468+
{8B38BF24-35F4-4787-A9C5-22D35987106E} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
447469
EndGlobalSection
448470
GlobalSection(ExtensibilityGlobals) = postSolution
449471
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}

build/Dependencies.props

+4
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,9 @@
1212
<SystemDrawingCommonPackageVersion>4.5.0</SystemDrawingCommonPackageVersion>
1313
<BenchmarkDotNetVersion>0.11.1</BenchmarkDotNetVersion>
1414
<TensorFlowVersion>1.10.0</TensorFlowVersion>
15+
16+
<MicrosoftCodeAnalysisCSharpVersion>2.9.0</MicrosoftCodeAnalysisCSharpVersion>
17+
<MicrosoftCSharpVersion>4.5.0</MicrosoftCSharpVersion>
18+
<SystemCompositionVersion>1.2.0</SystemCompositionVersion>
1519
</PropertyGroup>
1620
</Project>

src/Directory.Build.props

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
<ItemGroup>
2727
<ProjectReference
2828
Condition="'$(UseMLCodeAnalyzer)' != 'false' and '$(MSBuildProjectExtension)' == '.csproj'"
29-
Include="$(MSBuildThisFileDirectory)\..\tools-local\Microsoft.ML.CodeAnalyzer\Microsoft.ML.CodeAnalyzer.csproj">
29+
Include="$(MSBuildThisFileDirectory)\..\tools-local\Microsoft.ML.InternalCodeAnalyzer\Microsoft.ML.InternalCodeAnalyzer.csproj">
3030
<ReferenceOutputAssembly>false</ReferenceOutputAssembly>
3131
<OutputItemType>Analyzer</OutputItemType>
3232
</ProjectReference>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
3+
<PropertyGroup>
4+
<TargetFramework>netstandard1.3</TargetFramework>
5+
</PropertyGroup>
6+
7+
<ItemGroup>
8+
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Workspaces" Version="$(MicrosoftCodeAnalysisCSharpVersion)" />
9+
<PackageReference Include="Microsoft.CSharp" Version="$(MicrosoftCSharpVersion)" />
10+
<PackageReference Include="System.Composition" Version="$(SystemCompositionVersion)" />
11+
</ItemGroup>
12+
13+
</Project>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Collections.Immutable;
6+
using System.Linq;
7+
using Microsoft.CodeAnalysis;
8+
using Microsoft.CodeAnalysis.CSharp;
9+
using Microsoft.CodeAnalysis.CSharp.Syntax;
10+
using Microsoft.CodeAnalysis.Diagnostics;
11+
12+
namespace Microsoft.ML.Analyzer
13+
{
14+
[DiagnosticAnalyzer(LanguageNames.CSharp)]
15+
public sealed class TypeIsSchemaShapeAnalyzer : DiagnosticAnalyzer
16+
{
17+
internal static class ShapeDiagnostic
18+
{
19+
private const string Category = "Type Check";
20+
public const string Id = "MSML_TypeShouldBeSchemaShape";
21+
private const string Title = "The type is not a schema shape";
22+
private const string Format = "Type{0} is neither a PipelineColumn nor a ValueTuple.";
23+
internal const string Description =
24+
"Within statically typed pipeline elements of ML.NET, the shape of the schema is determined by a type. " +
25+
"A valid type is either an instance of one of the PipelineColumn subclasses (e.g., Scalar<bool> " +
26+
"or something like that), or a ValueTuple containing only valid types. (So, ValueTuples containing " +
27+
"other value tuples are fine, so long as they terminate in a PipelineColumn subclass.)";
28+
29+
internal static DiagnosticDescriptor Rule =
30+
new DiagnosticDescriptor(Id, Title, Format, Category,
31+
DiagnosticSeverity.Error, isEnabledByDefault: true, description: Description);
32+
}
33+
34+
internal static class ShapeParameterDiagnostic
35+
{
36+
private const string Category = "Type Check";
37+
public const string Id = "MSML_TypeParameterShouldBeSchemaShape";
38+
private const string Title = "The type is not a schema shape";
39+
private const string Format = "Type parameter {0} is not marked with [IsShape] or appropriate type constraints.";
40+
internal const string Description = ShapeDiagnostic.Description + " " +
41+
"If using type parameters when interacting with the statically typed pipelines, the type parameter ought to be " +
42+
"constrained in such a way that it, either by applying the [IsShape] attribute or by having type constraints to " +
43+
"indicate that it is valid, e.g., constraining the type to descend from PipelineColumn.";
44+
45+
internal static DiagnosticDescriptor Rule =
46+
new DiagnosticDescriptor(Id, Title, Format, Category,
47+
DiagnosticSeverity.Error, isEnabledByDefault: true, description: Description);
48+
}
49+
50+
private const string AttributeName = "Microsoft.ML.Data.StaticPipe.IsShapeAttribute";
51+
private const string LeafTypeName = "Microsoft.ML.Data.StaticPipe.Runtime.PipelineColumn";
52+
53+
public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics =>
54+
ImmutableArray.Create(ShapeDiagnostic.Rule, ShapeParameterDiagnostic.Rule);
55+
56+
public override void Initialize(AnalysisContext context)
57+
{
58+
context.RegisterSemanticModelAction(Analyze);
59+
}
60+
61+
private void Analyze(SemanticModelAnalysisContext context)
62+
{
63+
// We start with the model, then do the the method invocations.
64+
// We could have phrased it as RegisterSyntaxNodeAction(Analyze, SyntaxKind.InvocationExpression),
65+
// but this seemed more inefficient since getting the model and fetching the type symbols every
66+
// single time seems to incur significant cost. The following invocation is somewhat more awkward
67+
// since we must iterate over the invocation syntaxes ourselves, but this seems to be worthwhile.
68+
var model = context.SemanticModel;
69+
var comp = model.Compilation;
70+
71+
// Get the symbols of the key types we are analyzing. If we can't find any of them there is
72+
// no point in going further.
73+
var attrType = comp.GetTypeByMetadataName(AttributeName);
74+
if (attrType == null)
75+
return;
76+
var leafType = comp.GetTypeByMetadataName(LeafTypeName);
77+
if (leafType == null)
78+
return;
79+
80+
// This internal helper method recursively determines whether an attributed type parameter
81+
// has a valid type. It is called externally from the loop over invocations.
82+
bool CheckType(ITypeSymbol type, out string path, out ITypeSymbol problematicType)
83+
{
84+
if (type.TypeKind == TypeKind.TypeParameter)
85+
{
86+
var typeParam = (ITypeParameterSymbol)type;
87+
path = null;
88+
problematicType = null;
89+
// Does the type parameter have the attribute that triggers a check?
90+
if (type.GetAttributes().Any(attr => attr.AttributeClass == attrType))
91+
return true;
92+
// Are any of the declared constraint types OK?
93+
if (typeParam.ConstraintTypes.Any(ct => CheckType(ct, out string ctPath, out var ctProb)))
94+
return true;
95+
// Well, probably not good then. Let's call it a day.
96+
problematicType = typeParam;
97+
return false;
98+
}
99+
else if (type.IsTupleType)
100+
{
101+
INamedTypeSymbol nameType = (INamedTypeSymbol)type;
102+
var tupleElems = nameType.TupleElements;
103+
104+
for (int i = 0; i < tupleElems.Length; ++i)
105+
{
106+
var e = tupleElems[i];
107+
if (!CheckType(e.Type, out string innerPath, out problematicType))
108+
{
109+
path = e.Name ?? $"Item{i + 1}";
110+
if (innerPath != null)
111+
path += "." + innerPath;
112+
return false;
113+
}
114+
}
115+
path = null;
116+
problematicType = null;
117+
return true;
118+
}
119+
else
120+
{
121+
for (var rt = type; rt != null; rt = rt.BaseType)
122+
{
123+
if (rt == leafType)
124+
{
125+
path = null;
126+
problematicType = null;
127+
return true;
128+
}
129+
}
130+
path = null;
131+
problematicType = type;
132+
return false;
133+
}
134+
}
135+
136+
foreach (var invocation in model.SyntaxTree.GetRoot().DescendantNodes().OfType<InvocationExpressionSyntax>())
137+
{
138+
var symbolInfo = model.GetSymbolInfo(invocation);
139+
if (!(symbolInfo.Symbol is IMethodSymbol methodSymbol))
140+
{
141+
// Should we perhaps skip when there is a method resolution failure? This is often but not always a sign of another problem.
142+
if (symbolInfo.CandidateReason != CandidateReason.OverloadResolutionFailure || symbolInfo.CandidateSymbols.Length == 0)
143+
continue;
144+
methodSymbol = symbolInfo.CandidateSymbols[0] as IMethodSymbol;
145+
if (methodSymbol == null)
146+
continue;
147+
}
148+
// Analysis only applies to generic methods.
149+
if (!methodSymbol.IsGenericMethod)
150+
continue;
151+
// Scan the type parameters for one that has our target attribute.
152+
for (int i = 0; i < methodSymbol.TypeParameters.Length; ++i)
153+
{
154+
var par = methodSymbol.TypeParameters[i];
155+
var attr = par.GetAttributes();
156+
if (attr.Length == 0)
157+
continue;
158+
if (!attr.Any(a => a.AttributeClass == attrType))
159+
continue;
160+
// We've found it. Check the type argument to ensure it is of the appropriate type.
161+
var p = methodSymbol.TypeArguments[i];
162+
if (CheckType(p, out string path, out ITypeSymbol problematicType))
163+
continue;
164+
165+
if (problematicType.Kind == SymbolKind.TypeParameter)
166+
{
167+
var diagnostic = Diagnostic.Create(ShapeParameterDiagnostic.Rule, invocation.GetLocation(), problematicType.Name);
168+
context.ReportDiagnostic(diagnostic);
169+
}
170+
else
171+
{
172+
path = path == null ? "" : " of item " + path;
173+
var diagnostic = Diagnostic.Create(ShapeDiagnostic.Rule, invocation.GetLocation(), path);
174+
context.ReportDiagnostic(diagnostic);
175+
}
176+
}
177+
}
178+
}
179+
}
180+
}

src/Microsoft.ML.Core/Data/DataKind.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public enum DataKind : byte
3030
Num = R4,
3131

3232
TX = 11,
33-
#pragma warning disable MSML_GeneralName // The data kind enum has its own logic, independnet of C# naming conventions.
33+
#pragma warning disable MSML_GeneralName // The data kind enum has its own logic, independent of C# naming conventions.
3434
TXT = TX,
3535
Text = TX,
3636

src/Microsoft.ML.Core/Data/IEstimator.cs

+2-3
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,10 @@ public interface IDataReader<in TSource>
192192
public interface IDataReaderEstimator<in TSource, out TReader>
193193
where TReader : IDataReader<TSource>
194194
{
195+
// REVIEW: you could consider the transformer to take a different <typeparamref name="TSource"/>, but we don't have such components
196+
// yet, so why complicate matters?
195197
/// <summary>
196198
/// Train and return a data reader.
197-
///
198-
/// REVIEW: you could consider the transformer to take a different <typeparamref name="TSource"/>, but we don't have such components
199-
/// yet, so why complicate matters?
200199
/// </summary>
201200
TReader Fit(TSource input);
202201

src/Microsoft.ML.Core/Utilities/Utils.cs

+18
Original file line numberDiff line numberDiff line change
@@ -1049,6 +1049,24 @@ public static void MarshalActionInvoke<TArg1>(Action<TArg1> act, Type genArg, TA
10491049
meth.Invoke(act.Target, new object[] { arg1 });
10501050
}
10511051

1052+
/// <summary>
1053+
/// A two-argument version of <see cref="MarshalActionInvoke(Action, Type)"/>.
1054+
/// </summary>
1055+
public static void MarshalActionInvoke<TArg1, TArg2>(Action<TArg1, TArg2> act, Type genArg, TArg1 arg1, TArg2 arg2)
1056+
{
1057+
var meth = MarshalActionInvokeCheckAndCreate(genArg, act);
1058+
meth.Invoke(act.Target, new object[] { arg1, arg2 });
1059+
}
1060+
1061+
/// <summary>
1062+
/// A three-argument version of <see cref="MarshalActionInvoke(Action, Type)"/>.
1063+
/// </summary>
1064+
public static void MarshalActionInvoke<TArg1, TArg2, TArg3>(Action<TArg1, TArg2, TArg3> act, Type genArg, TArg1 arg1, TArg2 arg2, TArg3 arg3)
1065+
{
1066+
var meth = MarshalActionInvokeCheckAndCreate(genArg, act);
1067+
meth.Invoke(act.Target, new object[] { arg1, arg2, arg3 });
1068+
}
1069+
10521070
public static string GetDescription(this Enum value)
10531071
{
10541072
Type type = value.GetType();

0 commit comments

Comments
 (0)