15
15
using Microsoft . ML . Runtime . Internal . Utilities ;
16
16
using Microsoft . ML . Runtime . Internal . Internallearn ;
17
17
using Microsoft . ML . Core . Data ;
18
+ using Microsoft . ML . Data . StaticPipe . Runtime ;
18
19
19
20
[ assembly: LoadableClass ( CategoricalTransform . Summary , typeof ( IDataTransform ) , typeof ( CategoricalTransform ) , typeof ( CategoricalTransform . Arguments ) , typeof ( SignatureDataTransform ) ,
20
21
CategoricalTransform . UserName , "CategoricalTransform" , "CatTransform" , "Categorical" , "Cat" ) ]
@@ -119,7 +120,7 @@ public Arguments()
119
120
120
121
public const string UserName = "Categorical Transform" ;
121
122
122
- public static IDataView Create ( IHostEnvironment env , Arguments args , IDataView input )
123
+ public static IDataTransform Create ( IHostEnvironment env , Arguments args , IDataView input )
123
124
{
124
125
Contracts . CheckValue ( env , nameof ( env ) ) ;
125
126
var h = env . Register ( "Categorical" ) ;
@@ -140,7 +141,7 @@ public static IDataView Create(IHostEnvironment env, Arguments args, IDataView i
140
141
col . SetTerms ( column . Terms ) ;
141
142
columns . Add ( col ) ;
142
143
}
143
- return Create ( env , input , columns . ToArray ( ) ) ;
144
+ return Create ( env , input , columns . ToArray ( ) ) as IDataTransform ;
144
145
}
145
146
146
147
public static IDataView Create ( IHostEnvironment env , IDataView input , params CategoricalEstimator . ColumnInfo [ ] columns )
@@ -306,4 +307,109 @@ public static CommonOutputs.TransformOutput KeyToText(IHostEnvironment env, KeyT
306
307
return new CommonOutputs . TransformOutput { Model = new TransformModel ( env , xf , input . Data ) , OutputData = xf } ;
307
308
}
308
309
}
310
+
311
+ public enum OneHotOutputKind : byte
312
+ {
313
+ /// <summary>
314
+ /// Output is a bag (multi-set) vector
315
+ /// </summary>
316
+ Bag = 1 ,
317
+
318
+ /// <summary>
319
+ /// Output is an indicator vector
320
+ /// </summary>
321
+ Ind = 2 ,
322
+
323
+ /// <summary>
324
+ /// Output is a key value
325
+ /// </summary>
326
+ Key = 3 ,
327
+
328
+ /// <summary>
329
+ /// Output is binary encoded
330
+ /// </summary>
331
+ Bin = 4 ,
332
+ }
333
+
334
+ public static partial class CategoricalStaticExtensions
335
+ {
336
+ // I am not certain I see a good way to cover the distinct types beyond complete enumeration.
337
+ // Raw generics would allow illegal possible inputs, e.g., Scalar<Bitmap>. So, this is a partial
338
+ // class, and all the public facing extension methods for each possible type are in a T4 generated result.
339
+
340
+ private const KeyValueOrder DefSort = ( KeyValueOrder ) TermEstimator . Defaults . Sort ;
341
+ private const int DefMax = TermEstimator . Defaults . MaxNumTerms ;
342
+ private const OneHotOutputKind DefOut = ( OneHotOutputKind ) CategoricalEstimator . Defaults . OutKind ;
343
+
344
+ private struct Config
345
+ {
346
+ public readonly KeyValueOrder Order ;
347
+ public readonly int Max ;
348
+ public readonly OneHotOutputKind OutputKind ;
349
+
350
+ public Config ( KeyValueOrder order , int max , OneHotOutputKind outputKind )
351
+ {
352
+ Order = order ;
353
+ Max = max ;
354
+ OutputKind = outputKind ;
355
+ }
356
+ }
357
+
358
+ private interface IOneHotCol
359
+ {
360
+ PipelineColumn Input { get ; }
361
+ Config Config { get ; }
362
+ }
363
+
364
+ private sealed class ImplScalar < T > : Vector < float > , IOneHotCol
365
+ {
366
+ public PipelineColumn Input { get ; }
367
+ public Config Config { get ; }
368
+ public ImplScalar ( PipelineColumn input , Config config ) : base ( Rec . Inst , input )
369
+ {
370
+ Input = input ;
371
+ Config = config ;
372
+ }
373
+ }
374
+
375
+ private sealed class ImplVector < T > : Vector < float > , IOneHotCol
376
+ {
377
+ public PipelineColumn Input { get ; }
378
+ public Config Config { get ; }
379
+ public ImplVector ( PipelineColumn input , Config config ) : base ( Rec . Inst , input )
380
+ {
381
+ Input = input ;
382
+ Config = config ;
383
+ }
384
+ }
385
+
386
+ private sealed class ImplVarVector < T > : VarVector < float > , IOneHotCol
387
+ {
388
+ public PipelineColumn Input { get ; }
389
+ public Config Config { get ; }
390
+ public ImplVarVector ( PipelineColumn input , Config config ) : base ( Rec . Inst , input )
391
+ {
392
+ Input = input ;
393
+ Config = config ;
394
+ }
395
+ }
396
+
397
+ private sealed class Rec : EstimatorReconciler
398
+ {
399
+ public static readonly Rec Inst = new Rec ( ) ;
400
+
401
+ public override IEstimator < ITransformer > Reconcile ( IHostEnvironment env , PipelineColumn [ ] toOutput ,
402
+ IReadOnlyDictionary < PipelineColumn , string > inputNames , IReadOnlyDictionary < PipelineColumn , string > outputNames , IReadOnlyCollection < string > usedNames )
403
+ {
404
+ var infos = new CategoricalEstimator . ColumnInfo [ toOutput . Length ] ;
405
+ for ( int i = 0 ; i < toOutput . Length ; ++ i )
406
+ {
407
+ var tcol = ( IOneHotCol ) toOutput [ i ] ;
408
+ infos [ i ] = new CategoricalEstimator . ColumnInfo ( inputNames [ tcol . Input ] , outputNames [ toOutput [ i ] ] , ( CategoricalTransform . OutputKind ) tcol . Config . OutputKind ,
409
+ tcol . Config . Max , ( TermTransform . SortOrder ) tcol . Config . Order ) ;
410
+ }
411
+ return new CategoricalEstimator ( env , infos ) ;
412
+ }
413
+ }
414
+ }
309
415
}
0 commit comments