@@ -14,15 +14,15 @@ namespace Microsoft.ML
14
14
/// A catalog of operations over data that are not transformers or estimators.
15
15
/// This includes data loaders, saving, caching, filtering etc.
16
16
/// </summary>
17
- public sealed class DataOperationsCatalog
17
+ public sealed class DataOperationsCatalog : IInternalCatalog
18
18
{
19
- [ BestFriend ]
20
- internal IHostEnvironment Environment { get ; }
19
+ IHostEnvironment IInternalCatalog . Environment => _env ;
20
+ private readonly IHostEnvironment _env ;
21
21
22
22
internal DataOperationsCatalog ( IHostEnvironment env )
23
23
{
24
24
Contracts . AssertValue ( env ) ;
25
- Environment = env ;
25
+ _env = env ;
26
26
}
27
27
28
28
/// <summary>
@@ -52,9 +52,9 @@ internal DataOperationsCatalog(IHostEnvironment env)
52
52
public IDataView LoadFromEnumerable < TRow > ( IEnumerable < TRow > data , SchemaDefinition schemaDefinition = null )
53
53
where TRow : class
54
54
{
55
- Environment . CheckValue ( data , nameof ( data ) ) ;
56
- Environment . CheckValueOrNull ( schemaDefinition ) ;
57
- return DataViewConstructionUtils . CreateFromEnumerable ( Environment , data , schemaDefinition ) ;
55
+ _env . CheckValue ( data , nameof ( data ) ) ;
56
+ _env . CheckValueOrNull ( schemaDefinition ) ;
57
+ return DataViewConstructionUtils . CreateFromEnumerable ( _env , data , schemaDefinition ) ;
58
58
}
59
59
60
60
/// <summary>
@@ -77,10 +77,10 @@ public IEnumerable<TRow> CreateEnumerable<TRow>(IDataView data, bool reuseRowObj
77
77
bool ignoreMissingColumns = false , SchemaDefinition schemaDefinition = null )
78
78
where TRow : class , new ( )
79
79
{
80
- Environment . CheckValue ( data , nameof ( data ) ) ;
81
- Environment . CheckValueOrNull ( schemaDefinition ) ;
80
+ _env . CheckValue ( data , nameof ( data ) ) ;
81
+ _env . CheckValueOrNull ( schemaDefinition ) ;
82
82
83
- var engine = new PipeEngine < TRow > ( Environment , data , ignoreMissingColumns , schemaDefinition ) ;
83
+ var engine = new PipeEngine < TRow > ( _env , data , ignoreMissingColumns , schemaDefinition ) ;
84
84
return engine . RunPipe ( reuseRowObject ) ;
85
85
}
86
86
@@ -109,9 +109,9 @@ public IDataView BootstrapSample(IDataView input,
109
109
int ? seed = null ,
110
110
bool complement = BootstrapSamplingTransformer . Defaults . Complement )
111
111
{
112
- Environment . CheckValue ( input , nameof ( input ) ) ;
112
+ _env . CheckValue ( input , nameof ( input ) ) ;
113
113
return new BootstrapSamplingTransformer (
114
- Environment ,
114
+ _env ,
115
115
input ,
116
116
complement : complement ,
117
117
seed : ( uint ? ) seed ,
@@ -139,16 +139,16 @@ public IDataView BootstrapSample(IDataView input,
139
139
/// </example>
140
140
public IDataView Cache ( IDataView input , params string [ ] columnsToPrefetch )
141
141
{
142
- Environment . CheckValue ( input , nameof ( input ) ) ;
143
- Environment . CheckValueOrNull ( columnsToPrefetch ) ;
142
+ _env . CheckValue ( input , nameof ( input ) ) ;
143
+ _env . CheckValueOrNull ( columnsToPrefetch ) ;
144
144
145
145
int [ ] prefetch = new int [ Utils . Size ( columnsToPrefetch ) ] ;
146
146
for ( int i = 0 ; i < prefetch . Length ; i ++ )
147
147
{
148
148
if ( ! input . Schema . TryGetColumnIndex ( columnsToPrefetch [ i ] , out prefetch [ i ] ) )
149
- throw Environment . ExceptSchemaMismatch ( nameof ( columnsToPrefetch ) , "prefetch" , columnsToPrefetch [ i ] ) ;
149
+ throw _env . ExceptSchemaMismatch ( nameof ( columnsToPrefetch ) , "prefetch" , columnsToPrefetch [ i ] ) ;
150
150
}
151
- return new CacheDataView ( Environment , input , prefetch ) ;
151
+ return new CacheDataView ( _env , input , prefetch ) ;
152
152
}
153
153
154
154
/// <summary>
@@ -171,14 +171,14 @@ public IDataView Cache(IDataView input, params string[] columnsToPrefetch)
171
171
/// </example>
172
172
public IDataView FilterRowsByColumn ( IDataView input , string columnName , double lowerBound = double . NegativeInfinity , double upperBound = double . PositiveInfinity )
173
173
{
174
- Environment . CheckValue ( input , nameof ( input ) ) ;
175
- Environment . CheckNonEmpty ( columnName , nameof ( columnName ) ) ;
176
- Environment . CheckParam ( lowerBound < upperBound , nameof ( upperBound ) , "Must be less than lowerBound" ) ;
174
+ _env . CheckValue ( input , nameof ( input ) ) ;
175
+ _env . CheckNonEmpty ( columnName , nameof ( columnName ) ) ;
176
+ _env . CheckParam ( lowerBound < upperBound , nameof ( upperBound ) , "Must be less than lowerBound" ) ;
177
177
178
178
var type = input . Schema [ columnName ] . Type ;
179
179
if ( ! ( type is NumberDataViewType ) )
180
- throw Environment . ExceptSchemaMismatch ( nameof ( columnName ) , "filter" , columnName , "number" , type . ToString ( ) ) ;
181
- return new RangeFilter ( Environment , input , columnName , lowerBound , upperBound , false ) ;
180
+ throw _env . ExceptSchemaMismatch ( nameof ( columnName ) , "filter" , columnName , "number" , type . ToString ( ) ) ;
181
+ return new RangeFilter ( _env , input , columnName , lowerBound , upperBound , false ) ;
182
182
}
183
183
184
184
/// <summary>
@@ -203,16 +203,16 @@ public IDataView FilterRowsByColumn(IDataView input, string columnName, double l
203
203
/// </example>
204
204
public IDataView FilterRowsByKeyColumnFraction ( IDataView input , string columnName , double lowerBound = 0 , double upperBound = 1 )
205
205
{
206
- Environment . CheckValue ( input , nameof ( input ) ) ;
207
- Environment . CheckNonEmpty ( columnName , nameof ( columnName ) ) ;
208
- Environment . CheckParam ( 0 <= lowerBound && lowerBound <= 1 , nameof ( lowerBound ) , "Must be in [0, 1]" ) ;
209
- Environment . CheckParam ( 0 <= upperBound && upperBound <= 2 , nameof ( upperBound ) , "Must be in [0, 2]" ) ;
210
- Environment . CheckParam ( lowerBound <= upperBound , nameof ( upperBound ) , "Must be no less than lowerBound" ) ;
206
+ _env . CheckValue ( input , nameof ( input ) ) ;
207
+ _env . CheckNonEmpty ( columnName , nameof ( columnName ) ) ;
208
+ _env . CheckParam ( 0 <= lowerBound && lowerBound <= 1 , nameof ( lowerBound ) , "Must be in [0, 1]" ) ;
209
+ _env . CheckParam ( 0 <= upperBound && upperBound <= 2 , nameof ( upperBound ) , "Must be in [0, 2]" ) ;
210
+ _env . CheckParam ( lowerBound <= upperBound , nameof ( upperBound ) , "Must be no less than lowerBound" ) ;
211
211
212
212
var type = input . Schema [ columnName ] . Type ;
213
213
if ( type . GetKeyCount ( ) == 0 )
214
- throw Environment . ExceptSchemaMismatch ( nameof ( columnName ) , "filter" , columnName , "KeyType" , type . ToString ( ) ) ;
215
- return new RangeFilter ( Environment , input , columnName , lowerBound , upperBound , false ) ;
214
+ throw _env . ExceptSchemaMismatch ( nameof ( columnName ) , "filter" , columnName , "KeyType" , type . ToString ( ) ) ;
215
+ return new RangeFilter ( _env , input , columnName , lowerBound , upperBound , false ) ;
216
216
}
217
217
218
218
/// <summary>
@@ -230,10 +230,10 @@ public IDataView FilterRowsByKeyColumnFraction(IDataView input, string columnNam
230
230
/// </example>
231
231
public IDataView FilterRowsByMissingValues ( IDataView input , params string [ ] columns )
232
232
{
233
- Environment . CheckValue ( input , nameof ( input ) ) ;
234
- Environment . CheckUserArg ( Utils . Size ( columns ) > 0 , nameof ( columns ) ) ;
233
+ _env . CheckValue ( input , nameof ( input ) ) ;
234
+ _env . CheckUserArg ( Utils . Size ( columns ) > 0 , nameof ( columns ) ) ;
235
235
236
- return new NAFilter ( Environment , input , complement : false , columns ) ;
236
+ return new NAFilter ( _env , input , complement : false , columns ) ;
237
237
}
238
238
239
239
/// <summary>
@@ -268,8 +268,8 @@ public IDataView ShuffleRows(IDataView input,
268
268
int shufflePoolSize = RowShufflingTransformer . Defaults . PoolRows ,
269
269
bool shuffleSource = ! RowShufflingTransformer . Defaults . PoolOnly )
270
270
{
271
- Environment . CheckValue ( input , nameof ( input ) ) ;
272
- Environment . CheckUserArg ( shufflePoolSize > 0 , nameof ( shufflePoolSize ) , "Must be positive" ) ;
271
+ _env . CheckValue ( input , nameof ( input ) ) ;
272
+ _env . CheckUserArg ( shufflePoolSize > 0 , nameof ( shufflePoolSize ) , "Must be positive" ) ;
273
273
274
274
var options = new RowShufflingTransformer . Options
275
275
{
@@ -279,7 +279,7 @@ public IDataView ShuffleRows(IDataView input,
279
279
ForceShuffleSeed = seed
280
280
} ;
281
281
282
- return new RowShufflingTransformer ( Environment , options , input ) ;
282
+ return new RowShufflingTransformer ( _env , options , input ) ;
283
283
}
284
284
285
285
/// <summary>
@@ -299,15 +299,15 @@ public IDataView ShuffleRows(IDataView input,
299
299
/// </example>
300
300
public IDataView SkipRows ( IDataView input , long count )
301
301
{
302
- Environment . CheckValue ( input , nameof ( input ) ) ;
303
- Environment . CheckUserArg ( count > 0 , nameof ( count ) , "Must be greater than zero." ) ;
302
+ _env . CheckValue ( input , nameof ( input ) ) ;
303
+ _env . CheckUserArg ( count > 0 , nameof ( count ) , "Must be greater than zero." ) ;
304
304
305
305
var options = new SkipTakeFilter . SkipOptions ( )
306
306
{
307
307
Count = count
308
308
} ;
309
309
310
- return new SkipTakeFilter ( Environment , options , input ) ;
310
+ return new SkipTakeFilter ( _env , options , input ) ;
311
311
}
312
312
313
313
/// <summary>
@@ -327,15 +327,15 @@ public IDataView SkipRows(IDataView input, long count)
327
327
/// </example>
328
328
public IDataView TakeRows ( IDataView input , long count )
329
329
{
330
- Environment . CheckValue ( input , nameof ( input ) ) ;
331
- Environment . CheckUserArg ( count > 0 , nameof ( count ) , "Must be greater than zero." ) ;
330
+ _env . CheckValue ( input , nameof ( input ) ) ;
331
+ _env . CheckUserArg ( count > 0 , nameof ( count ) , "Must be greater than zero." ) ;
332
332
333
333
var options = new SkipTakeFilter . TakeOptions ( )
334
334
{
335
335
Count = count
336
336
} ;
337
337
338
- return new SkipTakeFilter ( Environment , options , input ) ;
338
+ return new SkipTakeFilter ( _env , options , input ) ;
339
339
}
340
340
}
341
341
}
0 commit comments