Skip to content

Commit a4bfd93

Browse files
authored
Hide more of Microsoft.ML.Data (#2842)
* Get rid of unnecessary SubCatalogBase base class. * Clean up unintentional usage of protected internal. * ValueToKeyMappingEstimator.ColumnOptions now not publicly subclassable. * Have the public abstract classes for estimator/transformers/etc. be practically unsubclassable. * Internalize the infrastructure surrounding TransformWrapper. * Internalize DefaultColumnNames utility class for storing const string literals. * Use explicit IInternalCatalog, thereby avoiding SubCatalogBase in public surface. * Simplify CatalogUtils. * Copy documentation from IEstimator to two implementations of GetOutputSchema.
1 parent 768e3bf commit a4bfd93

26 files changed

+273
-180
lines changed

src/Microsoft.ML.Data/Commands/DefaultColumnNames.cs

+7-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44

55
namespace Microsoft.ML.Data
66
{
7-
public static class DefaultColumnNames
7+
/// <summary>
8+
/// A set of string literals intended to be "canonical" names for column names intended for particular purpose.
9+
/// While not part of the public API surface, its primary purpose is intended to be used in such a way as to encourage
10+
/// uniformity on the public API surface, wherever it is judged where columns with default names should be consumed.
11+
/// </summary>
12+
[BestFriend]
13+
internal static class DefaultColumnNames
814
{
915
public const string Features = "Features";
1016
public const string Label = "Label";

src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs

+40-40
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@ namespace Microsoft.ML
1414
/// A catalog of operations over data that are not transformers or estimators.
1515
/// This includes data loaders, saving, caching, filtering etc.
1616
/// </summary>
17-
public sealed class DataOperationsCatalog
17+
public sealed class DataOperationsCatalog : IInternalCatalog
1818
{
19-
[BestFriend]
20-
internal IHostEnvironment Environment { get; }
19+
IHostEnvironment IInternalCatalog.Environment => _env;
20+
private readonly IHostEnvironment _env;
2121

2222
internal DataOperationsCatalog(IHostEnvironment env)
2323
{
2424
Contracts.AssertValue(env);
25-
Environment = env;
25+
_env = env;
2626
}
2727

2828
/// <summary>
@@ -52,9 +52,9 @@ internal DataOperationsCatalog(IHostEnvironment env)
5252
public IDataView LoadFromEnumerable<TRow>(IEnumerable<TRow> data, SchemaDefinition schemaDefinition = null)
5353
where TRow : class
5454
{
55-
Environment.CheckValue(data, nameof(data));
56-
Environment.CheckValueOrNull(schemaDefinition);
57-
return DataViewConstructionUtils.CreateFromEnumerable(Environment, data, schemaDefinition);
55+
_env.CheckValue(data, nameof(data));
56+
_env.CheckValueOrNull(schemaDefinition);
57+
return DataViewConstructionUtils.CreateFromEnumerable(_env, data, schemaDefinition);
5858
}
5959

6060
/// <summary>
@@ -77,10 +77,10 @@ public IEnumerable<TRow> CreateEnumerable<TRow>(IDataView data, bool reuseRowObj
7777
bool ignoreMissingColumns = false, SchemaDefinition schemaDefinition = null)
7878
where TRow : class, new()
7979
{
80-
Environment.CheckValue(data, nameof(data));
81-
Environment.CheckValueOrNull(schemaDefinition);
80+
_env.CheckValue(data, nameof(data));
81+
_env.CheckValueOrNull(schemaDefinition);
8282

83-
var engine = new PipeEngine<TRow>(Environment, data, ignoreMissingColumns, schemaDefinition);
83+
var engine = new PipeEngine<TRow>(_env, data, ignoreMissingColumns, schemaDefinition);
8484
return engine.RunPipe(reuseRowObject);
8585
}
8686

@@ -109,9 +109,9 @@ public IDataView BootstrapSample(IDataView input,
109109
int? seed = null,
110110
bool complement = BootstrapSamplingTransformer.Defaults.Complement)
111111
{
112-
Environment.CheckValue(input, nameof(input));
112+
_env.CheckValue(input, nameof(input));
113113
return new BootstrapSamplingTransformer(
114-
Environment,
114+
_env,
115115
input,
116116
complement: complement,
117117
seed: (uint?) seed,
@@ -139,16 +139,16 @@ public IDataView BootstrapSample(IDataView input,
139139
/// </example>
140140
public IDataView Cache(IDataView input, params string[] columnsToPrefetch)
141141
{
142-
Environment.CheckValue(input, nameof(input));
143-
Environment.CheckValueOrNull(columnsToPrefetch);
142+
_env.CheckValue(input, nameof(input));
143+
_env.CheckValueOrNull(columnsToPrefetch);
144144

145145
int[] prefetch = new int[Utils.Size(columnsToPrefetch)];
146146
for (int i = 0; i < prefetch.Length; i++)
147147
{
148148
if (!input.Schema.TryGetColumnIndex(columnsToPrefetch[i], out prefetch[i]))
149-
throw Environment.ExceptSchemaMismatch(nameof(columnsToPrefetch), "prefetch", columnsToPrefetch[i]);
149+
throw _env.ExceptSchemaMismatch(nameof(columnsToPrefetch), "prefetch", columnsToPrefetch[i]);
150150
}
151-
return new CacheDataView(Environment, input, prefetch);
151+
return new CacheDataView(_env, input, prefetch);
152152
}
153153

154154
/// <summary>
@@ -171,14 +171,14 @@ public IDataView Cache(IDataView input, params string[] columnsToPrefetch)
171171
/// </example>
172172
public IDataView FilterRowsByColumn(IDataView input, string columnName, double lowerBound = double.NegativeInfinity, double upperBound = double.PositiveInfinity)
173173
{
174-
Environment.CheckValue(input, nameof(input));
175-
Environment.CheckNonEmpty(columnName, nameof(columnName));
176-
Environment.CheckParam(lowerBound < upperBound, nameof(upperBound), "Must be less than lowerBound");
174+
_env.CheckValue(input, nameof(input));
175+
_env.CheckNonEmpty(columnName, nameof(columnName));
176+
_env.CheckParam(lowerBound < upperBound, nameof(upperBound), "Must be less than lowerBound");
177177

178178
var type = input.Schema[columnName].Type;
179179
if (!(type is NumberDataViewType))
180-
throw Environment.ExceptSchemaMismatch(nameof(columnName), "filter", columnName, "number", type.ToString());
181-
return new RangeFilter(Environment, input, columnName, lowerBound, upperBound, false);
180+
throw _env.ExceptSchemaMismatch(nameof(columnName), "filter", columnName, "number", type.ToString());
181+
return new RangeFilter(_env, input, columnName, lowerBound, upperBound, false);
182182
}
183183

184184
/// <summary>
@@ -203,16 +203,16 @@ public IDataView FilterRowsByColumn(IDataView input, string columnName, double l
203203
/// </example>
204204
public IDataView FilterRowsByKeyColumnFraction(IDataView input, string columnName, double lowerBound = 0, double upperBound = 1)
205205
{
206-
Environment.CheckValue(input, nameof(input));
207-
Environment.CheckNonEmpty(columnName, nameof(columnName));
208-
Environment.CheckParam(0 <= lowerBound && lowerBound <= 1, nameof(lowerBound), "Must be in [0, 1]");
209-
Environment.CheckParam(0 <= upperBound && upperBound <= 2, nameof(upperBound), "Must be in [0, 2]");
210-
Environment.CheckParam(lowerBound <= upperBound, nameof(upperBound), "Must be no less than lowerBound");
206+
_env.CheckValue(input, nameof(input));
207+
_env.CheckNonEmpty(columnName, nameof(columnName));
208+
_env.CheckParam(0 <= lowerBound && lowerBound <= 1, nameof(lowerBound), "Must be in [0, 1]");
209+
_env.CheckParam(0 <= upperBound && upperBound <= 2, nameof(upperBound), "Must be in [0, 2]");
210+
_env.CheckParam(lowerBound <= upperBound, nameof(upperBound), "Must be no less than lowerBound");
211211

212212
var type = input.Schema[columnName].Type;
213213
if (type.GetKeyCount() == 0)
214-
throw Environment.ExceptSchemaMismatch(nameof(columnName), "filter", columnName, "KeyType", type.ToString());
215-
return new RangeFilter(Environment, input, columnName, lowerBound, upperBound, false);
214+
throw _env.ExceptSchemaMismatch(nameof(columnName), "filter", columnName, "KeyType", type.ToString());
215+
return new RangeFilter(_env, input, columnName, lowerBound, upperBound, false);
216216
}
217217

218218
/// <summary>
@@ -230,10 +230,10 @@ public IDataView FilterRowsByKeyColumnFraction(IDataView input, string columnNam
230230
/// </example>
231231
public IDataView FilterRowsByMissingValues(IDataView input, params string[] columns)
232232
{
233-
Environment.CheckValue(input, nameof(input));
234-
Environment.CheckUserArg(Utils.Size(columns) > 0, nameof(columns));
233+
_env.CheckValue(input, nameof(input));
234+
_env.CheckUserArg(Utils.Size(columns) > 0, nameof(columns));
235235

236-
return new NAFilter(Environment, input, complement: false, columns);
236+
return new NAFilter(_env, input, complement: false, columns);
237237
}
238238

239239
/// <summary>
@@ -268,8 +268,8 @@ public IDataView ShuffleRows(IDataView input,
268268
int shufflePoolSize = RowShufflingTransformer.Defaults.PoolRows,
269269
bool shuffleSource = !RowShufflingTransformer.Defaults.PoolOnly)
270270
{
271-
Environment.CheckValue(input, nameof(input));
272-
Environment.CheckUserArg(shufflePoolSize > 0, nameof(shufflePoolSize), "Must be positive");
271+
_env.CheckValue(input, nameof(input));
272+
_env.CheckUserArg(shufflePoolSize > 0, nameof(shufflePoolSize), "Must be positive");
273273

274274
var options = new RowShufflingTransformer.Options
275275
{
@@ -279,7 +279,7 @@ public IDataView ShuffleRows(IDataView input,
279279
ForceShuffleSeed = seed
280280
};
281281

282-
return new RowShufflingTransformer(Environment, options, input);
282+
return new RowShufflingTransformer(_env, options, input);
283283
}
284284

285285
/// <summary>
@@ -299,15 +299,15 @@ public IDataView ShuffleRows(IDataView input,
299299
/// </example>
300300
public IDataView SkipRows(IDataView input, long count)
301301
{
302-
Environment.CheckValue(input, nameof(input));
303-
Environment.CheckUserArg(count > 0, nameof(count), "Must be greater than zero.");
302+
_env.CheckValue(input, nameof(input));
303+
_env.CheckUserArg(count > 0, nameof(count), "Must be greater than zero.");
304304

305305
var options = new SkipTakeFilter.SkipOptions()
306306
{
307307
Count = count
308308
};
309309

310-
return new SkipTakeFilter(Environment, options, input);
310+
return new SkipTakeFilter(_env, options, input);
311311
}
312312

313313
/// <summary>
@@ -327,15 +327,15 @@ public IDataView SkipRows(IDataView input, long count)
327327
/// </example>
328328
public IDataView TakeRows(IDataView input, long count)
329329
{
330-
Environment.CheckValue(input, nameof(input));
331-
Environment.CheckUserArg(count > 0, nameof(count), "Must be greater than zero.");
330+
_env.CheckValue(input, nameof(input));
331+
_env.CheckUserArg(count > 0, nameof(count), "Must be greater than zero.");
332332

333333
var options = new SkipTakeFilter.TakeOptions()
334334
{
335335
Count = count
336336
};
337337

338-
return new SkipTakeFilter(Environment, options, input);
338+
return new SkipTakeFilter(_env, options, input);
339339
}
340340
}
341341
}

src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ namespace Microsoft.ML.Data.DataLoadSave
1313
/// It will pretend that all vector sizes are equal to 10, all key value counts are equal to 10,
1414
/// and all values are defaults (for annotations).
1515
/// </summary>
16+
[BestFriend]
1617
internal static class FakeSchemaFactory
1718
{
1819
private const int AllVectorSizes = 10;

src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs

+12-6
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@
1515

1616
namespace Microsoft.ML.Data
1717
{
18-
// REVIEW: this class is public, as long as the Wrappers.cs in tests still rely on it.
19-
// It needs to become internal.
20-
public sealed class TransformWrapper : ITransformer
18+
/// <summary>
19+
/// This is a shim class to present the legacy <see cref="IDataTransform"/> interface as an <see cref="ITransformer"/>.
20+
/// Note that there are some important differences in usages that make this shimming somewhat non-seemless, so the goal
21+
/// would be gradual removal of this as we do away with <see cref="IDataTransform"/> based code.
22+
/// </summary>
23+
[BestFriend]
24+
internal sealed class TransformWrapper : ITransformer
2125
{
2226
internal const string LoaderSignature = "TransformWrapper";
2327
private const string TransformDirTemplate = "Step_{0:000}";
@@ -148,11 +152,13 @@ IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
148152
/// <summary>
149153
/// Estimator for trained wrapped transformers.
150154
/// </summary>
151-
public abstract class TrainedWrapperEstimatorBase : IEstimator<TransformWrapper>
155+
internal abstract class TrainedWrapperEstimatorBase : IEstimator<TransformWrapper>
152156
{
153-
protected readonly IHost Host;
157+
[BestFriend]
158+
private protected readonly IHost Host;
154159

155-
protected TrainedWrapperEstimatorBase(IHost host)
160+
[BestFriend]
161+
private protected TrainedWrapperEstimatorBase(IHost host)
156162
{
157163
Contracts.CheckValue(host, nameof(host));
158164
Host = host;

src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs

+6-3
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@ namespace Microsoft.ML.Data
1616
public abstract class TrivialEstimator<TTransformer> : IEstimator<TTransformer>
1717
where TTransformer : class, ITransformer
1818
{
19-
protected readonly IHost Host;
20-
protected readonly TTransformer Transformer;
19+
[BestFriend]
20+
private protected readonly IHost Host;
21+
[BestFriend]
22+
private protected readonly TTransformer Transformer;
2123

22-
protected TrivialEstimator(IHost host, TTransformer transformer)
24+
[BestFriend]
25+
private protected TrivialEstimator(IHost host, TTransformer transformer)
2326
{
2427
Contracts.AssertValue(host);
2528

src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,7 @@ public abstract partial class AnnotationInfo
740740

741741
internal abstract Delegate GetGetterDelegate();
742742

743-
protected AnnotationInfo(string kind, DataViewType annotationType)
743+
private protected AnnotationInfo(string kind, DataViewType annotationType)
744744
{
745745
Contracts.AssertValueOrNull(annotationType);
746746
Contracts.AssertNonEmpty(kind);

src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs

+13-22
Original file line numberDiff line numberDiff line change
@@ -10,55 +10,46 @@ namespace Microsoft.ML
1010
/// <summary>
1111
/// An object serving as a 'catalog' of available model operations.
1212
/// </summary>
13-
public sealed class ModelOperationsCatalog
13+
public sealed class ModelOperationsCatalog : IInternalCatalog
1414
{
15-
/// <summary>
16-
/// This is a best friend because an extension method defined in another assembly needs this field.
17-
/// </summary>
18-
[BestFriend]
19-
internal IHostEnvironment Environment { get; }
15+
IHostEnvironment IInternalCatalog.Environment => _env;
16+
private readonly IHostEnvironment _env;
2017

2118
public ExplainabilityTransforms Explainability { get; }
2219

2320
internal ModelOperationsCatalog(IHostEnvironment env)
2421
{
2522
Contracts.AssertValue(env);
26-
Environment = env;
23+
_env = env;
2724

2825
Explainability = new ExplainabilityTransforms(this);
2926
}
3027

31-
public abstract class SubCatalogBase
32-
{
33-
internal IHostEnvironment Environment { get; }
34-
35-
protected SubCatalogBase(ModelOperationsCatalog owner)
36-
{
37-
Environment = owner.Environment;
38-
}
39-
}
40-
4128
/// <summary>
4229
/// Save the model to the stream.
4330
/// </summary>
4431
/// <param name="model">The trained model to be saved.</param>
4532
/// <param name="stream">A writeable, seekable stream to save to.</param>
46-
public void Save(ITransformer model, Stream stream) => model.SaveTo(Environment, stream);
33+
public void Save(ITransformer model, Stream stream) => model.SaveTo(_env, stream);
4734

4835
/// <summary>
4936
/// Load the model from the stream.
5037
/// </summary>
5138
/// <param name="stream">A readable, seekable stream to load from.</param>
5239
/// <returns>The loaded model.</returns>
53-
public ITransformer Load(Stream stream) => TransformerChain.LoadFrom(Environment, stream);
40+
public ITransformer Load(Stream stream) => TransformerChain.LoadFrom(_env, stream);
5441

5542
/// <summary>
5643
/// The catalog of model explainability operations.
5744
/// </summary>
58-
public sealed class ExplainabilityTransforms : SubCatalogBase
45+
public sealed class ExplainabilityTransforms : IInternalCatalog
5946
{
60-
internal ExplainabilityTransforms(ModelOperationsCatalog owner) : base(owner)
47+
IHostEnvironment IInternalCatalog.Environment => _env;
48+
private readonly IHostEnvironment _env;
49+
50+
internal ExplainabilityTransforms(ModelOperationsCatalog owner)
6151
{
52+
_env = owner._env;
6253
}
6354
}
6455

@@ -75,7 +66,7 @@ public PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(ITransfor
7566
where TSrc : class
7667
where TDst : class, new()
7768
{
78-
return new PredictionEngine<TSrc, TDst>(Environment, transformer, false, inputSchemaDefinition, outputSchemaDefinition);
69+
return new PredictionEngine<TSrc, TDst>(_env, transformer, false, inputSchemaDefinition, outputSchemaDefinition);
7970
}
8071
}
8172
}

src/Microsoft.ML.Data/Prediction/PredictionEngine.cs

+5-3
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ public abstract class PredictionEngineBase<TSrc, TDst> : IDisposable
103103
/// <summary>
104104
/// Provides output schema.
105105
/// </summary>
106-
public DataViewSchema OutputSchema;
106+
public DataViewSchema OutputSchema { get; }
107107

108108
[BestFriend]
109109
private protected ITransformer Transformer { get; }
@@ -180,9 +180,11 @@ public TDst Predict(TSrc example)
180180
return result;
181181
}
182182

183-
protected void ExtractValues(TSrc example) => _inputRow.ExtractValues(example);
183+
[BestFriend]
184+
private protected void ExtractValues(TSrc example) => _inputRow.ExtractValues(example);
184185

185-
protected void FillValues(TDst prediction) => _outputRow.FillValues(prediction);
186+
[BestFriend]
187+
private protected void FillValues(TDst prediction) => _outputRow.FillValues(prediction);
186188

187189
/// <summary>
188190
/// Run prediction pipeline on one example.

0 commit comments

Comments
 (0)