Skip to content

Commit 6716d1e

Browse files
srsaggamDmitry-A
authored andcommitted
Default the kfolds to value 5 in CLI generated code (dotnet#115)
* Added sequential grouping of columns * reverted the file * Set up CI with Azure Pipelines * Update azure-pipelines.yml for Azure Pipelines * Update azure-pipelines.yml for Azure Pipelines * remove file * added kfold param and defaulted to value * changed type * added for regression
1 parent e4eed98 commit 6716d1e

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

src/mlnet/Templates/MLCodeGen.cs

+9-6
Original file line numberDiff line numberDiff line change
@@ -133,16 +133,18 @@ private static ITransformer BuildTrainEvaluateAndSaveModel(MLContext mlContext)
133133
if("BinaryClassification".Equals(TaskType)){
134134
this.Write(" var crossValidationResults = mlContext.");
135135
this.Write(this.ToStringHelper.ToStringWithCulture(TaskType));
136-
this.Write(".CrossValidateNonCalibrated(trainingDataView, trainingPipeline, numFolds: 3, labe" +
137-
"lColumn:\"Label\");\r\n ConsoleHelper.PrintBinaryClassificationFoldsAvera" +
138-
"geMetrics(trainer.ToString(), crossValidationResults);\r\n");
136+
this.Write(".CrossValidateNonCalibrated(trainingDataView, trainingPipeline, numFolds: ");
137+
this.Write(this.ToStringHelper.ToStringWithCulture(Kfolds));
138+
this.Write(", labelColumn:\"Label\");\r\n ConsoleHelper.PrintBinaryClassificationFolds" +
139+
"AverageMetrics(trainer.ToString(), crossValidationResults);\r\n");
139140
}
140141
if("Regression".Equals(TaskType)){
141142
this.Write(" var crossValidationResults = mlContext.");
142143
this.Write(this.ToStringHelper.ToStringWithCulture(TaskType));
143-
this.Write(".CrossValidate(trainingDataView, trainingPipeline, numFolds: 3, labelColumn:\"Labe" +
144-
"l\");\r\n ConsoleHelper.PrintRegressionFoldsAverageMetrics(trainer.ToStr" +
145-
"ing(), crossValidationResults);\r\n");
144+
this.Write(".CrossValidate(trainingDataView, trainingPipeline, numFolds: ");
145+
this.Write(this.ToStringHelper.ToStringWithCulture(Kfolds));
146+
this.Write(", labelColumn:\"Label\");\r\n ConsoleHelper.PrintRegressionFoldsAverageMet" +
147+
"rics(trainer.ToString(), crossValidationResults);\r\n");
146148
}
147149
}
148150
this.Write(@"
@@ -249,6 +251,7 @@ private static void TestSinglePrediction(MLContext mlContext)
249251
public bool AllowQuoting {get;set;}
250252
public bool AllowSparse {get;set;}
251253
public bool TrimWhiteSpace {get;set;}
254+
public int Kfolds {get;set;} = 5;
252255

253256
}
254257
#region Base class

src/mlnet/Templates/MLCodeGen.tt

+3-2
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,10 @@ else{#>
100100
// in order to evaluate and get the model's accuracy metrics
101101
Console.WriteLine("=============== Cross-validating to get model's accuracy metrics ===============");
102102
<#if("BinaryClassification".Equals(TaskType)){ #>
103-
var crossValidationResults = mlContext.<#= TaskType #>.CrossValidateNonCalibrated(trainingDataView, trainingPipeline, numFolds: 3, labelColumn:"Label");
103+
var crossValidationResults = mlContext.<#= TaskType #>.CrossValidateNonCalibrated(trainingDataView, trainingPipeline, numFolds: <#= Kfolds #>, labelColumn:"Label");
104104
ConsoleHelper.PrintBinaryClassificationFoldsAverageMetrics(trainer.ToString(), crossValidationResults);
105105
<#}#><#if("Regression".Equals(TaskType)){ #>
106-
var crossValidationResults = mlContext.<#= TaskType #>.CrossValidate(trainingDataView, trainingPipeline, numFolds: 3, labelColumn:"Label");
106+
var crossValidationResults = mlContext.<#= TaskType #>.CrossValidate(trainingDataView, trainingPipeline, numFolds: <#= Kfolds #>, labelColumn:"Label");
107107
ConsoleHelper.PrintRegressionFoldsAverageMetrics(trainer.ToString(), crossValidationResults);
108108
<#}#>
109109
<# } #>
@@ -205,4 +205,5 @@ public string GeneratedUsings {get;set;}
205205
public bool AllowQuoting {get;set;}
206206
public bool AllowSparse {get;set;}
207207
public bool TrimWhiteSpace {get;set;}
208+
public int Kfolds {get;set;} = 5;
208209
#>

0 commit comments

Comments
 (0)