diff --git a/src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoader.cs b/src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoader.cs index 4bff0606be..a3b5691322 100644 --- a/src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoader.cs +++ b/src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoader.cs @@ -125,12 +125,21 @@ internal static DatabaseLoader CreateDatabaseLoader(IHostEnvironment hos var column = new Column(); column.Name = mappingAttrName?.Name ?? memberInfo.Name; - var mappingAttr = memberInfo.GetCustomAttribute(); + var indexMappingAttr = memberInfo.GetCustomAttribute(); + var nameMappingAttr = memberInfo.GetCustomAttribute(); - if (mappingAttr is object) + if (indexMappingAttr is object) { - var sources = mappingAttr.Sources.Select((source) => Range.FromTextLoaderRange(source)).ToArray(); - column.Source = sources; + if (nameMappingAttr is object) + { + throw Contracts.Except($"Cannot specify both {nameof(LoadColumnAttribute)} and {nameof(LoadColumnNameAttribute)}"); + } + + column.Source = indexMappingAttr.Sources.Select((source) => Range.FromTextLoaderRange(source)).ToArray(); + } + else if (nameMappingAttr is object) + { + column.Source = nameMappingAttr.Sources.Select((source) => new Range(source)).ToArray(); } InternalDataKind dk; @@ -228,7 +237,7 @@ public Column(string name, DbType dbType, Range[] source, KeyCount keyCount = nu public DbType Type = DbType.Single; /// - /// Source index range(s) of the column. + /// Source index or name range(s) of the column. /// [Argument(ArgumentType.Multiple, HelpText = "Source index range(s) of the column", ShortName = "src")] public Range[] Source; @@ -241,7 +250,7 @@ public Column(string name, DbType dbType, Range[] source, KeyCount keyCount = nu } /// - /// Specifies the range of indices of input columns that should be mapped to an output column. + /// Specifies the range of indices or names of input columns that should be mapped to an output column. /// public sealed class Range { @@ -256,6 +265,19 @@ public Range(int index) Contracts.CheckParam(index >= 0, nameof(index), "Must be non-negative"); Min = index; Max = index; + Name = null; + } + + /// + /// A range representing a single value. Will result in a scalar column. + /// + /// The name of the field of the table to read. + public Range(string name) + { + Contracts.CheckValue(name, nameof(name)); + Min = -1; + Max = -1; + Name = name; } /// @@ -278,15 +300,30 @@ public Range(int min, int max) /// /// The minimum index of the column, inclusive. /// + /// + /// This value is ignored if is not null. + /// [Argument(ArgumentType.Required, HelpText = "First index in the range")] public int Min; /// /// The maximum index of the column, inclusive. /// + /// + /// This value is ignored if is not null. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Last index in the range")] public int Max; + /// + /// The name of the input column. + /// + /// + /// This value, if non-null, overrides and . + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Name of the column")] + public string Name; + /// /// Force scalar columns to be treated as vectors of length one. /// @@ -318,6 +355,7 @@ public sealed class Options /// internal readonly struct Segment { + public readonly string Name; public readonly int Min; public readonly int Lim; public readonly bool ForceVector; @@ -325,10 +363,20 @@ internal readonly struct Segment public Segment(int min, int lim, bool forceVector) { Contracts.Assert(0 <= min & min < lim); + Name = null; Min = min; Lim = lim; ForceVector = forceVector; } + + public Segment(string name, bool forceVector) + { + Contracts.Assert(name != null); + Name = name; + Min = -1; + Lim = -1; + ForceVector = forceVector; + } } /// @@ -368,19 +416,23 @@ public static ColInfo Create(string name, PrimitiveDataViewType itemType, Segmen if (segs != null) { var order = Utils.GetIdentityPermutation(segs.Length); - Array.Sort(order, (x, y) => segs[x].Min.CompareTo(segs[y].Min)); - // Check that the segments are disjoint. - for (int i = 1; i < order.Length; i++) + if ((segs.Length != 0) && (segs[0].Name is null)) { - int a = order[i - 1]; - int b = order[i]; - Contracts.Assert(segs[a].Min <= segs[b].Min); - if (segs[a].Lim > segs[b].Min) + Array.Sort(order, (x, y) => segs[x].Min.CompareTo(segs[y].Min)); + + // Check that the segments are disjoint. + for (int i = 1; i < order.Length; i++) { - throw user ? - Contracts.ExceptUserArg(nameof(Column.Source), "Intervals specified for column '{0}' overlap", name) : - Contracts.ExceptDecode("Intervals specified for column '{0}' overlap", name); + int a = order[i - 1]; + int b = order[i]; + Contracts.Assert(segs[a].Min <= segs[b].Min); + if (segs[a].Lim > segs[b].Min) + { + throw user ? + Contracts.ExceptUserArg(nameof(Column.Source), "Intervals specified for column '{0}' overlap", name) : + Contracts.ExceptDecode("Intervals specified for column '{0}' overlap", name); + } } } @@ -389,7 +441,7 @@ public static ColInfo Create(string name, PrimitiveDataViewType itemType, Segmen for (int i = 0; i < segs.Length; i++) { var seg = segs[i]; - size += seg.Lim - seg.Min; + size += (seg.Name is null) ? seg.Lim - seg.Min : 1; } Contracts.Assert(size >= segs.Length); @@ -454,15 +506,23 @@ public Bindings(DatabaseLoader parent, Column[] cols) for (int i = 0; i < segs.Length; i++) { var range = col.Source[i]; - - int min = range.Min; - ch.CheckUserArg(0 <= min, nameof(range.Min)); - Segment seg; - int max = range.Max; - ch.CheckUserArg(min <= max, nameof(range.Max)); - seg = new Segment(min, max + 1, range.ForceVector); + if (range.Name is null) + { + int min = range.Min; + ch.CheckUserArg(0 <= min, nameof(range.Min)); + + int max = range.Max; + ch.CheckUserArg(min <= max, nameof(range.Max)); + seg = new Segment(min, max + 1, range.ForceVector); + } + else + { + string columnName = range.Name; + ch.CheckUserArg(columnName != null, nameof(range.Name)); + seg = new Segment(columnName, range.ForceVector); + } segs[i] = seg; } @@ -490,6 +550,7 @@ public Bindings(ModelLoadContext ctx, DatabaseLoader parent) // ulong: count for key range // int: number of segments // foreach segment: + // string id: name // int: min // int: lim // byte: force vector (verWrittenCur: verIsVectorSupported) @@ -532,11 +593,12 @@ public Bindings(ModelLoadContext ctx, DatabaseLoader parent) segs = new Segment[cseg]; for (int iseg = 0; iseg < cseg; iseg++) { + string columnName = ctx.LoadStringOrNull(); int min = ctx.Reader.ReadInt32(); int lim = ctx.Reader.ReadInt32(); Contracts.CheckDecode(0 <= min && min < lim); bool forceVector = ctx.Reader.ReadBoolByte(); - segs[iseg] = new Segment(min, lim, forceVector); + segs[iseg] = (columnName is null) ? new Segment(min, lim, forceVector) : new Segment(columnName, forceVector); } } @@ -563,6 +625,7 @@ internal void Save(ModelSaveContext ctx) // ulong: count for key range // int: number of segments // foreach segment: + // string id: name // int: min // int: lim // byte: force vector (verWrittenCur: verIsVectorSupported) @@ -588,6 +651,7 @@ internal void Save(ModelSaveContext ctx) ctx.Writer.Write(info.Segments.Length); foreach (var seg in info.Segments) { + ctx.SaveStringOrNull(seg.Name); ctx.Writer.Write(seg.Min); ctx.Writer.Write(seg.Lim); ctx.Writer.WriteBoolByte(seg.ForceVector); diff --git a/src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoaderCursor.cs b/src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoaderCursor.cs index e6b87dfffa..5ea7752f20 100644 --- a/src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoaderCursor.cs +++ b/src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoaderCursor.cs @@ -375,9 +375,17 @@ private ValueGetter> CreateVBufferBooleanGetterDelegate(ColInfo co foreach (var seg in segs) { - for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + if (seg.Name is null) { - editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetBoolean(columnIndex); + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetBoolean(columnIndex); + } + } + else + { + var columnIndex = DataReader.GetOrdinal(seg.Name); + editor.Values[i++] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetBoolean(columnIndex); } } @@ -397,9 +405,17 @@ private ValueGetter> CreateVBufferByteGetterDelegate(ColInfo colIn foreach (var seg in segs) { - for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + if (seg.Name is null) + { + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetByte(columnIndex); + } + } + else { - editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetByte(columnIndex); + var columnIndex = DataReader.GetOrdinal(seg.Name); + editor.Values[i++] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetByte(columnIndex); } } @@ -419,9 +435,17 @@ private ValueGetter> CreateVBufferDateTimeGetterDelegate(ColIn foreach (var seg in segs) { - for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + if (seg.Name is null) { - editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetDateTime(columnIndex); + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetDateTime(columnIndex); + } + } + else + { + var columnIndex = DataReader.GetOrdinal(seg.Name); + editor.Values[i++] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetDateTime(columnIndex); } } @@ -441,9 +465,17 @@ private ValueGetter> CreateVBufferDoubleGetterDelegate(ColInfo c foreach (var seg in segs) { - for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + if (seg.Name is null) + { + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? double.NaN : DataReader.GetDouble(columnIndex); + } + } + else { - editor.Values[i] = DataReader.IsDBNull(columnIndex) ? double.NaN : DataReader.GetDouble(columnIndex); + var columnIndex = DataReader.GetOrdinal(seg.Name); + editor.Values[i++] = DataReader.IsDBNull(columnIndex) ? double.NaN : DataReader.GetDouble(columnIndex); } } @@ -463,9 +495,17 @@ private ValueGetter> CreateVBufferInt16GetterDelegate(ColInfo col foreach (var seg in segs) { - for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + if (seg.Name is null) { - editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetInt16(columnIndex); + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetInt16(columnIndex); + } + } + else + { + var columnIndex = DataReader.GetOrdinal(seg.Name); + editor.Values[i++] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetInt16(columnIndex); } } @@ -485,9 +525,17 @@ private ValueGetter> CreateVBufferInt32GetterDelegate(ColInfo colIn foreach (var seg in segs) { - for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + if (seg.Name is null) + { + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetInt32(columnIndex); + } + } + else { - editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetInt32(columnIndex); + var columnIndex = DataReader.GetOrdinal(seg.Name); + editor.Values[i++] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetInt32(columnIndex); } } @@ -507,9 +555,17 @@ private ValueGetter> CreateVBufferInt64GetterDelegate(ColInfo colI foreach (var seg in segs) { - for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + if (seg.Name is null) { - editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetInt64(columnIndex); + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetInt64(columnIndex); + } + } + else + { + var columnIndex = DataReader.GetOrdinal(seg.Name); + editor.Values[i++] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetInt64(columnIndex); } } @@ -529,9 +585,17 @@ private ValueGetter> CreateVBufferSByteGetterDelegate(ColInfo col foreach (var seg in segs) { - for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + if (seg.Name is null) + { + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : (sbyte)DataReader.GetByte(columnIndex); + } + } + else { - editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : (sbyte)DataReader.GetByte(columnIndex); + var columnIndex = DataReader.GetOrdinal(seg.Name); + editor.Values[i++] = DataReader.IsDBNull(columnIndex) ? default : (sbyte)DataReader.GetByte(columnIndex); } } @@ -551,9 +615,17 @@ private ValueGetter> CreateVBufferSingleGetterDelegate(ColInfo co foreach (var seg in segs) { - for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + if (seg.Name is null) { - editor.Values[i] = DataReader.IsDBNull(columnIndex) ? float.NaN : DataReader.GetFloat(columnIndex); + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? float.NaN : DataReader.GetFloat(columnIndex); + } + } + else + { + var columnIndex = DataReader.GetOrdinal(seg.Name); + editor.Values[i++] = DataReader.IsDBNull(columnIndex) ? float.NaN : DataReader.GetFloat(columnIndex); } } @@ -573,9 +645,17 @@ private ValueGetter>> CreateVBufferStringGetterDele foreach (var seg in segs) { - for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + if (seg.Name is null) + { + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetString(columnIndex).AsMemory(); + } + } + else { - editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetString(columnIndex).AsMemory(); + var columnIndex = DataReader.GetOrdinal(seg.Name); + editor.Values[i++] = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetString(columnIndex).AsMemory(); } } @@ -595,9 +675,17 @@ private ValueGetter> CreateVBufferUInt16GetterDelegate(ColInfo c foreach (var seg in segs) { - for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + if (seg.Name is null) { - editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : (ushort)DataReader.GetInt16(columnIndex); + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : (ushort)DataReader.GetInt16(columnIndex); + } + } + else + { + var columnIndex = DataReader.GetOrdinal(seg.Name); + editor.Values[i++] = DataReader.IsDBNull(columnIndex) ? default : (ushort)DataReader.GetInt16(columnIndex); } } @@ -617,9 +705,17 @@ private ValueGetter> CreateVBufferUInt32GetterDelegate(ColInfo col foreach (var seg in segs) { - for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + if (seg.Name is null) + { + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : (uint)DataReader.GetInt32(columnIndex); + } + } + else { - editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : (uint)DataReader.GetInt32(columnIndex); + var columnIndex = DataReader.GetOrdinal(seg.Name); + editor.Values[i++] = DataReader.IsDBNull(columnIndex) ? default : (uint)DataReader.GetInt32(columnIndex); } } @@ -639,9 +735,17 @@ private ValueGetter> CreateVBufferUInt64GetterDelegate(ColInfo co foreach (var seg in segs) { - for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + if (seg.Name is null) { - editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : (ulong)DataReader.GetInt64(columnIndex); + for (int columnIndex = seg.Min; columnIndex < seg.Lim; columnIndex++, i++) + { + editor.Values[i] = DataReader.IsDBNull(columnIndex) ? default : (ulong)DataReader.GetInt64(columnIndex); + } + } + else + { + var columnIndex = DataReader.GetOrdinal(seg.Name); + editor.Values[i++] = DataReader.IsDBNull(columnIndex) ? default : (ulong)DataReader.GetInt64(columnIndex); } } @@ -653,7 +757,7 @@ private int GetColumnIndex(ColInfo colInfo) { var segs = colInfo.Segments; - if (segs is null) + if ((segs is null) || (segs.Length == 0)) { return DataReader.GetOrdinal(colInfo.Name); } @@ -661,9 +765,16 @@ private int GetColumnIndex(ColInfo colInfo) Contracts.Check(segs.Length == 1); var seg = segs[0]; - Contracts.Check(seg.Min == seg.Lim); - return seg.Min; + if (seg.Name is null) + { + Contracts.Check(seg.Min == seg.Lim); + return seg.Min; + } + else + { + return DataReader.GetOrdinal(seg.Name); + } } } } diff --git a/src/Microsoft.ML.Experimental/DataLoadSave/Database/LoadColumnNameAttribute.cs b/src/Microsoft.ML.Experimental/DataLoadSave/Database/LoadColumnNameAttribute.cs new file mode 100644 index 0000000000..ac33e204c3 --- /dev/null +++ b/src/Microsoft.ML.Experimental/DataLoadSave/Database/LoadColumnNameAttribute.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace Microsoft.ML.Data +{ + /// + /// Allow member to specify mapping to field(s) in database. + /// To override name of column use . + /// + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)] + public sealed class LoadColumnNameAttribute : Attribute + { + /// + /// Maps member to specific field in database. + /// + /// The name of the field in the database. + public LoadColumnNameAttribute(string fieldName) + { + var sources = new List(1); + sources.Add(fieldName); + Sources = sources; + } + + /// + /// Maps member to set of fields in database. + /// + /// Distinct database field names to load as part of this column. + public LoadColumnNameAttribute(params string[] fieldNames) + { + Sources = new List(fieldNames); + } + + [BestFriend] + internal readonly IReadOnlyList Sources; + } +} diff --git a/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs b/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs index e28ebd9813..8ca1af4caf 100644 --- a/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs +++ b/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs @@ -77,6 +77,52 @@ public void IrisLightGbm() }).PredictedLabel); } + [LightGBMFact] + public void IrisLightGbmWithLoadColumnName() + { + if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + // https://github.com/dotnet/machinelearning/issues/4156 + return; + } + + var mlContext = new MLContext(seed: 1); + + var connectionString = GetConnectionString(TestDatasets.irisDb.name); + var commandText = $@"SELECT Label as [My Label], SepalLength, SepalWidth, PetalLength, PetalWidth FROM ""{TestDatasets.irisDb.trainFilename}"""; + + var loader = mlContext.Data.CreateDatabaseLoader(); + + var databaseSource = new DatabaseSource(SqlClientFactory.Instance, connectionString, commandText); + + var trainingData = loader.Load(databaseSource); + + IEstimator pipeline = mlContext.Transforms.Conversion.MapValueToKey("Label") + .Append(mlContext.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")) + .Append(mlContext.MulticlassClassification.Trainers.LightGbm()) + .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel")); + + var model = pipeline.Fit(trainingData); + + var engine = mlContext.Model.CreatePredictionEngine(model); + + Assert.Equal(0, engine.Predict(new IrisData() + { + SepalLength = 4.5f, + SepalWidth = 5.6f, + PetalLength = 0.5f, + PetalWidth = 0.5f, + }).PredictedLabel); + + Assert.Equal(1, engine.Predict(new IrisData() + { + SepalLength = 4.9f, + SepalWidth = 2.4f, + PetalLength = 3.3f, + PetalWidth = 1.0f, + }).PredictedLabel); + } + [LightGBMFact] public void IrisVectorLightGbm() { @@ -119,6 +165,48 @@ public void IrisVectorLightGbm() }).PredictedLabel); } + [LightGBMFact] + public void IrisVectorLightGbmWithLoadColumnName() + { + if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + // https://github.com/dotnet/machinelearning/issues/4156 + return; + } + + var mlContext = new MLContext(seed: 1); + + var connectionString = GetConnectionString(TestDatasets.irisDb.name); + var commandText = $@"SELECT * FROM ""{TestDatasets.irisDb.trainFilename}"""; + + var loader = mlContext.Data.CreateDatabaseLoader(); + + var databaseSource = new DatabaseSource(SqlClientFactory.Instance, connectionString, commandText); + + var trainingData = loader.Load(databaseSource); + + IEstimator pipeline = mlContext.Transforms.Conversion.MapValueToKey("Label") + .Append(mlContext.Transforms.Concatenate("Features", "SepalInfo", "PetalInfo")) + .Append(mlContext.MulticlassClassification.Trainers.LightGbm()) + .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel")); + + var model = pipeline.Fit(trainingData); + + var engine = mlContext.Model.CreatePredictionEngine(model); + + Assert.Equal(0, engine.Predict(new IrisVectorData() + { + SepalInfo = new float[] { 4.5f, 5.6f }, + PetalInfo = new float[] { 0.5f, 0.5f }, + }).PredictedLabel); + + Assert.Equal(1, engine.Predict(new IrisVectorData() + { + SepalInfo = new float[] { 4.9f, 2.4f }, + PetalInfo = new float[] { 3.3f, 1.0f }, + }).PredictedLabel); + } + [Fact] public void IrisSdcaMaximumEntropy() { @@ -189,6 +277,21 @@ public class IrisData public float PetalWidth; } + public class IrisDataWithLoadColumnName + { + [LoadColumnName("My Label")] + [ColumnName("Label")] + public int Kind; + + public float SepalLength; + + public float SepalWidth; + + public float PetalLength; + + public float PetalWidth; + } + public class IrisVectorData { public int Label; @@ -202,6 +305,19 @@ public class IrisVectorData public float[] PetalInfo; } + public class IrisVectorDataWithLoadColumnName + { + public int Label; + + [LoadColumnName("SepalLength", "SepalWidth")] + [VectorType(2)] + public float[] SepalInfo; + + [LoadColumnName("PetalLength", "PetalWidth")] + [VectorType(2)] + public float[] PetalInfo; + } + public class IrisPrediction { public int PredictedLabel;