Skip to content

Commit f9d547b

Browse files
authored
Upgrade ML.NET package (dotnet#343)
1 parent 6a752a5 commit f9d547b

File tree

74 files changed

+506
-460
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+506
-460
lines changed

src/Microsoft.ML.Auto/API/BinaryClassificationExperiment.cs

+4-5
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
using System;
66
using System.Collections.Generic;
77
using System.Linq;
8-
using Microsoft.Data.DataView;
98
using Microsoft.ML.Data;
109

1110
namespace Microsoft.ML.Auto
@@ -37,10 +36,10 @@ public enum BinaryClassificationTrainer
3736
FastTree,
3837
LightGbm,
3938
LinearSupportVectorMachines,
40-
LogisticRegression,
41-
StochasticDualCoordinateAscent,
42-
StochasticGradientDescent,
43-
SymbolicStochasticGradientDescent,
39+
LbfgsLogisticRegression,
40+
SdcaLogisticRegression,
41+
SgdCalibrated,
42+
SymbolicSgdLogisticRegression,
4443
}
4544

4645
public sealed class BinaryClassificationExperiment

src/Microsoft.ML.Auto/API/ColumnInference.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public sealed class ColumnInferenceResults
1717
public sealed class ColumnInformation
1818
{
1919
public string LabelColumn { get; set; } = DefaultColumnNames.Label;
20-
public string WeightColumn { get; set; }
20+
public string ExampleWeightColumn { get; set; }
2121
public string SamplingKeyColumn { get; set; }
2222
public ICollection<string> CategoricalColumns { get; } = new Collection<string>();
2323
public ICollection<string> NumericColumns { get; } = new Collection<string>();

src/Microsoft.ML.Auto/API/MulticlassClassificationExperiment.cs

+14-15
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
using System;
66
using System.Collections.Generic;
77
using System.Linq;
8-
using Microsoft.Data.DataView;
98
using Microsoft.ML.Data;
109

1110
namespace Microsoft.ML.Auto
@@ -15,7 +14,7 @@ public sealed class MulticlassExperimentSettings : ExperimentSettings
1514
public MulticlassClassificationMetric OptimizingMetric { get; set; } = MulticlassClassificationMetric.MicroAccuracy;
1615
public ICollection<MulticlassClassificationTrainer> Trainers { get; } =
1716
Enum.GetValues(typeof(MulticlassClassificationTrainer)).OfType<MulticlassClassificationTrainer>().ToList();
18-
public IProgress<RunResult<MultiClassClassifierMetrics>> ProgressHandler { get; set; }
17+
public IProgress<RunResult<MulticlassClassificationMetrics>> ProgressHandler { get; set; }
1918
}
2019

2120
public enum MulticlassClassificationMetric
@@ -34,11 +33,11 @@ public enum MulticlassClassificationTrainer
3433
FastTreeOVA,
3534
LightGbm,
3635
LinearSupportVectorMachinesOVA,
37-
LogisticRegression,
38-
LogisticRegressionOVA,
39-
StochasticDualCoordinateAscent,
40-
StochasticGradientDescentOVA,
41-
SymbolicStochasticGradientDescentOVA,
36+
LbfgsMaximumEntropy,
37+
LbfgsLogisticRegressionOVA,
38+
SdcaMaximumEntropy,
39+
SgdCalibratedOVA,
40+
SymbolicSgdLogisticRegressionOVA,
4241
}
4342

4443
public sealed class MulticlassClassificationExperiment
@@ -52,7 +51,7 @@ internal MulticlassClassificationExperiment(MLContext context, MulticlassExperim
5251
_settings = settings;
5352
}
5453

55-
public IEnumerable<RunResult<MultiClassClassifierMetrics>> Execute(IDataView trainData, string labelColumn = DefaultColumnNames.Label,
54+
public IEnumerable<RunResult<MulticlassClassificationMetrics>> Execute(IDataView trainData, string labelColumn = DefaultColumnNames.Label,
5655
string samplingKeyColumn = null, IEstimator<ITransformer> preFeaturizers = null)
5756
{
5857
var columnInformation = new ColumnInformation()
@@ -63,28 +62,28 @@ public IEnumerable<RunResult<MultiClassClassifierMetrics>> Execute(IDataView tra
6362
return Execute(_context, trainData, columnInformation, null, preFeaturizers);
6463
}
6564

66-
public IEnumerable<RunResult<MultiClassClassifierMetrics>> Execute(IDataView trainData, ColumnInformation columnInformation, IEstimator<ITransformer> preFeaturizers = null)
65+
public IEnumerable<RunResult<MulticlassClassificationMetrics>> Execute(IDataView trainData, ColumnInformation columnInformation, IEstimator<ITransformer> preFeaturizers = null)
6766
{
6867
return Execute(_context, trainData, columnInformation, null, preFeaturizers);
6968
}
7069

71-
public IEnumerable<RunResult<MultiClassClassifierMetrics>> Execute(IDataView trainData, IDataView validationData, string labelColumn = DefaultColumnNames.Label, IEstimator<ITransformer> preFeaturizers = null)
70+
public IEnumerable<RunResult<MulticlassClassificationMetrics>> Execute(IDataView trainData, IDataView validationData, string labelColumn = DefaultColumnNames.Label, IEstimator<ITransformer> preFeaturizers = null)
7271
{
7372
var columnInformation = new ColumnInformation() { LabelColumn = labelColumn };
7473
return Execute(_context, trainData, columnInformation, validationData, preFeaturizers);
7574
}
7675

77-
public IEnumerable<RunResult<MultiClassClassifierMetrics>> Execute(IDataView trainData, IDataView validationData, ColumnInformation columnInformation, IEstimator<ITransformer> preFeaturizers = null)
76+
public IEnumerable<RunResult<MulticlassClassificationMetrics>> Execute(IDataView trainData, IDataView validationData, ColumnInformation columnInformation, IEstimator<ITransformer> preFeaturizers = null)
7877
{
7978
return Execute(_context, trainData, columnInformation, validationData, preFeaturizers);
8079
}
8180

82-
internal IEnumerable<RunResult<MultiClassClassifierMetrics>> Execute(IDataView trainData, uint numberOfCVFolds, ColumnInformation columnInformation = null, IEstimator<ITransformer> preFeaturizers = null)
81+
internal IEnumerable<RunResult<MulticlassClassificationMetrics>> Execute(IDataView trainData, uint numberOfCVFolds, ColumnInformation columnInformation = null, IEstimator<ITransformer> preFeaturizers = null)
8382
{
8483
throw new NotImplementedException();
8584
}
8685

87-
internal IEnumerable<RunResult<MultiClassClassifierMetrics>> Execute(MLContext context,
86+
internal IEnumerable<RunResult<MulticlassClassificationMetrics>> Execute(MLContext context,
8887
IDataView trainData,
8988
ColumnInformation columnInfo,
9089
IDataView validationData = null,
@@ -94,7 +93,7 @@ internal IEnumerable<RunResult<MultiClassClassifierMetrics>> Execute(MLContext c
9493
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData, columnInfo, validationData);
9594

9695
// run autofit & get all pipelines run in that process
97-
var experiment = new Experiment<MultiClassClassifierMetrics>(context, TaskKind.MulticlassClassification, trainData,
96+
var experiment = new Experiment<MulticlassClassificationMetrics>(context, TaskKind.MulticlassClassification, trainData,
9897
columnInfo, validationData, preFeaturizers, new OptimizingMetricInfo(_settings.OptimizingMetric),
9998
_settings.ProgressHandler, _settings, new MultiMetricsAgent(_settings.OptimizingMetric),
10099
TrainerExtensionUtil.GetTrainerNames(_settings.Trainers));
@@ -105,7 +104,7 @@ internal IEnumerable<RunResult<MultiClassClassifierMetrics>> Execute(MLContext c
105104

106105
public static class MulticlassExperimentResultExtensions
107106
{
108-
public static RunResult<MultiClassClassifierMetrics> Best(this IEnumerable<RunResult<MultiClassClassifierMetrics>> results, MulticlassClassificationMetric metric = MulticlassClassificationMetric.MicroAccuracy)
107+
public static RunResult<MulticlassClassificationMetrics> Best(this IEnumerable<RunResult<MulticlassClassificationMetrics>> results, MulticlassClassificationMetric metric = MulticlassClassificationMetric.MicroAccuracy)
109108
{
110109
var metricsAgent = new MultiMetricsAgent(metric);
111110
return RunResultUtil.GetBestRunResult(results, metricsAgent);

src/Microsoft.ML.Auto/API/RegressionExperiment.cs

+2-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
using System;
66
using System.Collections.Generic;
77
using System.Linq;
8-
using Microsoft.Data.DataView;
98
using Microsoft.ML.Data;
109

1110
namespace Microsoft.ML.Auto
@@ -33,8 +32,8 @@ public enum RegressionTrainer
3332
FastTreeTweedie,
3433
LightGbm,
3534
OnlineGradientDescent,
36-
OrdinaryLeastSquares,
37-
PoissonRegression,
35+
Ols,
36+
LbfgsPoissonRegression,
3837
StochasticDualCoordinateAscent,
3938
}
4039

src/Microsoft.ML.Auto/AutoMlUtils.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
using System;
66
using System.Threading;
7-
using Microsoft.Data.DataView;
7+
using Microsoft.ML.Data;
88

99
namespace Microsoft.ML.Auto
1010
{
@@ -30,7 +30,7 @@ public static (IDataView testData, IDataView validationData) TestValidateSplit(t
3030
MLContext context, IDataView trainData, ColumnInformation columnInfo)
3131
{
3232
IDataView validationData;
33-
var splitData = catalog.TrainTestSplit(trainData, samplingKeyColumn: columnInfo.SamplingKeyColumn);
33+
var splitData = context.Data.TrainTestSplit(trainData, samplingKeyColumnName: columnInfo.SamplingKeyColumn);
3434
trainData = splitData.TrainSet;
3535
validationData = splitData.TestSet;
3636
trainData = trainData.DropLastColumn(context);

src/Microsoft.ML.Auto/ColumnInference/ColumnInformationUtil.cs

+3-4
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System;
65
using System.Collections.Generic;
76
using System.Linq;
8-
using Microsoft.Data.DataView;
7+
using Microsoft.ML.Data;
98

109
namespace Microsoft.ML.Auto
1110
{
@@ -18,7 +17,7 @@ internal static class ColumnInformationUtil
1817
return ColumnPurpose.Label;
1918
}
2019

21-
if (columnName == columnInfo.WeightColumn)
20+
if (columnName == columnInfo.ExampleWeightColumn)
2221
{
2322
return ColumnPurpose.Weight;
2423
}
@@ -63,7 +62,7 @@ internal static ColumnInformation BuildColumnInfo(IEnumerable<(string name, Colu
6362
columnInfo.LabelColumn = column.name;
6463
break;
6564
case ColumnPurpose.Weight:
66-
columnInfo.WeightColumn = column.name;
65+
columnInfo.ExampleWeightColumn = column.name;
6766
break;
6867
case ColumnPurpose.SamplingKey:
6968
columnInfo.SamplingKeyColumn = column.name;

src/Microsoft.ML.Auto/ColumnInference/ColumnTypeInference.cs

+4-6
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
using System.Collections.Generic;
77
using System.Linq;
88
using System.Text.RegularExpressions;
9-
using Microsoft.Data.DataView;
109
using Microsoft.ML.Data;
1110

1211
namespace Microsoft.ML.Auto
@@ -272,16 +271,15 @@ private static InferenceResult InferTextFileColumnTypesCore(MLContext context, I
272271
var data = new List<ReadOnlyMemory<char>[]>();
273272
using (var cursor = idv.GetRowCursor(idv.Schema))
274273
{
275-
var column = cursor.Schema.GetColumnOrNull("C");
276-
int columnIndex = column.Value.Index;
277-
var colType = column.Value.Type;
274+
var column = cursor.Schema.GetColumnOrNull("C").Value;
275+
var colType = column.Type;
278276
ValueGetter<VBuffer<ReadOnlyMemory<char>>> vecGetter = null;
279277
ValueGetter<ReadOnlyMemory<char>> oneGetter = null;
280278
bool isVector = colType.IsVector();
281-
if (isVector) { vecGetter = cursor.GetGetter<VBuffer<ReadOnlyMemory<char>>>(columnIndex); }
279+
if (isVector) { vecGetter = cursor.GetGetter<VBuffer<ReadOnlyMemory<char>>>(column); }
282280
else
283281
{
284-
oneGetter = cursor.GetGetter<ReadOnlyMemory<char>>(columnIndex);
282+
oneGetter = cursor.GetGetter<ReadOnlyMemory<char>>(column);
285283
}
286284

287285
VBuffer<ReadOnlyMemory<char>> line = default;

src/Microsoft.ML.Auto/ColumnInference/PurposeInference.cs

+4-5
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
using System;
66
using System.Collections.Generic;
77
using System.Linq;
8-
using System.Text.RegularExpressions;
9-
using Microsoft.Data.DataView;
108
using Microsoft.ML.Data;
119

1210
namespace Microsoft.ML.Auto
@@ -88,10 +86,11 @@ public IReadOnlyList<ReadOnlyMemory<char>> GetColumnData()
8886
return _cachedData;
8987

9088
var results = new List<ReadOnlyMemory<char>>();
91-
92-
using (var cursor = _data.GetRowCursor(new[] { _data.Schema[_columnId] }))
89+
var column = _data.Schema[_columnId];
90+
91+
using (var cursor = _data.GetRowCursor(new[] { column }))
9392
{
94-
var getter = cursor.GetGetter<ReadOnlyMemory<char>>(_columnId);
93+
var getter = cursor.GetGetter<ReadOnlyMemory<char>>(column);
9594
while (cursor.MoveNext())
9695
{
9796
var value = default(ReadOnlyMemory<char>);

src/Microsoft.ML.Auto/ColumnInference/TextFileContents.cs

+2-3
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,10 @@ private static bool TryParseFile(MLContext context, TextLoader.Options options,
8989
var idv = context.Data.TakeRows(textLoader.Load(source), 1000);
9090
var columnCounts = new List<int>();
9191
var column = idv.Schema["C"];
92-
var columnIndex = column.Index;
9392

94-
using (var cursor = idv.GetRowCursor(new[] { idv.Schema[columnIndex] }))
93+
using (var cursor = idv.GetRowCursor(new[] { column }))
9594
{
96-
var getter = cursor.GetGetter<VBuffer<ReadOnlyMemory<char>>>(columnIndex);
95+
var getter = cursor.GetGetter<VBuffer<ReadOnlyMemory<char>>>(column);
9796

9897
VBuffer<ReadOnlyMemory<char>> line = default;
9998
while (cursor.MoveNext())

src/Microsoft.ML.Auto/DatasetDimensions/DatasetDimensionsApi.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using Microsoft.Data.DataView;
5+
using Microsoft.ML.Data;
66

77
namespace Microsoft.ML.Auto
88
{
@@ -30,15 +30,15 @@ public static ColumnDimensions[] CalcColumnDimensions(MLContext context, IDataVi
3030
// If categorical text feature, calculate cardinality
3131
if (itemType.IsText() && purpose.Purpose == ColumnPurpose.CategoricalFeature)
3232
{
33-
cardinality = DatasetDimensionsUtil.GetTextColumnCardinality(data, i);
33+
cardinality = DatasetDimensionsUtil.GetTextColumnCardinality(data, column);
3434
}
3535

3636
// If numeric feature, discover missing values
3737
if (itemType == NumberDataViewType.Single)
3838
{
3939
hasMissing = column.Type.IsVector() ?
40-
DatasetDimensionsUtil.HasMissingNumericVector(data, i) :
41-
DatasetDimensionsUtil.HasMissingNumericSingleValue(data, i);
40+
DatasetDimensionsUtil.HasMissingNumericVector(data, column) :
41+
DatasetDimensionsUtil.HasMissingNumericSingleValue(data, column);
4242
}
4343

4444
colDimensions[i] = new ColumnDimensions(cardinality, hasMissing);

src/Microsoft.ML.Auto/DatasetDimensions/DatasetDimensionsUtil.cs

+9-10
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,18 @@
44

55
using System;
66
using System.Collections.Generic;
7-
using Microsoft.Data.DataView;
87
using Microsoft.ML.Data;
98

109
namespace Microsoft.ML.Auto
1110
{
1211
internal static class DatasetDimensionsUtil
1312
{
14-
public static int GetTextColumnCardinality(IDataView data, int colIndex)
13+
public static int GetTextColumnCardinality(IDataView data, DataViewSchema.Column column)
1514
{
1615
var seen = new HashSet<string>();
17-
using (var cursor = data.GetRowCursor(new[] { data.Schema[colIndex] }))
16+
using (var cursor = data.GetRowCursor(new[] { column }))
1817
{
19-
var getter = cursor.GetGetter<ReadOnlyMemory<char>>(colIndex);
18+
var getter = cursor.GetGetter<ReadOnlyMemory<char>>(column);
2019
while (cursor.MoveNext())
2120
{
2221
var value = default(ReadOnlyMemory<char>);
@@ -28,11 +27,11 @@ public static int GetTextColumnCardinality(IDataView data, int colIndex)
2827
return seen.Count;
2928
}
3029

31-
public static bool HasMissingNumericSingleValue(IDataView data, int colIndex)
30+
public static bool HasMissingNumericSingleValue(IDataView data, DataViewSchema.Column column)
3231
{
33-
using (var cursor = data.GetRowCursor(new[] { data.Schema[colIndex] }))
32+
using (var cursor = data.GetRowCursor(new[] { column }))
3433
{
35-
var getter = cursor.GetGetter<Single>(colIndex);
34+
var getter = cursor.GetGetter<Single>(column);
3635
var value = default(Single);
3736
while (cursor.MoveNext())
3837
{
@@ -46,11 +45,11 @@ public static bool HasMissingNumericSingleValue(IDataView data, int colIndex)
4645
}
4746
}
4847

49-
public static bool HasMissingNumericVector(IDataView data, int colIndex)
48+
public static bool HasMissingNumericVector(IDataView data, DataViewSchema.Column column)
5049
{
51-
using (var cursor = data.GetRowCursor(new[] { data.Schema[colIndex] }))
50+
using (var cursor = data.GetRowCursor(new[] { column }))
5251
{
53-
var getter = cursor.GetGetter<VBuffer<Single>>(colIndex);
52+
var getter = cursor.GetGetter<VBuffer<Single>>(column);
5453
var value = default(VBuffer<Single>);
5554
while (cursor.MoveNext())
5655
{

0 commit comments

Comments
 (0)