Skip to content

Commit 361caa2

Browse files
srsaggamDmitry-A
authored andcommitted
Caching enabling in code gen part -2 (dotnet#298)
* add * added caching codegen
1 parent 3d9bf7c commit 361caa2

File tree

5 files changed

+20
-10
lines changed

5 files changed

+20
-10
lines changed

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.GeneratedTrainCodeTest.approved.txt

+3-2
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ namespace MyNamespace
6868
allowSparse: true);
6969

7070
// Common data process configuration with pipeline data transformations
71-
var dataProcessPipeline = mlContext.Transforms.Concatenate("Out", new[] { "In" });
71+
var dataProcessPipeline = mlContext.Transforms.Concatenate("Out", new[] { "In" })
72+
.AppendCacheCheckpoint(mlContext);
7273

7374
// Set the training algorithm, then create and config the modelBuilder
7475
var trainer = mlContext.BinaryClassification.Trainers.LightGbm(new Options() { NumLeaves = 2, Booster = new Options.TreeBooster.Options() { }, LabelColumn = "Label", FeatureColumn = "Features" });
@@ -100,7 +101,7 @@ namespace MyNamespace
100101
{
101102
//Load data to test. Could be any test data. For demonstration purpose train data is used here.
102103
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
103-
path: TrainDataPath,
104+
path: TestDataPath,
104105
hasHeader: true,
105106
separatorChar: ',',
106107
allowQuoting: true,

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ public void GeneratedHelperCodeTest()
104104
var trainer2 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), new ColumnInformation(), hyperparams2);
105105
var transforms1 = new List<SuggestedTransform>() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") };
106106
var transforms2 = new List<SuggestedTransform>() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") };
107-
var inferredPipeline1 = new SuggestedPipeline(transforms1, new List<SuggestedTransform>(), trainer1, context, false);
107+
var inferredPipeline1 = new SuggestedPipeline(transforms1, new List<SuggestedTransform>(), trainer1, context, true);
108108
var inferredPipeline2 = new SuggestedPipeline(transforms2, new List<SuggestedTransform>(), trainer2, context, false);
109109

110110
this.pipeline = inferredPipeline1.ToPipeline();

src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs

+4-3
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public void GenerateOutput()
5454
var namespaceValue = Utils.Normalize(settings.OutputName);
5555

5656
// Generate code for training and scoring
57-
var trainFileContent = GenerateTrainCode(usings, trainer, transforms, columns, classLabels, namespaceValue);
57+
var trainFileContent = GenerateTrainCode(usings, trainer, transforms, columns, classLabels, namespaceValue, pipeline.CacheBeforeTrainer);
5858
var tree = CSharpSyntaxTree.ParseText(trainFileContent);
5959
var syntaxNode = tree.GetRoot();
6060
trainFileContent = Formatter.Format(syntaxNode, new AdhocWorkspace()).ToFullString();
@@ -91,7 +91,7 @@ internal static string GeneratProjectCode()
9191
return projectCodeGen.TransformText();
9292
}
9393

94-
internal string GenerateTrainCode(string usings, string trainer, List<string> transforms, IList<string> columns, IList<string> classLabels, string namespaceValue)
94+
internal string GenerateTrainCode(string usings, string trainer, List<string> transforms, IList<string> columns, IList<string> classLabels, string namespaceValue, bool cacheBeforeTrainer)
9595
{
9696
var trainingAndScoringCodeGen = new MLCodeGen()
9797
{
@@ -110,7 +110,8 @@ internal string GenerateTrainCode(string usings, string trainer, List<string> tr
110110
TaskType = settings.MlTask.ToString(),
111111
Namespace = namespaceValue,
112112
LabelName = settings.LabelName,
113-
ModelPath = settings.ModelPath
113+
ModelPath = settings.ModelPath,
114+
CacheBeforeTrainer = cacheBeforeTrainer
114115
};
115116

116117
return trainingAndScoringCodeGen.TransformText();

src/mlnet/Templates/Console/MLCodeGen.cs

+9-2
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ private static ITransformer TrainEvaluateAndSaveModel(MLContext mlContext)
126126
{ Write(")");
127127
}
128128
}
129+
if(CacheBeforeTrainer){ Write("\r\n .AppendCacheCheckpoint(mlContext)");}
129130
this.Write(";\r\n");
130131
}
131132
this.Write("\r\n // Set the training algorithm, then create and config the modelBuil" +
@@ -207,8 +208,13 @@ private static void Predict(MLContext mlContext)
207208
{
208209
//Load data to test. Could be any test data. For demonstration purpose train data is used here.
209210
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
210-
path: TrainDataPath,
211-
hasHeader : ");
211+
path: ");
212+
if(!string.IsNullOrEmpty(TestPath)){
213+
this.Write("TestDataPath");
214+
}else{
215+
this.Write("TrainDataPath");
216+
}
217+
this.Write(",\r\n hasHeader : ");
212218
this.Write(this.ToStringHelper.ToStringWithCulture(HasHeader.ToString().ToLowerInvariant()));
213219
this.Write(",\r\n separatorChar : \'");
214220
this.Write(this.ToStringHelper.ToStringWithCulture(Regex.Escape(Separator.ToString())));
@@ -284,6 +290,7 @@ private static void Predict(MLContext mlContext)
284290
public string Namespace {get;set;}
285291
public string LabelName {get;set;}
286292
public string ModelPath {get;set;}
293+
public bool CacheBeforeTrainer {get;set;}
287294

288295
}
289296
#region Base class

src/mlnet/Templates/Console/MLCodeGen.tt

+3-2
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ namespace <#= Namespace #>
8888
if(i>0)
8989
{ Write(")");
9090
}
91-
}#>;
91+
}#><#if(CacheBeforeTrainer){ Write("\r\n .AppendCacheCheckpoint(mlContext)");} #>;
9292
<#}#>
9393

9494
// Set the training algorithm, then create and config the modelBuilder
@@ -146,7 +146,7 @@ if(string.IsNullOrEmpty(TestPath)){ #>
146146
{
147147
//Load data to test. Could be any test data. For demonstration purpose train data is used here.
148148
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
149-
path: TrainDataPath,
149+
path: <#if(!string.IsNullOrEmpty(TestPath)){ #>TestDataPath<#}else{#>TrainDataPath<#}#>,
150150
hasHeader : <#= HasHeader.ToString().ToLowerInvariant() #>,
151151
separatorChar : '<#= Regex.Escape(Separator.ToString()) #>',
152152
allowQuoting : <#= AllowQuoting.ToString().ToLowerInvariant() #>,
@@ -219,4 +219,5 @@ public int Kfolds {get;set;} = 5;
219219
public string Namespace {get;set;}
220220
public string LabelName {get;set;}
221221
public string ModelPath {get;set;}
222+
public bool CacheBeforeTrainer {get;set;}
222223
#>

0 commit comments

Comments
 (0)