Skip to content

Commit 7b714fd

Browse files
committed
Use explicit IInternalCatalog, thereby avoiding SubCatalogBase in public surface. Simplify CatalogUtils.
1 parent b733c99 commit 7b714fd

File tree

9 files changed

+121
-101
lines changed

9 files changed

+121
-101
lines changed

src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs

+40-40
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@ namespace Microsoft.ML
1414
/// A catalog of operations over data that are not transformers or estimators.
1515
/// This includes data loaders, saving, caching, filtering etc.
1616
/// </summary>
17-
public sealed class DataOperationsCatalog
17+
public sealed class DataOperationsCatalog : IInternalCatalog
1818
{
19-
[BestFriend]
20-
internal IHostEnvironment Environment { get; }
19+
IHostEnvironment IInternalCatalog.Environment => _env;
20+
private readonly IHostEnvironment _env;
2121

2222
internal DataOperationsCatalog(IHostEnvironment env)
2323
{
2424
Contracts.AssertValue(env);
25-
Environment = env;
25+
_env = env;
2626
}
2727

2828
/// <summary>
@@ -52,9 +52,9 @@ internal DataOperationsCatalog(IHostEnvironment env)
5252
public IDataView LoadFromEnumerable<TRow>(IEnumerable<TRow> data, SchemaDefinition schemaDefinition = null)
5353
where TRow : class
5454
{
55-
Environment.CheckValue(data, nameof(data));
56-
Environment.CheckValueOrNull(schemaDefinition);
57-
return DataViewConstructionUtils.CreateFromEnumerable(Environment, data, schemaDefinition);
55+
_env.CheckValue(data, nameof(data));
56+
_env.CheckValueOrNull(schemaDefinition);
57+
return DataViewConstructionUtils.CreateFromEnumerable(_env, data, schemaDefinition);
5858
}
5959

6060
/// <summary>
@@ -77,10 +77,10 @@ public IEnumerable<TRow> CreateEnumerable<TRow>(IDataView data, bool reuseRowObj
7777
bool ignoreMissingColumns = false, SchemaDefinition schemaDefinition = null)
7878
where TRow : class, new()
7979
{
80-
Environment.CheckValue(data, nameof(data));
81-
Environment.CheckValueOrNull(schemaDefinition);
80+
_env.CheckValue(data, nameof(data));
81+
_env.CheckValueOrNull(schemaDefinition);
8282

83-
var engine = new PipeEngine<TRow>(Environment, data, ignoreMissingColumns, schemaDefinition);
83+
var engine = new PipeEngine<TRow>(_env, data, ignoreMissingColumns, schemaDefinition);
8484
return engine.RunPipe(reuseRowObject);
8585
}
8686

@@ -109,9 +109,9 @@ public IDataView BootstrapSample(IDataView input,
109109
int? seed = null,
110110
bool complement = BootstrapSamplingTransformer.Defaults.Complement)
111111
{
112-
Environment.CheckValue(input, nameof(input));
112+
_env.CheckValue(input, nameof(input));
113113
return new BootstrapSamplingTransformer(
114-
Environment,
114+
_env,
115115
input,
116116
complement: complement,
117117
seed: (uint?) seed,
@@ -139,16 +139,16 @@ public IDataView BootstrapSample(IDataView input,
139139
/// </example>
140140
public IDataView Cache(IDataView input, params string[] columnsToPrefetch)
141141
{
142-
Environment.CheckValue(input, nameof(input));
143-
Environment.CheckValueOrNull(columnsToPrefetch);
142+
_env.CheckValue(input, nameof(input));
143+
_env.CheckValueOrNull(columnsToPrefetch);
144144

145145
int[] prefetch = new int[Utils.Size(columnsToPrefetch)];
146146
for (int i = 0; i < prefetch.Length; i++)
147147
{
148148
if (!input.Schema.TryGetColumnIndex(columnsToPrefetch[i], out prefetch[i]))
149-
throw Environment.ExceptSchemaMismatch(nameof(columnsToPrefetch), "prefetch", columnsToPrefetch[i]);
149+
throw _env.ExceptSchemaMismatch(nameof(columnsToPrefetch), "prefetch", columnsToPrefetch[i]);
150150
}
151-
return new CacheDataView(Environment, input, prefetch);
151+
return new CacheDataView(_env, input, prefetch);
152152
}
153153

154154
/// <summary>
@@ -171,14 +171,14 @@ public IDataView Cache(IDataView input, params string[] columnsToPrefetch)
171171
/// </example>
172172
public IDataView FilterRowsByColumn(IDataView input, string columnName, double lowerBound = double.NegativeInfinity, double upperBound = double.PositiveInfinity)
173173
{
174-
Environment.CheckValue(input, nameof(input));
175-
Environment.CheckNonEmpty(columnName, nameof(columnName));
176-
Environment.CheckParam(lowerBound < upperBound, nameof(upperBound), "Must be less than lowerBound");
174+
_env.CheckValue(input, nameof(input));
175+
_env.CheckNonEmpty(columnName, nameof(columnName));
176+
_env.CheckParam(lowerBound < upperBound, nameof(upperBound), "Must be less than lowerBound");
177177

178178
var type = input.Schema[columnName].Type;
179179
if (!(type is NumberDataViewType))
180-
throw Environment.ExceptSchemaMismatch(nameof(columnName), "filter", columnName, "number", type.ToString());
181-
return new RangeFilter(Environment, input, columnName, lowerBound, upperBound, false);
180+
throw _env.ExceptSchemaMismatch(nameof(columnName), "filter", columnName, "number", type.ToString());
181+
return new RangeFilter(_env, input, columnName, lowerBound, upperBound, false);
182182
}
183183

184184
/// <summary>
@@ -203,16 +203,16 @@ public IDataView FilterRowsByColumn(IDataView input, string columnName, double l
203203
/// </example>
204204
public IDataView FilterRowsByKeyColumnFraction(IDataView input, string columnName, double lowerBound = 0, double upperBound = 1)
205205
{
206-
Environment.CheckValue(input, nameof(input));
207-
Environment.CheckNonEmpty(columnName, nameof(columnName));
208-
Environment.CheckParam(0 <= lowerBound && lowerBound <= 1, nameof(lowerBound), "Must be in [0, 1]");
209-
Environment.CheckParam(0 <= upperBound && upperBound <= 2, nameof(upperBound), "Must be in [0, 2]");
210-
Environment.CheckParam(lowerBound <= upperBound, nameof(upperBound), "Must be no less than lowerBound");
206+
_env.CheckValue(input, nameof(input));
207+
_env.CheckNonEmpty(columnName, nameof(columnName));
208+
_env.CheckParam(0 <= lowerBound && lowerBound <= 1, nameof(lowerBound), "Must be in [0, 1]");
209+
_env.CheckParam(0 <= upperBound && upperBound <= 2, nameof(upperBound), "Must be in [0, 2]");
210+
_env.CheckParam(lowerBound <= upperBound, nameof(upperBound), "Must be no less than lowerBound");
211211

212212
var type = input.Schema[columnName].Type;
213213
if (type.GetKeyCount() == 0)
214-
throw Environment.ExceptSchemaMismatch(nameof(columnName), "filter", columnName, "KeyType", type.ToString());
215-
return new RangeFilter(Environment, input, columnName, lowerBound, upperBound, false);
214+
throw _env.ExceptSchemaMismatch(nameof(columnName), "filter", columnName, "KeyType", type.ToString());
215+
return new RangeFilter(_env, input, columnName, lowerBound, upperBound, false);
216216
}
217217

218218
/// <summary>
@@ -230,10 +230,10 @@ public IDataView FilterRowsByKeyColumnFraction(IDataView input, string columnNam
230230
/// </example>
231231
public IDataView FilterRowsByMissingValues(IDataView input, params string[] columns)
232232
{
233-
Environment.CheckValue(input, nameof(input));
234-
Environment.CheckUserArg(Utils.Size(columns) > 0, nameof(columns));
233+
_env.CheckValue(input, nameof(input));
234+
_env.CheckUserArg(Utils.Size(columns) > 0, nameof(columns));
235235

236-
return new NAFilter(Environment, input, complement: false, columns);
236+
return new NAFilter(_env, input, complement: false, columns);
237237
}
238238

239239
/// <summary>
@@ -268,8 +268,8 @@ public IDataView ShuffleRows(IDataView input,
268268
int shufflePoolSize = RowShufflingTransformer.Defaults.PoolRows,
269269
bool shuffleSource = !RowShufflingTransformer.Defaults.PoolOnly)
270270
{
271-
Environment.CheckValue(input, nameof(input));
272-
Environment.CheckUserArg(shufflePoolSize > 0, nameof(shufflePoolSize), "Must be positive");
271+
_env.CheckValue(input, nameof(input));
272+
_env.CheckUserArg(shufflePoolSize > 0, nameof(shufflePoolSize), "Must be positive");
273273

274274
var options = new RowShufflingTransformer.Options
275275
{
@@ -279,7 +279,7 @@ public IDataView ShuffleRows(IDataView input,
279279
ForceShuffleSeed = seed
280280
};
281281

282-
return new RowShufflingTransformer(Environment, options, input);
282+
return new RowShufflingTransformer(_env, options, input);
283283
}
284284

285285
/// <summary>
@@ -299,15 +299,15 @@ public IDataView ShuffleRows(IDataView input,
299299
/// </example>
300300
public IDataView SkipRows(IDataView input, long count)
301301
{
302-
Environment.CheckValue(input, nameof(input));
303-
Environment.CheckUserArg(count > 0, nameof(count), "Must be greater than zero.");
302+
_env.CheckValue(input, nameof(input));
303+
_env.CheckUserArg(count > 0, nameof(count), "Must be greater than zero.");
304304

305305
var options = new SkipTakeFilter.SkipOptions()
306306
{
307307
Count = count
308308
};
309309

310-
return new SkipTakeFilter(Environment, options, input);
310+
return new SkipTakeFilter(_env, options, input);
311311
}
312312

313313
/// <summary>
@@ -327,15 +327,15 @@ public IDataView SkipRows(IDataView input, long count)
327327
/// </example>
328328
public IDataView TakeRows(IDataView input, long count)
329329
{
330-
Environment.CheckValue(input, nameof(input));
331-
Environment.CheckUserArg(count > 0, nameof(count), "Must be greater than zero.");
330+
_env.CheckValue(input, nameof(input));
331+
_env.CheckUserArg(count > 0, nameof(count), "Must be greater than zero.");
332332

333333
var options = new SkipTakeFilter.TakeOptions()
334334
{
335335
Count = count
336336
};
337337

338-
return new SkipTakeFilter(Environment, options, input);
338+
return new SkipTakeFilter(_env, options, input);
339339
}
340340
}
341341
}

src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs

+11-13
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,17 @@ namespace Microsoft.ML
1010
/// <summary>
1111
/// An object serving as a 'catalog' of available model operations.
1212
/// </summary>
13-
public sealed class ModelOperationsCatalog
13+
public sealed class ModelOperationsCatalog : IInternalCatalog
1414
{
15-
/// <summary>
16-
/// This is a best friend because an extension method defined in another assembly needs this field.
17-
/// </summary>
18-
[BestFriend]
19-
internal IHostEnvironment Environment { get; }
15+
IHostEnvironment IInternalCatalog.Environment => _env;
16+
private readonly IHostEnvironment _env;
2017

2118
public ExplainabilityTransforms Explainability { get; }
2219

2320
internal ModelOperationsCatalog(IHostEnvironment env)
2421
{
2522
Contracts.AssertValue(env);
26-
Environment = env;
23+
_env = env;
2724

2825
Explainability = new ExplainabilityTransforms(this);
2926
}
@@ -33,25 +30,26 @@ internal ModelOperationsCatalog(IHostEnvironment env)
3330
/// </summary>
3431
/// <param name="model">The trained model to be saved.</param>
3532
/// <param name="stream">A writeable, seekable stream to save to.</param>
36-
public void Save(ITransformer model, Stream stream) => model.SaveTo(Environment, stream);
33+
public void Save(ITransformer model, Stream stream) => model.SaveTo(_env, stream);
3734

3835
/// <summary>
3936
/// Load the model from the stream.
4037
/// </summary>
4138
/// <param name="stream">A readable, seekable stream to load from.</param>
4239
/// <returns>The loaded model.</returns>
43-
public ITransformer Load(Stream stream) => TransformerChain.LoadFrom(Environment, stream);
40+
public ITransformer Load(Stream stream) => TransformerChain.LoadFrom(_env, stream);
4441

4542
/// <summary>
4643
/// The catalog of model explainability operations.
4744
/// </summary>
48-
public sealed class ExplainabilityTransforms
45+
public sealed class ExplainabilityTransforms : IInternalCatalog
4946
{
50-
internal IHostEnvironment Environment { get; }
47+
IHostEnvironment IInternalCatalog.Environment => _env;
48+
private readonly IHostEnvironment _env;
5149

5250
internal ExplainabilityTransforms(ModelOperationsCatalog owner)
5351
{
54-
Environment = owner.Environment;
52+
_env = owner._env;
5553
}
5654
}
5755

@@ -68,7 +66,7 @@ public PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(ITransfor
6866
where TSrc : class
6967
where TDst : class, new()
7068
{
71-
return new PredictionEngine<TSrc, TDst>(Environment, transformer, false, inputSchemaDefinition, outputSchemaDefinition);
69+
return new PredictionEngine<TSrc, TDst>(_env, transformer, false, inputSchemaDefinition, outputSchemaDefinition);
7270
}
7371
}
7472
}

src/Microsoft.ML.Data/TrainCatalog.cs

+11-7
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ namespace Microsoft.ML
1616
/// "area" of machine learning. A subclass would represent a particular task in machine learning. The idea
1717
/// is that a user can instantiate that particular area, and get trainers and evaluators.
1818
/// </summary>
19-
public abstract class TrainCatalogBase
19+
public abstract class TrainCatalogBase : IInternalCatalog
2020
{
21+
IHostEnvironment IInternalCatalog.Environment => Environment;
22+
2123
[BestFriend]
22-
internal IHostEnvironment Environment { get; }
24+
private protected IHostEnvironment Environment { get; }
2325

2426
/// <summary>
2527
/// A pair of datasets, for the train and test set.
@@ -245,8 +247,10 @@ private void EnsureGroupPreservationColumn(ref IDataView data, ref string sampli
245247
/// through <see cref="CatalogUtils"/> to get more "hidden" information from this object,
246248
/// for example, the environment.
247249
/// </summary>
248-
public abstract class CatalogInstantiatorBase
250+
public abstract class CatalogInstantiatorBase : IInternalCatalog
249251
{
252+
IHostEnvironment IInternalCatalog.Environment => Owner.GetEnvironment();
253+
250254
[BestFriend]
251255
internal TrainCatalogBase Owner { get; }
252256

@@ -402,7 +406,7 @@ public NaiveCalibratorEstimator Naive(
402406
string labelColumnName = DefaultColumnNames.Label,
403407
string scoreColumnName = DefaultColumnNames.Score)
404408
{
405-
return new NaiveCalibratorEstimator(Owner.Environment, labelColumnName, scoreColumnName);
409+
return new NaiveCalibratorEstimator(Owner.GetEnvironment(), labelColumnName, scoreColumnName);
406410
}
407411
/// <summary>
408412
/// Adds probability column by training <a href="https://en.wikipedia.org/wiki/Platt_scaling">platt calibrator</a>.
@@ -422,7 +426,7 @@ public PlattCalibratorEstimator Platt(
422426
string scoreColumnName = DefaultColumnNames.Score,
423427
string exampleWeightColumnName = null)
424428
{
425-
return new PlattCalibratorEstimator(Owner.Environment, labelColumnName, scoreColumnName, exampleWeightColumnName);
429+
return new PlattCalibratorEstimator(Owner.GetEnvironment(), labelColumnName, scoreColumnName, exampleWeightColumnName);
426430
}
427431

428432
/// <summary>
@@ -443,7 +447,7 @@ public FixedPlattCalibratorEstimator Platt(
443447
double offset,
444448
string scoreColumnName = DefaultColumnNames.Score)
445449
{
446-
return new FixedPlattCalibratorEstimator(Owner.Environment, slope, offset, scoreColumnName);
450+
return new FixedPlattCalibratorEstimator(Owner.GetEnvironment(), slope, offset, scoreColumnName);
447451
}
448452

449453
/// <summary>
@@ -468,7 +472,7 @@ public IsotonicCalibratorEstimator Isotonic(
468472
string scoreColumnName = DefaultColumnNames.Score,
469473
string exampleWeightColumnName = null)
470474
{
471-
return new IsotonicCalibratorEstimator(Owner.Environment, labelColumnName, scoreColumnName, exampleWeightColumnName);
475+
return new IsotonicCalibratorEstimator(Owner.GetEnvironment(), labelColumnName, scoreColumnName, exampleWeightColumnName);
472476
}
473477
}
474478
}

src/Microsoft.ML.Data/Transforms/CatalogUtils.cs

+14-8
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,23 @@
55
namespace Microsoft.ML.Data
66
{
77
/// <summary>
8-
/// Set of extension methods to extract <see cref="IHostEnvironment"/> from various catalog classes.
8+
/// Convenience method to more easily extract an <see cref="IHostEnvironment"/> from an <see cref="IInternalCatalog"/>
9+
/// implementor without requiring an explicit cast.
910
/// </summary>
1011
[BestFriend]
1112
internal static class CatalogUtils
1213
{
13-
public static IHostEnvironment GetEnvironment(this TransformsCatalog catalog) => Contracts.CheckRef(catalog, nameof(catalog)).Environment;
14-
public static IHostEnvironment GetEnvironment(this TransformsCatalog.SubCatalogBase subCatalog) => Contracts.CheckRef(subCatalog, nameof(subCatalog)).Environment;
15-
public static IHostEnvironment GetEnvironment(this ModelOperationsCatalog catalog) => Contracts.CheckRef(catalog, nameof(catalog)).Environment;
16-
public static IHostEnvironment GetEnvironment(this ModelOperationsCatalog.ExplainabilityTransforms catalog) => Contracts.CheckRef(catalog, nameof(catalog)).Environment;
17-
public static IHostEnvironment GetEnvironment(this DataOperationsCatalog catalog) => Contracts.CheckRef(catalog, nameof(catalog)).Environment;
18-
public static IHostEnvironment GetEnvironment(TrainCatalogBase.CatalogInstantiatorBase obj) => Contracts.CheckRef(obj, nameof(obj)).Owner.Environment;
19-
public static IHostEnvironment GetEnvironment(TrainCatalogBase catalog) => Contracts.CheckRef(catalog, nameof(catalog)).Environment;
14+
public static IHostEnvironment GetEnvironment(this IInternalCatalog catalog) => Contracts.CheckRef(catalog, nameof(catalog)).Environment;
15+
}
16+
17+
/// <summary>
18+
/// An internal interface for the benefit of those <see cref="IHostEnvironment"/>-bearing objects accessible through
19+
/// <see cref="MLContext"/>. Because this is meant to consumed by component authors implementations of this interface
20+
/// should be explicit.
21+
/// </summary>
22+
[BestFriend]
23+
internal interface IInternalCatalog
24+
{
25+
IHostEnvironment Environment { get; }
2026
}
2127
}

0 commit comments

Comments
 (0)