Skip to content

Commit 3dd3a1e

Browse files
authored
Hide much infrastructure in data (#2300)
* Internalize normalizer infrastructure, normalizer transform, entry-points. * Internalization of component level command line infrastructure. * Parse/TryUnparse methods internalized. * Lockdown of column abstract class constructors to make them unextensible. * Removal of legacy API derived interfaces and abstract classes for columns. * Internalization of column bindings utilities and infrastructure. * Hide filtering/misc IDataTransform implementations and associated entry-points. * Internalize or move much of entry-points. * Internalize common outputs of entry-points and deal with resulting fallout. * Internalization of common inputs. * Internalization of more model saving infrastructure.
1 parent a869547 commit 3dd3a1e

File tree

130 files changed

+375
-314
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

130 files changed

+375
-314
lines changed

src/Microsoft.ML.Core/CommandLine/CmdParser.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -1359,17 +1359,17 @@ public ArgumentInfo(Type type, Argument argDef, Argument[] args, Dictionary<stri
13591359
private static MethodInfo GetParseMethod(Type type)
13601360
{
13611361
Contracts.AssertValue(type);
1362-
var meth = type.GetMethod("Parse", new[] { typeof(string) });
1363-
if (meth != null && meth.IsStatic && meth.IsPublic && meth.ReturnType == type)
1362+
var meth = type.GetMethod("Parse", BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public, binder: null, new[] { typeof(string) }, null);
1363+
if (meth != null && meth.IsStatic && !meth.IsPrivate && meth.ReturnType == type)
13641364
return meth;
13651365
return null;
13661366
}
13671367

13681368
private static MethodInfo GetUnparseMethod(Type type)
13691369
{
13701370
Contracts.AssertValue(type);
1371-
var meth = type.GetMethod("TryUnparse", new[] { typeof(StringBuilder) });
1372-
if (meth != null && !meth.IsStatic && meth.IsPublic && meth.ReturnType == typeof(bool))
1371+
var meth = type.GetMethod("TryUnparse", BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public, binder: null, new[] { typeof(StringBuilder) }, null);
1372+
if (meth != null && !meth.IsPrivate && meth.ReturnType == typeof(bool))
13731373
return meth;
13741374
return null;
13751375
}

src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ private void ScanForEntryPoints(LoadableClassInfo info)
468468
var type = info.LoaderType;
469469

470470
// Scan for entry points.
471-
foreach (var methodInfo in type.GetMethods(BindingFlags.Static | BindingFlags.Public))
471+
foreach (var methodInfo in type.GetMethods(BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic))
472472
{
473473
var attr = methodInfo.GetCustomAttributes(typeof(TlcModule.EntryPointAttribute), false).FirstOrDefault() as TlcModule.EntryPointAttribute;
474474
if (attr == null)
@@ -727,7 +727,7 @@ internal LoadableClassInfo GetLoadableClassInfo(string loadName, Type signatureT
727727
[BestFriend]
728728
internal IEnumerable<EntryPointInfo> AllEntryPoints()
729729
{
730-
return _entryPoints.AsEnumerable();
730+
return _entryPoints;
731731
}
732732

733733
[BestFriend]

src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ public sealed class Column
9191
[Argument(ArgumentType.Required, HelpText = "Index of the directory representing this column.")]
9292
public int Source;
9393

94-
public static Column Parse(string str)
94+
internal static Column Parse(string str)
9595
{
9696
Contracts.AssertNonEmpty(str);
9797

@@ -103,7 +103,7 @@ public static Column Parse(string str)
103103
return null;
104104
}
105105

106-
public static bool TryParse(string str, out Column column)
106+
private static bool TryParse(string str, out Column column)
107107
{
108108
column = null;
109109

@@ -138,7 +138,7 @@ public static bool TryParse(string str, out Column column)
138138
return true;
139139
}
140140

141-
public bool TryUnparse(StringBuilder sb)
141+
internal bool TryUnparse(StringBuilder sb)
142142
{
143143
Contracts.AssertValue(sb);
144144

src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs

+5-5
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ public Column(string name, DataKind? type, Range[] source, KeyCount keyCount = n
7474
[Argument(ArgumentType.Multiple, HelpText = "For a key column, this defines the range of values", ShortName = "key")]
7575
public KeyCount KeyCount;
7676

77-
public static Column Parse(string str)
77+
internal static Column Parse(string str)
7878
{
7979
Contracts.AssertNonEmpty(str);
8080

@@ -125,7 +125,7 @@ public static bool TryParseSourceEx(string str, out Range[] ranges)
125125
return true;
126126
}
127127

128-
public bool TryUnparse(StringBuilder sb)
128+
internal bool TryUnparse(StringBuilder sb)
129129
{
130130
Contracts.AssertValue(sb);
131131

@@ -172,7 +172,7 @@ public bool TryUnparse(StringBuilder sb)
172172
/// <summary>
173173
/// Returns <c>true</c> iff the ranges are disjoint, and each range satisfies 0 &lt;= min &lt;= max.
174174
/// </summary>
175-
public bool IsValid()
175+
internal bool IsValid()
176176
{
177177
if (Utils.Size(Source) == 0)
178178
return false;
@@ -258,7 +258,7 @@ public Range(int min, int? max)
258258
[Argument(ArgumentType.AtMostOnce, HelpText = "Force scalar columns to be treated as vectors of length one", ShortName = "vector")]
259259
public bool ForceVector;
260260

261-
public static Range Parse(string str)
261+
internal static Range Parse(string str)
262262
{
263263
Contracts.AssertNonEmpty(str);
264264

@@ -314,7 +314,7 @@ private bool TryParse(string str)
314314
return true;
315315
}
316316

317-
public bool TryUnparse(StringBuilder sb)
317+
internal bool TryUnparse(StringBuilder sb)
318318
{
319319
Contracts.AssertValue(sb);
320320
char dash = AllOther ? '~' : '-';

src/Microsoft.ML.Data/Dirty/PredictorUtils.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ public static void SaveText(IChannel ch, IPredictor predictor, RoleMappedSchema
6464
}
6565

6666
/// <summary>
67-
/// Save the model in binary format (if it can save itself)
67+
/// Save the model in binary format (if it can save itself).
6868
/// </summary>
6969
public static void SaveBinary(IChannel ch, IPredictor predictor, BinaryWriter writer)
7070
{

src/Microsoft.ML.Data/EntryPoints/Cache.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
[assembly: LoadableClass(typeof(void), typeof(Cache), null, typeof(SignatureEntryPointModule), "Cache")]
1414
namespace Microsoft.ML.EntryPoints
1515
{
16-
public static class Cache
16+
internal static class Cache
1717
{
1818
public enum CachingType
1919
{

src/Microsoft.ML.Data/EntryPoints/CommonOutputs.cs

+3-2
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
namespace Microsoft.ML.EntryPoints
1010
{
1111
/// <summary>
12-
/// Common output classes for trainers and transforms.
12+
/// Common output classes for trainers and transform entry-points.
1313
/// </summary>
14-
public static class CommonOutputs
14+
[BestFriend]
15+
internal static class CommonOutputs
1516
{
1617
/// <summary>
1718
/// The common output class for all transforms.

src/Microsoft.ML.Data/EntryPoints/InputBase.cs

+24-9
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Collections.Generic;
77
using Microsoft.Data.DataView;
88
using Microsoft.ML.CommandLine;
9+
using Microsoft.ML.Core.Data;
910
using Microsoft.ML.Data;
1011
using Microsoft.ML.Data.IO;
1112
using Microsoft.ML.Internal.Calibration;
@@ -18,11 +19,17 @@ namespace Microsoft.ML.EntryPoints
1819
[TlcModule.EntryPointKind(typeof(CommonInputs.ITransformInput))]
1920
public abstract class TransformInputBase
2021
{
22+
/// <summary>
23+
/// The input dataset. Used only in entry-point methods, since the normal API mechanism for feeding in a dataset to
24+
/// create an <see cref="ITransformer"/> is to use the <see cref="IEstimator{TTransformer}.Fit(IDataView)"/> method.
25+
/// </summary>
26+
[BestFriend]
2127
[Argument(ArgumentType.Required, HelpText = "Input dataset", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, SortOrder = 1)]
22-
public IDataView Data;
28+
internal IDataView Data;
2329
}
2430

25-
public enum CachingOptions
31+
[BestFriend]
32+
internal enum CachingOptions
2633
{
2734
Auto,
2835
Memory,
@@ -37,10 +44,13 @@ public enum CachingOptions
3744
public abstract class LearnerInputBase
3845
{
3946
/// <summary>
40-
/// The data to be used for training.
47+
/// The data to be used for training. Used only in entry-points, since in the API the expected mechanism is
48+
/// that the user iwll use the <see cref="IEstimator{TTransformer}.Fit(IDataView)"/> or some other train
49+
/// method.
4150
/// </summary>
51+
[BestFriend]
4252
[Argument(ArgumentType.Required, ShortName = "data", HelpText = "The data to be used for training", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
43-
public IDataView TrainingData;
53+
internal IDataView TrainingData;
4454

4555
/// <summary>
4656
/// Column to use for features.
@@ -49,16 +59,20 @@ public abstract class LearnerInputBase
4959
public string FeatureColumn = DefaultColumnNames.Features;
5060

5161
/// <summary>
52-
/// Normalize option for the feature column.
62+
/// Normalize option for the feature column. Used only in entry-points, since in the API the user is expected to do this themselves.
5363
/// </summary>
64+
[BestFriend]
5465
[Argument(ArgumentType.AtMostOnce, HelpText = "Normalize option for the feature column", ShortName = "norm", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
55-
public NormalizeOption NormalizeFeatures = NormalizeOption.Auto;
66+
internal NormalizeOption NormalizeFeatures = NormalizeOption.Auto;
5667

5768
/// <summary>
58-
/// Whether learner should cache input training data.
69+
/// Whether learner should cache input training data. Used only in entry-points, since the intended API mechanism
70+
/// is that the user will use the <see cref="DataOperationsCatalog.Cache(IDataView, string[])"/> or other method
71+
/// like <see cref="EstimatorChain{TLastTransformer}.AppendCacheCheckpoint(IHostEnvironment)"/>.
5972
/// </summary>
73+
[BestFriend]
6074
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether learner should cache input training data", ShortName = "cache", SortOrder = 6, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
61-
public CachingOptions Caching = CachingOptions.Auto;
75+
internal CachingOptions Caching = CachingOptions.Auto;
6276
}
6377

6478
/// <summary>
@@ -221,7 +235,8 @@ public static TOut Train<TArg, TOut>(IHost host, TArg input,
221235
/// <summary>
222236
/// Common input interfaces for TLC components.
223237
/// </summary>
224-
public static class CommonInputs
238+
[BestFriend]
239+
internal static class CommonInputs
225240
{
226241
/// <summary>
227242
/// Interface that all API transform input classes will implement.

src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
namespace Microsoft.ML.EntryPoints
1313
{
14-
public static class SchemaManipulation
14+
internal static class SchemaManipulation
1515
{
1616
[TlcModule.EntryPoint(Name = "Transforms.ColumnConcatenator", Desc = ColumnConcatenatingTransformer.Summary, UserName = ColumnConcatenatingTransformer.UserName, ShortName = ColumnConcatenatingTransformer.LoadName)]
1717
public static CommonOutputs.TransformOutput ConcatColumns(IHostEnvironment env, ColumnConcatenatingTransformer.Arguments input)

src/Microsoft.ML.Data/EntryPoints/SelectRows.cs

+4-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99

1010
namespace Microsoft.ML.EntryPoints
1111
{
12-
public static class SelectRows
12+
/// <summary>
13+
/// Entry point methods for row filtering and selection.
14+
/// </summary>
15+
internal static class SelectRows
1316
{
1417
[TlcModule.EntryPoint(Name = "Transforms.RowRangeFilter", Desc = RangeFilter.Summary, UserName = RangeFilter.UserName, ShortName = RangeFilter.LoaderSignature)]
1518
public static CommonOutputs.TransformOutput FilterByRange(IHostEnvironment env, RangeFilter.Arguments input)

src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515

1616
namespace Microsoft.ML.EntryPoints
1717
{
18-
public static class SummarizePredictor
18+
[BestFriend]
19+
internal static class SummarizePredictor
1920
{
2021
public abstract class InputBase
2122
{

src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,7 @@ public override IEnumerable<MetricColumn> GetOverallMetricColumns()
764764
}
765765
}
766766

767-
public static partial class Evaluate
767+
internal static partial class Evaluate
768768
{
769769
[TlcModule.EntryPoint(Name = "Models.AnomalyDetectionEvaluator", Desc = "Evaluates an anomaly detection scored dataset.")]
770770
public static CommonOutputs.CommonEvaluateOutput AnomalyDetection(IHostEnvironment env, AnomalyDetectionMamlEvaluator.Arguments input)

src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1506,7 +1506,7 @@ private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleM
15061506
}
15071507
}
15081508

1509-
public static partial class Evaluate
1509+
internal static partial class Evaluate
15101510
{
15111511
[TlcModule.EntryPoint(Name = "Models.BinaryClassificationEvaluator", Desc = "Evaluates a binary classification scored dataset.")]
15121512
public static CommonOutputs.ClassificationEvaluateOutput Binary(IHostEnvironment env, BinaryClassifierMamlEvaluator.Arguments input)

src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,7 @@ private protected override IDataView GetPerInstanceMetricsCore(IDataView perInst
851851
}
852852
}
853853

854-
public static partial class Evaluate
854+
internal static partial class Evaluate
855855
{
856856
[TlcModule.EntryPoint(Name = "Models.ClusterEvaluator", Desc = "Evaluates a clustering scored dataset.")]
857857
public static CommonOutputs.CommonEvaluateOutput Clustering(IHostEnvironment env, ClusteringMamlEvaluator.Arguments input)

src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1029,7 +1029,7 @@ private protected override IDataView GetPerInstanceMetricsCore(IDataView perInst
10291029
}
10301030
}
10311031

1032-
public static partial class Evaluate
1032+
internal static partial class Evaluate
10331033
{
10341034
[TlcModule.EntryPoint(Name = "Models.ClassificationEvaluator", Desc = "Evaluates a multi class classification scored dataset.")]
10351035
public static CommonOutputs.ClassificationEvaluateOutput MultiClass(IHostEnvironment env, MultiClassMamlEvaluator.Arguments input)

src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ private protected override void PrintFoldResultsCore(IChannel ch, Dictionary<str
767767
}
768768
}
769769

770-
public static partial class Evaluate
770+
internal static partial class Evaluate
771771
{
772772
[TlcModule.EntryPoint(Name = "Models.MultiOutputRegressionEvaluator", Desc = "Evaluates a multi output regression scored dataset.")]
773773
public static CommonOutputs.CommonEvaluateOutput MultiOutputRegression(IHostEnvironment env, MultiOutputRegressionMamlEvaluator.Arguments input)

src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleM
552552
}
553553
}
554554

555-
public static partial class Evaluate
555+
internal static partial class Evaluate
556556
{
557557
[TlcModule.EntryPoint(Name = "Models.QuantileRegressionEvaluator", Desc = "Evaluates a quantile regression scored dataset.")]
558558
public static CommonOutputs.CommonEvaluateOutput QuantileRegression(IHostEnvironment env, QuantileRegressionMamlEvaluator.Arguments input)

src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1049,7 +1049,7 @@ private static Comparison<int> GetCompareItems(List<short> queryLabels, List<Sin
10491049
}
10501050
}
10511051

1052-
public static partial class Evaluate
1052+
internal static partial class Evaluate
10531053
{
10541054
[TlcModule.EntryPoint(Name = "Models.RankerEvaluator", Desc = "Evaluates a ranking scored dataset.")]
10551055
public static CommonOutputs.CommonEvaluateOutput Ranking(IHostEnvironment env, RankerMamlEvaluator.Arguments input)

src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleM
371371
}
372372
}
373373

374-
public static partial class Evaluate
374+
internal static partial class Evaluate
375375
{
376376
[TlcModule.EntryPoint(Name = "Models.RegressionEvaluator", Desc = "Evaluates a regression scored dataset.")]
377377
public static CommonOutputs.CommonEvaluateOutput Regression(IHostEnvironment env, RegressionMamlEvaluator.Arguments input)

src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ namespace Microsoft.ML.Model.Onnx
1313
/// That method creates a with inputs and outputs, but this object can modify the node further
1414
/// by adding attributes (in ONNX parlance, attributes are more or less constant parameterizations).
1515
/// </summary>
16-
public abstract class OnnxNode
16+
[BestFriend]
17+
internal abstract class OnnxNode
1718
{
1819
public abstract void AddAttribute(string argName, double value);
1920
public abstract void AddAttribute(string argName, long value);

src/Microsoft.ML.Data/Model/Repository.cs

+4-2
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@ public interface ICanSaveModel
2424
}
2525

2626
/// <summary>
27-
/// For saving to a single stream.
27+
/// For saving to a single stream. Note that this interface is mostly deprecated in favor of
28+
/// saving more comprehensive and composable "model" objects, via <see cref="ICanSaveModel"/>.
2829
/// </summary>
29-
public interface ICanSaveInBinaryFormat
30+
[BestFriend]
31+
internal interface ICanSaveInBinaryFormat
3032
{
3133
void SaveAsBinary(BinaryWriter writer);
3234
}

src/Microsoft.ML.Data/Prediction/Calibrator.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1070,7 +1070,7 @@ private static NaiveCalibrator Create(IHostEnvironment env, ModelLoadContext ctx
10701070
return new NaiveCalibrator(env, ctx);
10711071
}
10721072

1073-
public void SaveAsBinary(BinaryWriter writer)
1073+
void ICanSaveInBinaryFormat.SaveAsBinary(BinaryWriter writer)
10741074
{
10751075
ModelSaveContext.Save(writer, SaveCore);
10761076
}
@@ -1717,7 +1717,7 @@ private static PavCalibrator Create(IHostEnvironment env, ModelLoadContext ctx)
17171717
return new PavCalibrator(env, ctx);
17181718
}
17191719

1720-
public void SaveAsBinary(BinaryWriter writer)
1720+
void ICanSaveInBinaryFormat.SaveAsBinary(BinaryWriter writer)
17211721
{
17221722
ModelSaveContext.Save(writer, SaveCore);
17231723
}

0 commit comments

Comments
 (0)