2
2
// The .NET Foundation licenses this file to you under the MIT license.
3
3
// See the LICENSE file in the project root for more information.
4
4
5
- using System ;
6
- using System . Linq ;
5
+ using Microsoft . ML . Core . Data ;
7
6
using Microsoft . ML . Runtime ;
8
7
using Microsoft . ML . Runtime . Data ;
9
8
using Microsoft . ML . Runtime . EntryPoints ;
12
11
using Microsoft . ML . Runtime . Model ;
13
12
using Microsoft . ML . Runtime . Training ;
14
13
using Microsoft . ML . Runtime . Internal . Internallearn ;
14
+ using System ;
15
+ using System . Collections . Generic ;
16
+ using System . Linq ;
15
17
16
18
[ assembly: LoadableClass ( MultiClassNaiveBayesTrainer . Summary , typeof ( MultiClassNaiveBayesTrainer ) , typeof ( MultiClassNaiveBayesTrainer . Arguments ) ,
17
19
new [ ] { typeof ( SignatureMultiClassClassifierTrainer ) , typeof ( SignatureTrainer ) } ,
22
24
[ assembly: LoadableClass ( typeof ( MultiClassNaiveBayesPredictor ) , null , typeof ( SignatureLoadModel ) ,
23
25
"Multi Class Naive Bayes predictor" , MultiClassNaiveBayesPredictor . LoaderSignature ) ]
24
26
25
- [ assembly: LoadableClass ( typeof ( void ) , typeof ( MultiClassNaiveBayesTrainer ) , null , typeof ( SignatureEntryPointModule ) , "MultiClassNaiveBayes" ) ]
27
+ [ assembly: LoadableClass ( typeof ( void ) , typeof ( MultiClassNaiveBayesTrainer ) , null , typeof ( SignatureEntryPointModule ) , MultiClassNaiveBayesTrainer . LoadName ) ]
26
28
27
29
namespace Microsoft . ML . Runtime . Learners
28
30
{
29
31
/// <include file='doc.xml' path='doc/members/member[@name="MultiClassNaiveBayesTrainer"]' />
30
- public sealed class MultiClassNaiveBayesTrainer : TrainerBase < MultiClassNaiveBayesPredictor >
32
+ public sealed class MultiClassNaiveBayesTrainer : TrainerEstimatorBase < MulticlassPredictionTransformer < MultiClassNaiveBayesPredictor > , MultiClassNaiveBayesPredictor >
31
33
{
32
34
public const string LoadName = "MultiClassNaiveBayes" ;
33
35
internal const string UserName = "Multiclass Naive Bayes" ;
@@ -43,13 +45,52 @@ public sealed class Arguments : LearnerInputBaseWithLabel
43
45
private static readonly TrainerInfo _info = new TrainerInfo ( normalization : false , caching : false ) ;
44
46
public override TrainerInfo Info => _info ;
45
47
46
- public MultiClassNaiveBayesTrainer ( IHostEnvironment env , Arguments args )
47
- : base ( env , LoadName )
48
+ /// <summary>
49
+ /// Initializes a new instance of <see cref="MultiClassNaiveBayesTrainer"/>
50
+ /// </summary>
51
+ /// <param name="env">The environment to use.</param>
52
+ /// <param name="labelColumn">The name of the label column.</param>
53
+ /// <param name="featureColumn">The name of the feature column.</param>
54
+ public MultiClassNaiveBayesTrainer ( IHostEnvironment env , string featureColumn , string labelColumn )
55
+ : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( LoadName ) , TrainerUtils . MakeR4VecFeature ( featureColumn ) ,
56
+ TrainerUtils . MakeU4ScalarLabel ( labelColumn ) )
57
+ {
58
+ Host . CheckNonEmpty ( featureColumn , nameof ( featureColumn ) ) ;
59
+ Host . CheckNonEmpty ( labelColumn , nameof ( labelColumn ) ) ;
60
+ }
61
+
62
+ /// <summary>
63
+ /// Initializes a new instance of <see cref="MultiClassNaiveBayesTrainer"/>
64
+ /// </summary>
65
+ internal MultiClassNaiveBayesTrainer ( IHostEnvironment env , Arguments args )
66
+ : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( LoadName ) , TrainerUtils . MakeR4VecFeature ( args . FeatureColumn ) ,
67
+ TrainerUtils . MakeU4ScalarLabel ( args . LabelColumn ) )
48
68
{
49
69
Host . CheckValue ( args , nameof ( args ) ) ;
50
70
}
51
71
52
- public override MultiClassNaiveBayesPredictor Train ( TrainContext context )
72
+ protected override SchemaShape . Column [ ] GetOutputColumnsCore ( SchemaShape inputSchema )
73
+ {
74
+ bool success = inputSchema . TryFindColumn ( LabelColumn . Name , out var labelCol ) ;
75
+ Contracts . Assert ( success ) ;
76
+
77
+ var scoreMetadata = new List < SchemaShape . Column > ( ) { new SchemaShape . Column ( MetadataUtils . Kinds . SlotNames , SchemaShape . Column . VectorKind . Vector , TextType . Instance , false ) } ;
78
+ scoreMetadata . AddRange ( MetadataUtils . GetTrainerOutputMetadata ( ) ) ;
79
+
80
+ var predLabelMetadata = new SchemaShape ( labelCol . Metadata . Columns . Where ( x => x . Name == MetadataUtils . Kinds . KeyValues )
81
+ . Concat ( MetadataUtils . GetTrainerOutputMetadata ( ) ) ) ;
82
+
83
+ return new [ ]
84
+ {
85
+ new SchemaShape . Column ( DefaultColumnNames . Score , SchemaShape . Column . VectorKind . Vector , NumberType . R4 , false , new SchemaShape ( scoreMetadata ) ) ,
86
+ new SchemaShape . Column ( DefaultColumnNames . PredictedLabel , SchemaShape . Column . VectorKind . Scalar , NumberType . U4 , true , predLabelMetadata )
87
+ } ;
88
+ }
89
+
90
+ protected override MulticlassPredictionTransformer < MultiClassNaiveBayesPredictor > MakeTransformer ( MultiClassNaiveBayesPredictor model , ISchema trainSchema )
91
+ => new MulticlassPredictionTransformer < MultiClassNaiveBayesPredictor > ( Host , model , trainSchema , FeatureColumn . Name , LabelColumn . Name ) ;
92
+
93
+ protected override MultiClassNaiveBayesPredictor TrainModelCore ( TrainContext context )
53
94
{
54
95
Host . CheckValue ( context , nameof ( context ) ) ;
55
96
var data = context . TrainingSet ;
@@ -170,6 +211,40 @@ private static VersionInfo GetVersionInfo()
170
211
171
212
public ColumnType OutputType => _outputType ;
172
213
214
+ /// <summary>
215
+ /// Copies the label histogram into a buffer.
216
+ /// </summary>
217
+ /// <param name="labelHistogram">A possibly reusable array, which will
218
+ /// be expanded as necessary to accomodate the data.</param>
219
+ /// <param name="labelCount">Set to the length of the resized array, which is also the number of different labels.</param>
220
+ public void GetLabelHistogram ( ref int [ ] labelHistogram , out int labelCount )
221
+ {
222
+ labelCount = _labelCount ;
223
+ Utils . EnsureSize ( ref labelHistogram , _labelCount ) ;
224
+ Array . Copy ( _labelHistogram , labelHistogram , _labelCount ) ;
225
+ }
226
+
227
+ /// <summary>
228
+ /// Copies the feature histogram into a buffer.
229
+ /// </summary>
230
+ /// <param name="featureHistogram">A possibly reusable array, which will
231
+ /// be expanded as necessary to accomodate the data.</param>
232
+ /// <param name="labelCount">Set to the first dimension of the resized array,
233
+ /// which is the number of different labels encountered in training.</param>
234
+ /// <param name="featureCount">Set to the second dimension of the resized array,
235
+ /// which is also the number of different feature combinations encountered in training.</param>
236
+ public void GetFeatureHistogram ( ref int [ ] [ ] featureHistogram , out int labelCount , out int featureCount )
237
+ {
238
+ labelCount = _labelCount ;
239
+ featureCount = _featureCount ;
240
+ Utils . EnsureSize ( ref featureHistogram , _labelCount ) ;
241
+ for ( int i = 0 ; i < _labelCount ; i ++ )
242
+ {
243
+ Utils . EnsureSize ( ref featureHistogram [ i ] , _featureCount ) ;
244
+ Array . Copy ( _featureHistogram [ i ] , featureHistogram [ i ] , _featureCount ) ;
245
+ }
246
+ }
247
+
173
248
internal MultiClassNaiveBayesPredictor ( IHostEnvironment env , int [ ] labelHistogram , int [ ] [ ] featureHistogram , int featureCount )
174
249
: base ( env , LoaderSignature )
175
250
{
0 commit comments