7
7
using Microsoft . ML . Runtime . CommandLine ;
8
8
using Microsoft . ML . Runtime . Data ;
9
9
using Microsoft . ML . Runtime . Data . Conversion ;
10
+ using Microsoft . ML . Runtime . EntryPoints ;
10
11
using Microsoft . ML . Runtime . Internal . Calibration ;
11
12
using Microsoft . ML . Runtime . Internal . Internallearn ;
12
13
using Microsoft . ML . Runtime . Internal . Utilities ;
@@ -45,7 +46,7 @@ internal static class FastTreeShared
45
46
}
46
47
47
48
public abstract class FastTreeTrainerBase < TArgs , TTransformer , TModel > :
48
- TrainerEstimatorBase < TTransformer , TModel >
49
+ TrainerEstimatorBaseWithGroupId < TTransformer , TModel >
49
50
where TTransformer : ISingleFeaturePredictionTransformer < TModel >
50
51
where TArgs : TreeArgs , new ( )
51
52
where TModel : IPredictorProducing < Float >
@@ -92,26 +93,36 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
92
93
/// <summary>
93
94
/// Constructor to use when instantiating the classes deriving from here through the API.
94
95
/// </summary>
95
- private protected FastTreeTrainerBase ( IHostEnvironment env , SchemaShape . Column label , string featureColumn ,
96
- string weightColumn = null , string groupIdColumn = null , Action < TArgs > advancedSettings = null )
97
- : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( RegisterName ) , TrainerUtils . MakeR4VecFeature ( featureColumn ) , label , TrainerUtils . MakeR4ScalarWeightColumn ( weightColumn ) )
96
+ private protected FastTreeTrainerBase ( IHostEnvironment env ,
97
+ SchemaShape . Column label ,
98
+ string featureColumn ,
99
+ string weightColumn ,
100
+ string groupIdColumn ,
101
+ int numLeaves ,
102
+ int numTrees ,
103
+ int minDocumentsInLeafs ,
104
+ Action < TArgs > advancedSettings )
105
+ : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( RegisterName ) , TrainerUtils . MakeR4VecFeature ( featureColumn ) , label , TrainerUtils . MakeR4ScalarWeightColumn ( weightColumn ) , TrainerUtils . MakeU4ScalarColumn ( groupIdColumn ) )
98
106
{
99
107
Args = new TArgs ( ) ;
100
108
109
+ // set up the directly provided values
110
+ // override with the directly provided values.
111
+ Args . NumLeaves = numLeaves ;
112
+ Args . NumTrees = numTrees ;
113
+ Args . MinDocumentsInLeafs = minDocumentsInLeafs ;
114
+
101
115
//apply the advanced args, if the user supplied any
102
116
advancedSettings ? . Invoke ( Args ) ;
103
117
104
- // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly
105
- TrainerUtils . CheckArgsHaveDefaultColNames ( Host , Args ) ;
106
-
107
118
Args . LabelColumn = label . Name ;
108
119
Args . FeatureColumn = featureColumn ;
109
120
110
121
if ( weightColumn != null )
111
- Args . WeightColumn = weightColumn ;
122
+ Args . WeightColumn = Optional < string > . Explicit ( weightColumn ) ; ;
112
123
113
124
if ( groupIdColumn != null )
114
- Args . GroupIdColumn = groupIdColumn ;
125
+ Args . GroupIdColumn = Optional < string > . Explicit ( groupIdColumn ) ; ;
115
126
116
127
// The discretization step renders this trainer non-parametric, and therefore it does not need normalization.
117
128
// Also since it builds its own internal discretized columnar structures, it cannot benefit from caching.
@@ -128,7 +139,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column l
128
139
/// Legacy constructor that is used when invoking the classes deriving from this, through maml.
129
140
/// </summary>
130
141
private protected FastTreeTrainerBase ( IHostEnvironment env , TArgs args , SchemaShape . Column label )
131
- : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( RegisterName ) , TrainerUtils . MakeR4VecFeature ( args . FeatureColumn ) , label , TrainerUtils . MakeR4ScalarWeightColumn ( args . WeightColumn ) )
142
+ : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( RegisterName ) , TrainerUtils . MakeR4VecFeature ( args . FeatureColumn ) , label , TrainerUtils . MakeR4ScalarWeightColumn ( args . WeightColumn , args . WeightColumn . IsExplicit ) )
132
143
{
133
144
Host . CheckValue ( args , nameof ( args ) ) ;
134
145
Args = args ;
@@ -159,32 +170,6 @@ protected virtual Float GetMaxLabel()
159
170
return Float . PositiveInfinity ;
160
171
}
161
172
162
- /// <summary>
163
- /// If, after applying the advancedSettings delegate, the args are different that the default value
164
- /// and are also different than the value supplied directly to the xtension method, warn the user
165
- /// about which value is being used.
166
- /// The parameters that appear here, numTrees, minDocumentsInLeafs, numLeaves, learningRate are the ones the users are most likely to tune.
167
- /// This list should follow the one in the constructor, and the extension methods on the <see cref="TrainContextBase"/>.
168
- /// REVIEW: we should somehow mark the arguments that are set apart in those two places. Currently they stand out by their sort order annotation.
169
- /// </summary>
170
- protected void CheckArgsAndAdvancedSettingMismatch ( int numLeaves ,
171
- int numTrees ,
172
- int minDocumentsInLeafs ,
173
- double learningRate ,
174
- BoostedTreeArgs snapshot ,
175
- BoostedTreeArgs currentArgs )
176
- {
177
- using ( var ch = Host . Start ( "Comparing advanced settings with the directly provided values." ) )
178
- {
179
-
180
- // Check that the user didn't supply different parameters in the args, from what it specified directly.
181
- TrainerUtils . CheckArgsAndAdvancedSettingMismatch ( ch , numLeaves , snapshot . NumLeaves , currentArgs . NumLeaves , nameof ( numLeaves ) ) ;
182
- TrainerUtils . CheckArgsAndAdvancedSettingMismatch ( ch , numTrees , snapshot . NumTrees , currentArgs . NumTrees , nameof ( numTrees ) ) ;
183
- TrainerUtils . CheckArgsAndAdvancedSettingMismatch ( ch , minDocumentsInLeafs , snapshot . MinDocumentsInLeafs , currentArgs . MinDocumentsInLeafs , nameof ( minDocumentsInLeafs ) ) ;
184
- TrainerUtils . CheckArgsAndAdvancedSettingMismatch ( ch , learningRate , snapshot . LearningRates , currentArgs . LearningRates , nameof ( learningRate ) ) ;
185
- }
186
- }
187
-
188
173
private void Initialize ( IHostEnvironment env )
189
174
{
190
175
int numThreads = Args . NumThreads ?? Environment . ProcessorCount ;
0 commit comments