diff --git a/src/Microsoft.ML.Core/Data/AnnotationUtils.cs b/src/Microsoft.ML.Core/Data/AnnotationUtils.cs index 19a5d3f44c..6810fa3959 100644 --- a/src/Microsoft.ML.Core/Data/AnnotationUtils.cs +++ b/src/Microsoft.ML.Core/Data/AnnotationUtils.cs @@ -441,8 +441,16 @@ public static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int co public static IEnumerable AnnotationsForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null) { var cols = new List(); - if (labelColumn != null && labelColumn.Value.IsKey && NeedsSlotNames(labelColumn.Value)) - cols.Add(new SchemaShape.Column(Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false)); + if (labelColumn != null && labelColumn.Value.IsKey) + { + if (labelColumn.Value.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol) && + metaCol.Kind == SchemaShape.Column.VectorKind.Vector) + { + if (metaCol.ItemType is TextDataViewType) + cols.Add(new SchemaShape.Column(Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false)); + cols.Add(new SchemaShape.Column(Kinds.TrainingLabelValues, SchemaShape.Column.VectorKind.Vector, metaCol.ItemType, false)); + } + } cols.AddRange(GetTrainerOutputAnnotation()); return cols; } diff --git a/src/Microsoft.ML.Data/Scorers/MulticlassClassificationScorer.cs b/src/Microsoft.ML.Data/Scorers/MulticlassClassificationScorer.cs index ae0bd506dc..1cb1ffc045 100644 --- a/src/Microsoft.ML.Data/Scorers/MulticlassClassificationScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/MulticlassClassificationScorer.cs @@ -390,7 +390,7 @@ private static ISchemaBoundMapper WrapIfNeeded(IHostEnvironment env, ISchemaBoun if (trainSchema?.Label == null) return mapper; // We don't even have a label identified in a training schema. var keyType = trainSchema.Label.Value.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type as VectorDataViewType; - if (keyType == null || !CanWrap(mapper, keyType)) + if (keyType == null) return mapper; // Great!! All checks pass. @@ -409,11 +409,19 @@ private static ISchemaBoundMapper WrapIfNeeded(IHostEnvironment env, ISchemaBoun /// from the model of a bindable mapper) /// Whether we can call with /// this mapper and expect it to succeed - internal static bool CanWrap(ISchemaBoundMapper mapper, DataViewType labelNameType) + internal static bool CanWrapTrainingLabels(ISchemaBoundMapper mapper, DataViewType labelNameType) + { + if (GetTypesForWrapping(mapper, labelNameType, AnnotationUtils.Kinds.TrainingLabelValues, out var scoreType)) + // Check that the type is vector, and is of compatible size with the score output. + return labelNameType is VectorDataViewType vectorType && vectorType.Size == scoreType.GetVectorSize(); + return false; + } + + internal static bool GetTypesForWrapping(ISchemaBoundMapper mapper, DataViewType labelNameType, string metaKind, out DataViewType scoreType) { Contracts.AssertValue(mapper); Contracts.AssertValue(labelNameType); - + scoreType = null; ISchemaBoundRowMapper rowMapper = mapper as ISchemaBoundRowMapper; if (rowMapper == null) return false; // We could cover this case, but it is of no practical worth as far as I see, so I decline to do so. @@ -423,12 +431,30 @@ internal static bool CanWrap(ISchemaBoundMapper mapper, DataViewType labelNameTy var scoreCol = outSchema.GetColumnOrNull(AnnotationUtils.Const.ScoreValueKind.Score); if (!outSchema.TryGetColumnIndex(AnnotationUtils.Const.ScoreValueKind.Score, out scoreIdx)) return false; // The mapper doesn't even publish a score column to attach the metadata to. - if (outSchema[scoreIdx].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)?.Type != null) - return false; // The mapper publishes a score column, and already produces its own slot names. - var scoreType = outSchema[scoreIdx].Type; + if (outSchema[scoreIdx].Annotations.Schema.GetColumnOrNull(metaKind)?.Type != null) + return false; // The mapper publishes a score column, and already produces its own metakind. + scoreType = outSchema[scoreIdx].Type; + return true; + } - // Check that the type is vector, and is of compatible size with the score output. - return labelNameType is VectorDataViewType vectorType && vectorType.Size == scoreType.GetVectorSize() && vectorType.ItemType == TextDataViewType.Instance; + /// + /// This is a utility method used to determine whether + /// can or should be used to wrap . This will not throw, since the + /// desired behavior in the event that it cannot be wrapped, is to just back off to the original + /// "unwrapped" bound mapper. + /// + /// The mapper we are seeing if we can wrap + /// The type of the label names from the metadata (either + /// originating from the key value metadata of the training label column, or deserialized + /// from the model of a bindable mapper) + /// Whether we can call with + /// this mapper and expect it to succeed + internal static bool CanWrapSlotNames(ISchemaBoundMapper mapper, DataViewType labelNameType) + { + if (GetTypesForWrapping(mapper, labelNameType, AnnotationUtils.Kinds.SlotNames, out var scoreType)) + // Check that the type is vector, and is of compatible size with the score output. + return labelNameType is VectorDataViewType vectorType && vectorType.Size == scoreType.GetVectorSize() && vectorType.ItemType == TextDataViewType.Instance; + return false; } internal static ISchemaBoundMapper WrapCore(IHostEnvironment env, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) @@ -449,8 +475,12 @@ internal static ISchemaBoundMapper WrapCore(IHostEnvironment env, ISchemaBoun { trainSchema.Label.Value.GetKeyValues(ref value); }; - - return LabelNameBindableMapper.CreateBound(env, (ISchemaBoundRowMapper)mapper, type as VectorDataViewType, getter, AnnotationUtils.Kinds.SlotNames, CanWrap); + var resultMapper = mapper; + if (CanWrapTrainingLabels(resultMapper, type)) + resultMapper = LabelNameBindableMapper.CreateBound(env, (ISchemaBoundRowMapper)resultMapper, type as VectorDataViewType, getter, AnnotationUtils.Kinds.TrainingLabelValues, CanWrapTrainingLabels); + if (CanWrapSlotNames(resultMapper, type)) + resultMapper = LabelNameBindableMapper.CreateBound(env, (ISchemaBoundRowMapper)resultMapper, type as VectorDataViewType, getter, AnnotationUtils.Kinds.SlotNames, CanWrapSlotNames); + return resultMapper; } [BestFriend] diff --git a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs index e87e93f43c..90ec8b0357 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs @@ -62,22 +62,12 @@ private BindingsImpl(DataViewSchema input, ISchemaBoundRowMapper mapper, string { var scoreColMetadata = mapper.OutputSchema[scoreColIndex].Annotations; - var slotColumn = scoreColMetadata.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames); - if (slotColumn?.Type is VectorDataViewType slotColVecType && (ulong)slotColVecType.Size == predColKeyType.Count) + var trainLabelColumn = scoreColMetadata.Schema.GetColumnOrNull(AnnotationUtils.Kinds.TrainingLabelValues); + if (trainLabelColumn?.Type is VectorDataViewType trainLabelColVecType && (ulong)trainLabelColVecType.Size == predColKeyType.Count) { - Contracts.Assert(slotColVecType.Size > 0); - _predColMetadata = Utils.MarshalInvoke(KeyValueMetadataFromMetadata, slotColVecType.RawType, - scoreColMetadata, slotColumn.Value); - } - else - { - var trainLabelColumn = scoreColMetadata.Schema.GetColumnOrNull(AnnotationUtils.Kinds.TrainingLabelValues); - if (trainLabelColumn?.Type is VectorDataViewType trainLabelColVecType && (ulong)trainLabelColVecType.Size == predColKeyType.Count) - { - Contracts.Assert(trainLabelColVecType.Size > 0); - _predColMetadata = Utils.MarshalInvoke(KeyValueMetadataFromMetadata, trainLabelColVecType.RawType, - scoreColMetadata, trainLabelColumn.Value); - } + Contracts.Assert(trainLabelColVecType.Size > 0); + _predColMetadata = Utils.MarshalInvoke(KeyValueMetadataFromMetadata, trainLabelColVecType.RawType, + scoreColMetadata, trainLabelColumn.Value); } } } diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs index d82c306ec5..fff8663643 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs @@ -131,7 +131,8 @@ private protected override NaiveBayesMulticlassModelParameters TrainModelCore(Tr int size = cursor.Label + 1; Utils.EnsureSize(ref labelHistogram, size); Utils.EnsureSize(ref featureHistogram, size); - Utils.EnsureSize(ref featureHistogram[cursor.Label], featureCount); + if (featureHistogram[cursor.Label] == null) + featureHistogram[cursor.Label] = new int[featureCount]; labelHistogram[cursor.Label] += 1; labelCount = labelCount < size ? size : labelCount; diff --git a/test/Microsoft.ML.Functional.Tests/Training.cs b/test/Microsoft.ML.Functional.Tests/Training.cs index 2499ecc634..165c57dc20 100644 --- a/test/Microsoft.ML.Functional.Tests/Training.cs +++ b/test/Microsoft.ML.Functional.Tests/Training.cs @@ -438,7 +438,7 @@ public void ContinueTrainingSymbolicStochasticGradientDescent() } /// - /// Training: Meta-compononts function as expected. For OVA (one-versus-all), a user will be able to specify only + /// Training: Meta-components function as expected. For OVA (one-versus-all), a user will be able to specify only /// binary classifier trainers. If they specify a different model class there should be a compile error. /// [Fact] @@ -467,5 +467,39 @@ public void MetacomponentsFunctionAsExpectedOva() // Evaluate the model. var binaryClassificationMetrics = mlContext.MulticlassClassification.Evaluate(binaryClassificationPredictions); } + + /// + /// Training: Meta-components function as expected. For OVA (one-versus-all), a user will be able to specify only + /// binary classifier trainers. If they specify a different model class there should be a compile error. + /// + [Fact] + public void MetacomponentsFunctionWithKeyHandling() + { + var mlContext = new MLContext(seed: 1); + + var data = mlContext.Data.LoadFromTextFile(GetDataPath(TestDatasets.iris.trainFilename), + hasHeader: TestDatasets.iris.fileHasHeader, + separatorChar: TestDatasets.iris.fileSeparator); + + // Create a model training an OVA trainer with a binary classifier. + var binaryClassificationTrainer = mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression( + new LbfgsLogisticRegressionBinaryTrainer.Options { MaximumNumberOfIterations = 10, NumberOfThreads = 1, }); + var binaryClassificationPipeline = mlContext.Transforms.Concatenate("Features", Iris.Features) + .AppendCacheCheckpoint(mlContext) + .Append(mlContext.Transforms.Conversion.MapValueToKey("Label")) + .Append(mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryClassificationTrainer)) + .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel")); + + // Fit the binary classification pipeline. + var binaryClassificationModel = binaryClassificationPipeline.Fit(data); + + // Transform the data + var binaryClassificationPredictions = binaryClassificationModel.Transform(data); + + // Evaluate the model. + var binaryClassificationMetrics = mlContext.MulticlassClassification.Evaluate(binaryClassificationPredictions); + + Assert.Equal(0.4367, binaryClassificationMetrics.LogLoss, 4); + } } } \ No newline at end of file diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/PredictAndMetadata.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/PredictAndMetadata.cs index cfbdd079ed..30fdafe90b 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/PredictAndMetadata.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/PredictAndMetadata.cs @@ -37,22 +37,22 @@ void PredictAndMetadata() var testLoader = ml.Data.LoadFromTextFile(dataPath, TestDatasets.irisData.GetLoaderColumns(), separatorChar: ',', hasHeader: true); var testData = ml.Data.CreateEnumerable(testLoader, false); - + // During prediction we will get Score column with 3 float values. // We need to find way to map each score to original label. - // In order to do what we need to get SlotNames from Score column. - // Slot names on top of Score column represent original labels for i-th value in Score array. - VBuffer> slotNames = default; - engine.OutputSchema[nameof(IrisPrediction.Score)].GetSlotNames(ref slotNames); + // In order to do what we need to get TrainingLabelValues from Score column. + // TrainingLabelValues on top of Score column represent original labels for i-th value in Score array. + VBuffer> originalLabels = default; + engine.OutputSchema[nameof(IrisPrediction.Score)].Annotations.GetValue(AnnotationUtils.Kinds.TrainingLabelValues, ref originalLabels); // Since we apply MapValueToKey estimator with default parameters, key values // depends on order of occurence in data file. Which is "Iris-setosa", "Iris-versicolor", "Iris-virginica" // So if we have Score column equal to [0.2, 0.3, 0.5] that's mean what score for // Iris-setosa is 0.2 // Iris-versicolor is 0.3 // Iris-virginica is 0.5. - Assert.True(slotNames.GetItemOrDefault(0).ToString() == "Iris-setosa"); - Assert.True(slotNames.GetItemOrDefault(1).ToString() == "Iris-versicolor"); - Assert.True(slotNames.GetItemOrDefault(2).ToString() == "Iris-virginica"); + Assert.Equal("Iris-setosa", originalLabels.GetItemOrDefault(0).ToString()); + Assert.Equal("Iris-versicolor", originalLabels.GetItemOrDefault(1).ToString()); + Assert.Equal("Iris-virginica", originalLabels.GetItemOrDefault(2).ToString()); // Let's look how we can convert key value for PredictedLabel to original labels. // We need to read KeyValues for "PredictedLabel" column. diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs index 0ec033b4c5..9e1f3615e5 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Linq; using Microsoft.ML.Calibrators; using Microsoft.ML.Data; using Microsoft.ML.RunTests; @@ -83,12 +84,14 @@ public void MetacomponentsFeaturesRenamed() var data = loader.Load(GetDataPath(TestDatasets.irisData.trainFilename)); var sdcaTrainer = ML.BinaryClassification.Trainers.SdcaNonCalibrated( - new SdcaNonCalibratedBinaryTrainer.Options { + new SdcaNonCalibratedBinaryTrainer.Options + { LabelColumnName = "Label", FeatureColumnName = "Vars", MaximumNumberOfIterations = 100, Shuffle = true, - NumberOfThreads = 1, }); + NumberOfThreads = 1, + }); var pipeline = new ColumnConcatenatingEstimator(Env, "Vars", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") .Append(new ValueToKeyMappingEstimator(Env, "Label"), TransformerScope.TrainTest)